From 1c1ec4d078682296b1acf232697eadb7b3f80b56 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 9 Jan 2025 18:36:02 +0200 Subject: [PATCH] Initial Enzyme support (#668) --------- Co-authored-by: William Moses --- .buildkite/pipeline.yml | 19 ++ Project.toml | 5 +- .../AMDGPUEnzymeCoreExt.jl | 218 ++++++++++++++++++ ext/AMDGPUEnzymeCoreExt/meta_kernels.jl | 83 +++++++ src/AMDGPU.jl | 18 +- src/runtime/hip-execution.jl | 5 +- test/enzyme_tests.jl | 55 +++++ test/runtests.jl | 13 +- 8 files changed, 402 insertions(+), 14 deletions(-) create mode 100644 ext/AMDGPUEnzymeCoreExt/AMDGPUEnzymeCoreExt.jl create mode 100644 ext/AMDGPUEnzymeCoreExt/meta_kernels.jl create mode 100644 test/enzyme_tests.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index e674c111d..f1391e7d0 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -83,6 +83,25 @@ steps: # JULIA_AMDGPU_HIP_MUST_LOAD: "1" # JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + - label: "Julia 1.10 Enzyme" + plugins: + - JuliaCI/julia#v1: + version: "1.10" + - JuliaCI/julia-test#v1: + test_args: "enzyme" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + command: "julia --project -e 'using Pkg; Pkg.update()'" + timeout_in_minutes: 180 + env: + JULIA_NUM_THREADS: 4 + JULIA_AMDGPU_CORE_MUST_LOAD: "1" + JULIA_AMDGPU_HIP_MUST_LOAD: "1" + JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" + - label: "GPU-less environment" plugins: - JuliaCI/julia#v1: diff --git a/Project.toml b/Project.toml index 968ae226d..bce4d520e 100644 --- a/Project.toml +++ b/Project.toml @@ -34,17 +34,20 @@ UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] AMDGPUChainRulesCoreExt = "ChainRulesCore" +AMDGPUEnzymeCoreExt = "EnzymeCore" [compat] AbstractFFTs = "1.0" AcceleratedKernels = "0.2" Adapt = "4" -Atomix = "0.1, 1" +Atomix = "1" CEnum = "0.4, 0.5" ChainRulesCore = "1" +EnzymeCore = "0.8" ExprTools = "0.1" GPUArrays = "11.2" GPUCompiler = "0.27, 1.0" diff --git a/ext/AMDGPUEnzymeCoreExt/AMDGPUEnzymeCoreExt.jl b/ext/AMDGPUEnzymeCoreExt/AMDGPUEnzymeCoreExt.jl new file mode 100644 index 000000000..b91c58177 --- /dev/null +++ b/ext/AMDGPUEnzymeCoreExt/AMDGPUEnzymeCoreExt.jl @@ -0,0 +1,218 @@ +module AMDGPUEnzymeCoreExt + +using AMDGPU +using EnzymeCore +using EnzymeCore: EnzymeRules +using GPUCompiler + +include("meta_kernels.jl") + +function EnzymeCore.compiler_job_from_backend( + ::ROCBackend, @nospecialize(F::Type), @nospecialize(TT::Type), +) + mi = GPUCompiler.methodinstance(F, TT) + return GPUCompiler.CompilerJob(mi, AMDGPU.compiler_config(AMDGPU.device())) +end + +function EnzymeRules.forward( + config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{<: Duplicated}, + f::Const{F}, tt::Const{TT}; kwargs..., +) where {F, TT} + res = fn.val(f.val, tt.val; kwargs...) + return Duplicated(res, res) +end + +function EnzymeRules.forward( + config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{<: BatchDuplicated{T, N}}, + f::Const{F}, tt::Const{TT}; kwargs..., +) where {F, TT, T, N} + res = fn.val(f.val, tt.val; kwargs...) + return BatchDuplicated(res, ntuple(_ -> res, Val(N))) +end + +function EnzymeRules.reverse( + config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{RT}, + subtape, f, tt; kwargs..., +) where RT + return (nothing, nothing) +end + +function EnzymeRules.forward( + config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, x::IT, +) where {RT, IT} + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + config_width = EnzymeRules.width(config) + if config_width == 1 + Duplicated(fn.val(x.val), fn.val(x.dval)) + else + tup = ntuple(Val(config_width)) do i + Base.@_inline_meta + fn.val(x.dval[i])::eltype(RT) + end + BatchDuplicated(fn.val(x.val), tup) + end + + elseif EnzymeRules.needs_shadow(config) + config_width = EnzymeRules.width(config) + ST = EnzymeCore.shadow_type(config, RT) + if config_width == 1 + fn.val(x.dval)::ST + else + (ntuple(Val(config_width)) do i + Base.@_inline_meta + fn.val(x.dval[i])::eltype(RT) + end)::ST + end + + elseif EnzymeRules.needs_primal(config) + fn.val(x.val)::eltype(RT) + else + nothing + end +end + +function EnzymeRules.augmented_primal( + config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, x::IT, +) where {RT, IT} + primal = EnzymeRules.needs_primal(config) ? + fn.val(x.val) : nothing + + shadow = if EnzymeRules.needs_shadow(config) + config_width = EnzymeRules.width(config) + if config_width == 1 + fn.val(x.dval) + else + ntuple(Val(config_width)) do i + Base.@_inline_meta + fn.val(x.dval[i]) + end + end + else + nothing + end + + return EnzymeRules.AugmentedReturn{ + EnzymeRules.primal_type(config, RT), + EnzymeRules.shadow_type(config, RT), Nothing + }(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, tape, x::IT, +) where {RT, IT} + return (nothing,) +end + +function EnzymeRules.forward( + config, fn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F, TT}}, + ::Type{Const{Nothing}}, args...; kwargs..., +) where {F, TT} + GC.@preserve args begin + kernel_args = ((rocconvert(a) for a in args)...,) + kernel_tt = Tuple{(typeof(config), F, (typeof(a) for a in kernel_args)...)...} + kernel = AMDGPU.hipfunction(meta_fn, kernel_tt) + kernel(config, fn.val.f, kernel_args...; kwargs...) + end + return +end + +function EnzymeRules.reverse( + config, ofn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F, TT}}, + ::Type{Const{Nothing}}, subtape, args...; + groupsize::AMDGPU.Runtime.ROCDim = 1, + gridsize::AMDGPU.Runtime.ROCDim = 1, + kwargs..., +) where {F, TT} + kernel_args = ((rocconvert(a) for a in args)...,) + kernel_tt = map(typeof, kernel_args) + + ModifiedBetween = EnzymeRules.overwritten(config) + TapeType = EnzymeCore.tape_type( + ReverseSplitModified( + EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), + Val(ModifiedBetween)), + Const{F}, + Const{Nothing}, + kernel_tt..., + ) + groupsize = AMDGPU.Runtime.ROCDim3(groupsize) + gridsize = AMDGPU.Runtime.ROCDim3(gridsize) + + GC.@preserve args subtape begin + subtape_cc = rocconvert(subtape) + kernel_tt2 = Tuple{ + (typeof(config), F, typeof(subtape_cc), kernel_tt...)...} + kernel = AMDGPU.hipfunction(meta_revf, kernel_tt2) + kernel(config, ofn.val.f, subtape_cc, kernel_args...; + groupsize, gridsize, kwargs...) + end + + return ntuple(Val(length(kernel_args))) do i + Base.@_inline_meta + nothing + end +end + +function EnzymeRules.augmented_primal( + config, fn::Const{typeof(AMDGPU.hipfunction)}, + ::Type{RT}, f::Const{F}, tt::Const{TT}; kwargs... +) where {F, CT, RT <: EnzymeCore.Annotation{CT}, TT} + res = fn.val(f.val, tt.val; kwargs...) + primal = EnzymeRules.needs_primal(config) ? res : nothing + + shadow = if EnzymeRules.needs_shadow(config) + config_width = EnzymeRules.width(config) + config_width == 1 ? + res : + ntuple(Val(config_width)) do i + Base.@_inline_meta + res + end + else + nothing + end + + return EnzymeRules.AugmentedReturn{ + EnzymeRules.primal_type(config, RT), + EnzymeRules.shadow_type(config, RT), Nothing, + }(primal, shadow, nothing) +end + +function EnzymeRules.augmented_primal( + config, fn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F,TT}}, + ::Type{Const{Nothing}}, args...; + groupsize::AMDGPU.Runtime.ROCDim = 1, + gridsize::AMDGPU.Runtime.ROCDim = 1, kwargs..., +) where {F,TT} + kernel_args = ((rocconvert(a) for a in args)...,) + kernel_tt = map(typeof, kernel_args) + + ModifiedBetween = EnzymeRules.overwritten(config) + compiler_job = EnzymeCore.compiler_job_from_backend( + ROCBackend(), typeof(Base.identity), Tuple{Float64}) + TapeType = EnzymeCore.tape_type( + compiler_job, + ReverseSplitModified( + EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), + Val(ModifiedBetween)), + Const{F}, Const{Nothing}, + kernel_tt..., + ) + groupsize = AMDGPU.Runtime.ROCDim3(groupsize) + gridsize = AMDGPU.Runtime.ROCDim3(gridsize) + subtape = ROCArray{TapeType}(undef, + gridsize.x * gridsize.y * gridsize.z * + groupsize.x * groupsize.y * groupsize.z) + + GC.@preserve args subtape begin + subtape_cc = rocconvert(subtape) + kernel_tt2 = Tuple{ + (typeof(config), F, typeof(subtape_cc), kernel_tt...)...} + kernel = AMDGPU.hipfunction(meta_augf, kernel_tt2) + kernel(config, fn.val.f, subtape_cc, kernel_args...; + groupsize, gridsize, kwargs...) + end + return EnzymeRules.AugmentedReturn{Nothing, Nothing, ROCArray}(nothing, nothing, subtape) +end + +end diff --git a/ext/AMDGPUEnzymeCoreExt/meta_kernels.jl b/ext/AMDGPUEnzymeCoreExt/meta_kernels.jl new file mode 100644 index 000000000..2584491b4 --- /dev/null +++ b/ext/AMDGPUEnzymeCoreExt/meta_kernels.jl @@ -0,0 +1,83 @@ +function meta_fn(config, fn, args::Vararg{Any, N}) where N + EnzymeCore.autodiff_deferred( + EnzymeCore.set_runtime_activity(Forward, config), + Const(fn), Const, args...) + return +end + + +function meta_augf( + config, f, tape::ROCDeviceArray{TapeType}, args::Vararg{Any, N}, +) where {N, TapeType} + ModifiedBetween = EnzymeRules.overwritten(config) + forward, _ = EnzymeCore.autodiff_deferred_thunk( + ReverseSplitModified( + EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), + Val(ModifiedBetween)), + TapeType, + Const{Core.Typeof(f)}, + Const{Nothing}, + map(typeof, args)..., + ) + + idx = 0 + # idx *= gridDim().x + idx += workgroupIdx().x - 1 + + idx *= gridGroupDim().y + idx += workgroupIdx().y - 1 + + idx *= gridGroupDim().z + idx += workgroupIdx().z - 1 + + idx *= workgroupDim().x + idx += workitemIdx().x - 1 + + idx *= workgroupDim().y + idx += workitemIdx().y - 1 + + idx *= workgroupDim().z + idx += workitemIdx().z - 1 + idx += 1 + + @inbounds tape[idx] = forward(Const(f), args...)[1] + return +end + +function meta_revf( + config, f, tape::ROCDeviceArray{TapeType}, args::Vararg{Any, N}, +) where {N, TapeType} + ModifiedBetween = EnzymeRules.overwritten(config) + _, reverse = EnzymeCore.autodiff_deferred_thunk( + ReverseSplitModified( + EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), + Val(ModifiedBetween)), + TapeType, + Const{Core.Typeof(f)}, + Const{Nothing}, + map(typeof, args)..., + ) + + idx = 0 + # idx *= gridDim().x + idx += workgroupIdx().x - 1 + + idx *= gridGroupDim().y + idx += workgroupIdx().y - 1 + + idx *= gridGroupDim().z + idx += workgroupIdx().z - 1 + + idx *= workgroupDim().x + idx += workitemIdx().x - 1 + + idx *= workgroupDim().y + idx += workitemIdx().y - 1 + + idx *= workgroupDim().z + idx += workitemIdx().z - 1 + idx += 1 + + reverse(Const(f), args..., @inbounds tape[idx]) + return +end diff --git a/src/AMDGPU.jl b/src/AMDGPU.jl index c3c45d6e8..b93932d26 100644 --- a/src/AMDGPU.jl +++ b/src/AMDGPU.jl @@ -65,8 +65,8 @@ using .ROCmDiscovery include("utils.jl") -include(joinpath("hsa", "HSA.jl")) -include(joinpath("hip", "HIP.jl")) +include("hsa/HSA.jl") +include("hip/HIP.jl") using .HIP using .HIP: HIPContext, HIPDevice, HIPStream @@ -101,7 +101,7 @@ export sync_workgroup, sync_workgroup_count, sync_workgroup_and, sync_workgroup_ include("compiler/Compiler.jl") import .Compiler -import .Compiler: hipfunction +import .Compiler: hipfunction, compiler_config include("tls.jl") include("highlevel.jl") @@ -117,12 +117,12 @@ include("kernels/accumulate.jl") include("kernels/sorting.jl") include("kernels/reverse.jl") -include(joinpath("blas", "rocBLAS.jl")) -include(joinpath("solver", "rocSOLVER.jl")) -include(joinpath("sparse", "rocSPARSE.jl")) -include(joinpath("rand", "rocRAND.jl")) -include(joinpath("fft", "rocFFT.jl")) -include(joinpath("dnn", "MIOpen.jl")) +include("blas/rocBLAS.jl") +include("solver/rocSOLVER.jl") +include("sparse/rocSPARSE.jl") +include("rand/rocRAND.jl") +include("fft/rocFFT.jl") +include("dnn/MIOpen.jl") include("random.jl") diff --git a/src/runtime/hip-execution.jl b/src/runtime/hip-execution.jl index 0002d7568..165473b47 100644 --- a/src/runtime/hip-execution.jl +++ b/src/runtime/hip-execution.jl @@ -115,10 +115,11 @@ end function launch( fun::HIP.HIPFunction, args::Vararg{Any, N}; - gridsize::ROCDim = 1, groupsize::ROCDim = 1, + gridsize = 1, groupsize = 1, shmem::Integer = 0, stream::HIP.HIPStream, ) where N - gd, bd = ROCDim3(gridsize), ROCDim3(groupsize) + gd = gridsize isa ROCDim3 ? gridsize : ROCDim3(gridsize) + bd = groupsize isa ROCDim3 ? groupsize : ROCDim3(groupsize) pack_arguments(args...) do kernel_params HIP.hipModuleLaunchKernel( fun, gd.x, gd.y, gd.z, bd.x, bd.y, bd.z, diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl new file mode 100644 index 000000000..2fa12953b --- /dev/null +++ b/test/enzyme_tests.jl @@ -0,0 +1,55 @@ +@testitem "enzyme" begin + +using AMDGPU +using EnzymeCore, Enzyme +using GPUCompiler + +Enzyme.Compiler.VERBOSE_ERRORS[] = true + + +@testset "CompilerJob from backend" begin + job = EnzymeCore.compiler_job_from_backend( + ROCBackend(), typeof(() -> nothing), Tuple{}) + @test job isa GPUCompiler.CompilerJob +end + +function square_kernel!(x) + i = workitemIdx().x + x[i] *= x[i] + return +end + +function square!(x) + @roc groupsize=length(x) gridsize=1 square_kernel!(x) + return nothing +end + +@testset "Forward Kernel" begin + A = ROCArray(collect(1.0:64.0)) + dA = ROCArray(ones(Float64, 64)) + Enzyme.autodiff(Forward, square!, Duplicated(A, dA)) + @test all(dA .≈ (2:2:128)) + + A = ROCArray(collect(1.0:64.0)) + dA = ROCArray(ones(Float64, 64)) + dA2 = ROCArray(ones(Float64, 64) .* 3.0) + Enzyme.autodiff(Forward, square!, BatchDuplicated(A, (dA, dA2))) + @test all(dA .≈ (2:2:128)) + @test all(dA2 .≈ (2:2:128) .* 3) +end + +@testset "Reverse Kernel" begin + A = ROCArray(collect(1.0:64.0)) + dA = ROCArray(ones(Float64, 64)) + Enzyme.autodiff(Reverse, square!, Duplicated(A, dA)) + @test all(dA .≈ (2:2:128)) + + A = ROCArray(collect(1.0:64.0)) + dA = ROCArray(ones(Float64, 64)) + dA2 = ROCArray(ones(Float64, 64) .* 3.0) + Enzyme.autodiff(Reverse, square!, BatchDuplicated(A, (dA, dA2))) + @test all(dA .≈ (2:2:128)) + @test all(dA2 .≈ (2:2:128) .* 3) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 18e44e365..5369a77b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,8 +2,10 @@ using AMDGPU using AMDGPU: Device, Runtime, @allowscalar import AMDGPU.Device: HostCallHolder, hostcall! +import Pkg import PrettyTables import InteractiveUtils + using LinearAlgebra using ReTestItems using Test @@ -30,7 +32,7 @@ end AMDGPU.allowscalar(false) -const TEST_NAMES = ["core", "hip", "ext", "gpuarrays", "kernelabstractions"] +const TEST_NAMES = ["core", "hip", "ext", "gpuarrays", "kernelabstractions", "enzyme"] function parse_flags!(args, flag; default = nothing, typ = typeof(default)) for f in args @@ -88,7 +90,14 @@ for test_name in ARGS """) end -const TARGET_TESTS = isempty(ARGS) ? TEST_NAMES : ARGS +# Do not run Enzyme tests by default. +const TARGET_TESTS = isempty(ARGS) ? + [t for t in TEST_NAMES if t != "enzyme"] : + ARGS + +if "enzyme" in TARGET_TESTS + Pkg.add(["EnzymeCore", "Enzyme"]) +end # Run tests in parallel. np = set_jobs ? jobs : (Sys.CPU_THREADS ÷ 2)