Skip to content

Commit

Permalink
Initial Enzyme support (#668)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: William Moses <[email protected]>
  • Loading branch information
pxl-th and wsmoses authored Jan 9, 2025
1 parent 0fece1f commit 1c1ec4d
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 14 deletions.
19 changes: 19 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,25 @@ steps:
# JULIA_AMDGPU_HIP_MUST_LOAD: "1"
# JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"

- label: "Julia 1.10 Enzyme"
plugins:
- JuliaCI/julia#v1:
version: "1.10"
- JuliaCI/julia-test#v1:
test_args: "enzyme"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
if: build.message !~ /\[skip tests\]/
command: "julia --project -e 'using Pkg; Pkg.update()'"
timeout_in_minutes: 180
env:
JULIA_NUM_THREADS: 4
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"

- label: "GPU-less environment"
plugins:
- JuliaCI/julia#v1:
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,20 @@ UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
AMDGPUChainRulesCoreExt = "ChainRulesCore"
AMDGPUEnzymeCoreExt = "EnzymeCore"

[compat]
AbstractFFTs = "1.0"
AcceleratedKernels = "0.2"
Adapt = "4"
Atomix = "0.1, 1"
Atomix = "1"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8"
ExprTools = "0.1"
GPUArrays = "11.2"
GPUCompiler = "0.27, 1.0"
Expand Down
218 changes: 218 additions & 0 deletions ext/AMDGPUEnzymeCoreExt/AMDGPUEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
module AMDGPUEnzymeCoreExt

using AMDGPU
using EnzymeCore
using EnzymeCore: EnzymeRules
using GPUCompiler

include("meta_kernels.jl")

function EnzymeCore.compiler_job_from_backend(
::ROCBackend, @nospecialize(F::Type), @nospecialize(TT::Type),
)
mi = GPUCompiler.methodinstance(F, TT)
return GPUCompiler.CompilerJob(mi, AMDGPU.compiler_config(AMDGPU.device()))
end

function EnzymeRules.forward(
config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{<: Duplicated},
f::Const{F}, tt::Const{TT}; kwargs...,
) where {F, TT}
res = fn.val(f.val, tt.val; kwargs...)
return Duplicated(res, res)
end

function EnzymeRules.forward(
config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{<: BatchDuplicated{T, N}},
f::Const{F}, tt::Const{TT}; kwargs...,
) where {F, TT, T, N}
res = fn.val(f.val, tt.val; kwargs...)
return BatchDuplicated(res, ntuple(_ -> res, Val(N)))
end

function EnzymeRules.reverse(
config, fn::Const{typeof(AMDGPU.hipfunction)}, ::Type{RT},
subtape, f, tt; kwargs...,
) where RT
return (nothing, nothing)
end

function EnzymeRules.forward(
config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, x::IT,
) where {RT, IT}
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
config_width = EnzymeRules.width(config)
if config_width == 1
Duplicated(fn.val(x.val), fn.val(x.dval))
else
tup = ntuple(Val(config_width)) do i
Base.@_inline_meta
fn.val(x.dval[i])::eltype(RT)
end
BatchDuplicated(fn.val(x.val), tup)
end

elseif EnzymeRules.needs_shadow(config)
config_width = EnzymeRules.width(config)
ST = EnzymeCore.shadow_type(config, RT)
if config_width == 1
fn.val(x.dval)::ST
else
(ntuple(Val(config_width)) do i
Base.@_inline_meta
fn.val(x.dval[i])::eltype(RT)
end)::ST
end

elseif EnzymeRules.needs_primal(config)
fn.val(x.val)::eltype(RT)
else
nothing
end
end

function EnzymeRules.augmented_primal(
config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, x::IT,
) where {RT, IT}
primal = EnzymeRules.needs_primal(config) ?
fn.val(x.val) : nothing

shadow = if EnzymeRules.needs_shadow(config)
config_width = EnzymeRules.width(config)
if config_width == 1
fn.val(x.dval)
else
ntuple(Val(config_width)) do i
Base.@_inline_meta
fn.val(x.dval[i])
end
end
else
nothing
end

return EnzymeRules.AugmentedReturn{
EnzymeRules.primal_type(config, RT),
EnzymeRules.shadow_type(config, RT), Nothing
}(primal, shadow, nothing)
end

function EnzymeRules.reverse(
config, fn::Const{typeof(AMDGPU.rocconvert)}, ::Type{RT}, tape, x::IT,
) where {RT, IT}
return (nothing,)
end

function EnzymeRules.forward(
config, fn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F, TT}},
::Type{Const{Nothing}}, args...; kwargs...,
) where {F, TT}
GC.@preserve args begin
kernel_args = ((rocconvert(a) for a in args)...,)
kernel_tt = Tuple{(typeof(config), F, (typeof(a) for a in kernel_args)...)...}
kernel = AMDGPU.hipfunction(meta_fn, kernel_tt)
kernel(config, fn.val.f, kernel_args...; kwargs...)
end
return
end

