From 6195cd3cafcb71dfb677f9776f8e86763cace4bd Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 12 Aug 2024 08:44:13 +0200 Subject: [PATCH] Fix Enzyme extension and add new broken test (#151) * Fix Enzyme extension and add new test * Adapt to latest version * No function annotation * Test broken * Fix tests * Mode * Const * Bump version and move constructor doc --- Project.toml | 6 ++-- examples/3_tricks.jl | 3 ++ ext/ImplicitDifferentiationEnzymeExt.jl | 7 +++-- src/implicit_function.jl | 41 ++++++++++++++----------- test/systematic.jl | 4 +-- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 9b80cc3..1bae515 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.6.0" +version = "0.6.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -21,9 +21,9 @@ ImplicitDifferentiationEnzymeExt = "Enzyme" ImplicitDifferentiationForwardDiffExt = "ForwardDiff" [compat] -ADTypes = "1.0" +ADTypes = "1.7.1" ChainRulesCore = "1.23.0" -DifferentiationInterface = "0.5" +DifferentiationInterface = "0.5.12" Enzyme = "0.11.20,0.12" ForwardDiff = "0.10.36" Krylov = "0.9.5" diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index 572b13b..86ff5dd 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -5,6 +5,7 @@ We demonstrate several features that may come in handy for some users. =# using ComponentArrays +using Enzyme #src using ForwardDiff using ImplicitDifferentiation using Krylov @@ -67,6 +68,8 @@ J = ForwardDiff.jacobian(forward_components, x) #src Zygote.jacobian(implicit_components, x)[1] @test Zygote.jacobian(implicit_components, x)[1] ≈ J #src +@test_broken Enzyme.jacobian(Enzyme.Forward, implicit_components, x) ≈ J #src + #- The full differentiable pipeline looks like this function full_pipeline(a, b, m) diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl index eaeec45..ff9e29f 100644 --- a/ext/ImplicitDifferentiationEnzymeExt.jl +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -5,6 +5,8 @@ using Enzyme using Enzyme.EnzymeCore using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output +const FORWARD_BACKEND = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const) + function EnzymeRules.forward( func::Const{<:ImplicitFunction}, RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}}, @@ -20,12 +22,11 @@ function EnzymeRules.forward( y = output(y_or_yz) Y = typeof(y) - suggested_backend = AutoEnzyme(Enzyme.Forward) + suggested_backend = FORWARD_BACKEND A = build_A(implicit, x, y_or_yz, args...; suggested_backend) B = build_B(implicit, x, y_or_yz, args...; suggested_backend) - dx_batch = reduce(hcat, dx) - dc_batch = mapreduce(hcat, eachcol(dx_batch)) do dₖx + dc_batch = mapreduce(hcat, dx) do dₖx B * dₖx end dy_batch = implicit.linear_solver(A, -dc_batch) diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 51393cc..534176f 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -60,6 +60,26 @@ The value of `lazy` must be chosen together with the `linear_solver`, see below. - `conditions_x_backend`: how the conditions will be differentiated w.r.t. the first argument `x` - `conditions_y_backend`: how the conditions will be differentiated w.r.t. the second argument `y` +# Constructors + + ImplicitFunction( + forward, conditions; + linear_solver=KrylovLinearSolver(), + conditions_x_backend=nothing, + conditions_x_backend=nothing, + ) + +Picks the `lazy` parameter automatically based on the `linear_solver`, using the following heuristic: `lazy = linear_solver != \\`. + + ImplicitFunction{lazy}( + forward, conditions; + linear_solver=lazy ? KrylovLinearSolver() : \\, + conditions_x_backend=nothing, + conditions_y_backend=nothing, + ) + +Picks the `linear_solver` automatically based on the `lazy` parameter. + # Function signatures There are two possible signatures for `forward` and `conditions`, which must be consistent with one another: @@ -87,8 +107,10 @@ Typically, direct solvers work best with dense Jacobians (`lazy = false`) while # Condition backends The provided `conditions_x_backend` and `conditions_y_backend` can be either: +- `nothing` (the default), in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used. - an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl); -- `nothing`, in which case the outer backend (the one differentiating through the `ImplicitFunction`) is used. + +When differentiating with Enzyme as an outer backend, the default setting assumes that `conditions` does not contain writeable data involved in derivatives. """ struct ImplicitFunction{ lazy,F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType} @@ -101,14 +123,7 @@ struct ImplicitFunction{ end """ - ImplicitFunction{lazy}( - forward, conditions; - linear_solver=lazy ? KrylovLinearSolver() : \\, - conditions_x_backend=nothing, - conditions_y_backend=nothing, - ) -Constructor for an [`ImplicitFunction`](@ref) which picks the `linear_solver` automatically based on the `lazy` parameter. """ function ImplicitFunction{lazy}( forward::F, @@ -126,16 +141,6 @@ function ImplicitFunction{lazy}( ) end -""" - ImplicitFunction( - forward, conditions; - linear_solver=KrylovLinearSolver(), - conditions_x_backend=nothing, - conditions_x_backend=nothing, - ) - -Constructor for an [`ImplicitFunction`](@ref) which picks the `lazy` parameter automatically based on the `linear_solver`, using the following heuristic: `lazy = linear_solver != \\`. -""" function ImplicitFunction( forward, conditions; diff --git a/test/systematic.jl b/test/systematic.jl index 8056bd8..bff1afd 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -12,7 +12,7 @@ include("utils.jl") backends = [ AutoForwardDiff(; chunksize=1), # - AutoEnzyme(Enzyme.Forward), + AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const), AutoZygote(), ] @@ -24,7 +24,7 @@ linear_solver_candidates = ( conditions_backend_candidates = ( nothing, # AutoForwardDiff(; chunksize=1), - AutoEnzyme(Enzyme.Forward), + AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const), ); x_candidates = (