Skip to content

Commit

Permalink
Fix flattening bug (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 24, 2023
1 parent 8da575b commit cf5ab9d
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 25 deletions.
2 changes: 1 addition & 1 deletion CITATION.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ @misc{ImplicitDifferentiation.jl
author = {Guillaume Dalle, Mohamed Tarek and contributors},
title = {ImplicitDifferentiation.jl},
url = {https://github.com/gdalle/ImplicitDifferentiation.jl},
version = {v0.4.3},
version = {v0.4.4},
year = {2023},
month = {5}
}
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
version = "0.4.3"
version = "0.4.4"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand Down
22 changes: 0 additions & 22 deletions examples/0_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,25 +162,3 @@ h = rand(2)
J_Z(t) = Zygote.jacobian(first implicit2, x .+ t .* h)[1]
ForwardDiff.derivative(J_Z, 0) Diagonal((-0.25 .* h) ./ (x .^ 1.5))
@test ForwardDiff.derivative(J_Z, 0) Diagonal((-0.25 .* h) ./ (x .^ 1.5)) #src

# The following tests are not included in the docs #src

X = rand(2, 3, 4) #src
JJ = Diagonal(0.5 ./ sqrt.(vec(X))) #src
@test (first implicit)(X) sqrt.(X) #src
@test ForwardDiff.jacobian(first implicit, X) JJ #src
@test Zygote.jacobian(first implicit, X)[1] JJ #src

# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities #src
@testset verbose = true "ChainRulesTestUtils.jl" begin #src
@test_skip test_rrule(implicit, x) #src
@test_skip test_rrule(implicit, X) #src
end #src

x_and_dx = [ForwardDiff.Dual(x[i], (0, 0)) for i in eachindex(x)] #src
@inferred implicit(x_and_dx) #src

rc = Zygote.ZygoteRuleConfig() #src
_, pullback = @inferred rrule(rc, implicit, x) #src
dy, dz = zero(implicit(x)[1]), 0
@inferred pullback((dy, dz))
3 changes: 2 additions & 1 deletion ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ function (implicit::ImplicitFunction)(
end

y_and_dy = let y = y, dy = dy
map(eachindex(y)) do i
y_and_dy_vec = map(eachindex(y)) do i
Dual{T}(y[i], Partials(ntuple(k -> dy[k][i], Val(N))))
end
reshape(y_and_dy_vec, size(y))
end
return y_and_dy, z
end
Expand Down
80 changes: 80 additions & 0 deletions test/misc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using ChainRulesCore
using ChainRulesTestUtils
using ForwardDiff
using ImplicitDifferentiation
using JET
using LinearAlgebra
using Random
using Test
using Zygote

Random.seed!(63);

function mysqrt(x::AbstractArray)
a = [0.0]
a[1] = first(x)
return sqrt.(x)
end

forward(x) = mysqrt(x), 0
conditions(x, y, z) = y .^ 2 .- x
implicit = ImplicitFunction(forward, conditions)

# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities
@testset verbose = true "ChainRulesTestUtils.jl" begin
@test_skip test_rrule(implicit, x)
@test_skip test_rrule(implicit, X)
end

@testset verbose = true "Vectors" begin
x = rand(2)
y, _ = implicit(x)
J = Diagonal(0.5 ./ sqrt.(x))

@testset "Exactness" begin
@test (first implicit)(x) sqrt.(x)
@test ForwardDiff.jacobian(first implicit, x) J
@test Zygote.jacobian(first implicit, x)[1] J
end

@testset verbose = true "Forward inference" begin
x_and_dx = ForwardDiff.Dual.(x, ((0, 0),))
@test (@inferred implicit(x_and_dx)) == implicit(x_and_dx)
y_and_dy, _ = implicit(x_and_dx)
@test size(y_and_dy) == size(y)
end
@testset "Reverse type inference" begin
_, pullback = @inferred rrule(Zygote.ZygoteRuleConfig(), implicit, x)
dy, dz = zero(implicit(x)[1]), 0
@test (@inferred pullback((dy, dz))) == pullback((dy, dz))
_, dx = pullback((dy, dz))
@test size(dx) == size(x)
end
end

@testset verbose = true "Arrays" begin
X = rand(2, 3, 4)
Y, _ = implicit(X)
JJ = Diagonal(0.5 ./ sqrt.(vec(X)))

@testset "Exactness" begin
@test (first implicit)(X) sqrt.(X)
@test ForwardDiff.jacobian(first implicit, X) JJ
@test Zygote.jacobian(first implicit, X)[1] JJ
end

@testset "Forward type inference" begin
X_and_dX = ForwardDiff.Dual.(X, ((0, 0),))
@test (@inferred implicit(X_and_dX)) == implicit(X_and_dX)
Y_and_dY, _ = implicit(X_and_dX)
@test size(Y_and_dY) == size(Y)
end

@testset "Reverse type inference" begin
_, pullback = @inferred rrule(Zygote.ZygoteRuleConfig(), implicit, X)
dY, dZ = zero(implicit(X)[1]), 0
@test (@inferred pullback((dY, dZ))) == pullback((dY, dZ))
_, dX = pullback((dY, dZ))
@test size(dX) == size(X)
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples")
@testset verbose = false "Doctests (Documenter.jl)" begin
doctest(ImplicitDifferentiation)
end
@testset verbose = true "Miscellaneous" begin
include("misc.jl")
end
for file in readdir(EXAMPLES_DIR_JL)
path = joinpath(EXAMPLES_DIR_JL, file)
title = markdown_title(path)
Expand Down

0 comments on commit cf5ab9d

Please sign in to comment.