function EnzymeRules.reverse(
config, ofn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F, TT}},
::Type{Const{Nothing}}, subtape, args...;
groupsize::AMDGPU.Runtime.ROCDim = 1,
gridsize::AMDGPU.Runtime.ROCDim = 1,
kwargs...,
) where {F, TT}
kernel_args = ((rocconvert(a) for a in args)...,)
kernel_tt = map(typeof, kernel_args)

ModifiedBetween = EnzymeRules.overwritten(config)
TapeType = EnzymeCore.tape_type(
ReverseSplitModified(
EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config),
Val(ModifiedBetween)),
Const{F},
Const{Nothing},
kernel_tt...,
)
groupsize = AMDGPU.Runtime.ROCDim3(groupsize)
gridsize = AMDGPU.Runtime.ROCDim3(gridsize)

GC.@preserve args subtape begin
subtape_cc = rocconvert(subtape)
kernel_tt2 = Tuple{
(typeof(config), F, typeof(subtape_cc), kernel_tt...)...}
kernel = AMDGPU.hipfunction(meta_revf, kernel_tt2)
kernel(config, ofn.val.f, subtape_cc, kernel_args...;
groupsize, gridsize, kwargs...)
end

return ntuple(Val(length(kernel_args))) do i
Base.@_inline_meta
nothing
end
end

function EnzymeRules.augmented_primal(
config, fn::Const{typeof(AMDGPU.hipfunction)},
::Type{RT}, f::Const{F}, tt::Const{TT}; kwargs...
) where {F, CT, RT <: EnzymeCore.Annotation{CT}, TT}
res = fn.val(f.val, tt.val; kwargs...)
primal = EnzymeRules.needs_primal(config) ? res : nothing

shadow = if EnzymeRules.needs_shadow(config)
config_width = EnzymeRules.width(config)
config_width == 1 ?
res :
ntuple(Val(config_width)) do i
Base.@_inline_meta
res
end
else
nothing
end

return EnzymeRules.AugmentedReturn{
EnzymeRules.primal_type(config, RT),
EnzymeRules.shadow_type(config, RT), Nothing,
}(primal, shadow, nothing)
end

function EnzymeRules.augmented_primal(
config, fn::EnzymeCore.Annotation{AMDGPU.Runtime.HIPKernel{F,TT}},
::Type{Const{Nothing}}, args...;
groupsize::AMDGPU.Runtime.ROCDim = 1,
gridsize::AMDGPU.Runtime.ROCDim = 1, kwargs...,
) where {F,TT}
kernel_args = ((rocconvert(a) for a in args)...,)
kernel_tt = map(typeof, kernel_args)

ModifiedBetween = EnzymeRules.overwritten(config)
compiler_job = EnzymeCore.compiler_job_from_backend(
ROCBackend(), typeof(Base.identity), Tuple{Float64})
TapeType = EnzymeCore.tape_type(
compiler_job,
ReverseSplitModified(
EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config),
Val(ModifiedBetween)),
Const{F}, Const{Nothing},
kernel_tt...,
)
groupsize = AMDGPU.Runtime.ROCDim3(groupsize)
gridsize = AMDGPU.Runtime.ROCDim3(gridsize)
subtape = ROCArray{TapeType}(undef,
gridsize.x * gridsize.y * gridsize.z *
groupsize.x * groupsize.y * groupsize.z)

GC.@preserve args subtape begin
subtape_cc = rocconvert(subtape)
kernel_tt2 = Tuple{
(typeof(config), F, typeof(subtape_cc), kernel_tt...)...}
kernel = AMDGPU.hipfunction(meta_augf, kernel_tt2)
kernel(config, fn.val.f, subtape_cc, kernel_args...;
groupsize, gridsize, kwargs...)
end
return EnzymeRules.AugmentedReturn{Nothing, Nothing, ROCArray}(nothing, nothing, subtape)
end

end
83 changes: 83 additions & 0 deletions ext/AMDGPUEnzymeCoreExt/meta_kernels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
function meta_fn(config, fn, args::Vararg{Any, N}) where N
EnzymeCore.autodiff_deferred(
EnzymeCore.set_runtime_activity(Forward, config),
Const(fn), Const, args...)
return
end


function meta_augf(
config, f, tape::ROCDeviceArray{TapeType}, args::Vararg{Any, N},
) where {N, TapeType}
ModifiedBetween = EnzymeRules.overwritten(config)
forward, _ = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(
EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config),
Val(ModifiedBetween)),
TapeType,
Const{Core.Typeof(f)},
Const{Nothing},
map(typeof, args)...,
)

idx = 0
# idx *= gridDim().x
idx += workgroupIdx().x - 1

idx *= gridGroupDim().y
idx += workgroupIdx().y - 1

idx *= gridGroupDim().z
idx += workgroupIdx().z - 1

idx *= workgroupDim().x
idx += workitemIdx().x - 1

idx *= workgroupDim().y
idx += workitemIdx().y - 1

idx *= workgroupDim().z
idx += workitemIdx().z - 1
idx += 1

@inbounds tape[idx] = forward(Const(f), args...)[1]
return
end

function meta_revf(
config, f, tape::ROCDeviceArray{TapeType}, args::Vararg{Any, N},
) where {N, TapeType}
ModifiedBetween = EnzymeRules.overwritten(config)
_, reverse = EnzymeCore.autodiff_deferred_thunk(
ReverseSplitModified(
EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config),
Val(ModifiedBetween)),
TapeType,
Const{Core.Typeof(f)},
Const{Nothing},
map(typeof, args)...,
)

idx = 0
# idx *= gridDim().x
idx += workgroupIdx().x - 1

idx *= gridGroupDim().y
idx += workgroupIdx().y - 1

idx *= gridGroupDim().z
idx += workgroupIdx().z - 1

idx *= workgroupDim().x
idx += workitemIdx().x - 1

idx *= workgroupDim().y
idx += workitemIdx().y - 1

idx *= workgroupDim().z
idx += workitemIdx().z - 1
idx += 1

reverse(Const(f), args..., @inbounds tape[idx])
return
end
18 changes: 9 additions & 9 deletions src/AMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ using .ROCmDiscovery

include("utils.jl")

include(joinpath("hsa", "HSA.jl"))
include(joinpath("hip", "HIP.jl"))
include("hsa/HSA.jl")
include("hip/HIP.jl")

using .HIP
using .HIP: HIPContext, HIPDevice, HIPStream
Expand Down Expand Up @@ -101,7 +101,7 @@ export sync_workgroup, sync_workgroup_count, sync_workgroup_and, sync_workgroup_

include("compiler/Compiler.jl")
import .Compiler
import .Compiler: hipfunction
import .Compiler: hipfunction, compiler_config

include("tls.jl")
include("highlevel.jl")
Expand All @@ -117,12 +117,12 @@ include("kernels/accumulate.jl")
include("kernels/sorting.jl")
include("kernels/reverse.jl")

include(joinpath("blas", "rocBLAS.jl"))
include(joinpath("solver", "rocSOLVER.jl"))
include(joinpath("sparse", "rocSPARSE.jl"))
include(joinpath("rand", "rocRAND.jl"))
include(joinpath("fft", "rocFFT.jl"))
include(joinpath("dnn", "MIOpen.jl"))
include("blas/rocBLAS.jl")
include("solver/rocSOLVER.jl")
include("sparse/rocSPARSE.jl")
include("rand/rocRAND.jl")
include("fft/rocFFT.jl")
include("dnn/MIOpen.jl")

include("random.jl")

Expand Down
Loading

0 comments on commit 1c1ec4d

Please sign in to comment.