Skip to content

Commit

Permalink
Merge pull request #172 from JuliaDiff/duals
Browse files Browse the repository at this point in the history
proper duals for JacVec
  • Loading branch information
ChrisRackauckas authored Dec 29, 2021
2 parents 5152045 + 7c8723e commit 508a6cc
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "1.19.1"
version = "1.19.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
37 changes: 22 additions & 15 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@ function auto_jacvec!(
f,
x,
v,
cache1 = Dual{DeivVecTag}.(x, reshape(v, size(x))),
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
cache2 = similar(cache1),
)
cache1 .= Dual{DeivVecTag}.(x, reshape(v, size(x)))
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
f(cache2, cache1)
dy .= partials.(cache2, 1)
vecdy = _vec(dy)
vecdy .= partials.(_vec(cache2), 1)
end

_vec(v) = vec(v)
_vec(v::AbstractVector) = v

function auto_jacvec(f, x, v)
vv = reshape(v, axes(x))
vec(partials.(vec(f(ForwardDiff.Dual{DeivVecTag}.(x, vv))), 1))
y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(vv)))
vec(partials.(vec(f(y)), 1))
end

function num_jacvec!(
Expand Down Expand Up @@ -122,12 +127,12 @@ function autonum_hesvec!(
f,
x,
v,
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v),
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
)
cache = FiniteDiff.GradientCache(v[1], cache1, Val{:central})
g = (dx, x) -> FiniteDiff.finite_difference_gradient!(dx, f, x, cache)
cache1 .= Dual{DeivVecTag}.(x, v)
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
g(cache2, cache1)
dy .= partials.(cache2, 1)
end
Expand Down Expand Up @@ -164,16 +169,17 @@ function auto_hesvecgrad!(
g,
x,
v,
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v),
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
)
cache2 .= Dual{DeivVecTag}.(x, v)
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
g(cache3, cache2)
dy .= partials.(cache3, 1)
end

function auto_hesvecgrad(g, x, v)
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
partials.(g(y), 1)
end

### Operator Forms
Expand All @@ -188,15 +194,16 @@ end

function JacVec(f, x::AbstractArray; autodiff = true)
if autodiff
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x)
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x)
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
else
cache1 = similar(x)
cache2 = similar(x)
end
JacVec(f, cache1, cache2, x, autodiff)
end

Base.eltype(L::JacVec) = eltype(L.x)
Base.size(L::JacVec) = (length(L.cache1), length(L.cache1))
Base.size(L::JacVec, i::Int) = length(L.cache1)
Base.:*(L::JacVec, v::AbstractVector) =
Expand Down Expand Up @@ -256,8 +263,8 @@ end

function HesVecGrad(g, x::AbstractArray; autodiff = false)
if autodiff
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, x)
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, x)
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(x)))
else
cache1 = similar(x)
cache2 = similar(x)
Expand Down
10 changes: 6 additions & 4 deletions src/differentiation/jaches_products_zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@ function numback_hesvec(f, x, v)
(gxp - gxm)/(2ϵ)
end

function autoback_hesvec!(dy, f, x, v, cache2 = ForwardDiff.Dual{Nothing}.(x, v),
cache3 = ForwardDiff.Dual{Nothing}.(x, v))
function autoback_hesvec!(dy, f, x, v,
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))))
g = let f=f
(dx, x) -> dx .= first(Zygote.gradient(f,x))
end
cache2 .= Dual{Nothing}.(x, v)
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
g(cache3,cache2)
dy .= partials.(cache3, 1)
end

function autoback_hesvec(f, x, v)
g = x -> first(Zygote.gradient(f,x))
ForwardDiff.partials.(g(ForwardDiff.Dual{Nothing}.(x, v)), 1)
y = Dual{typeof(ForwardDiff.Tag(DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
ForwardDiff.partials.(g(y), 1)
end
8 changes: 4 additions & 4 deletions test/test_jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ function h(dy,x)
FiniteDiff.finite_difference_gradient!(dy,g,x)
end

cache1 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
@test num_jacvec!(dy, f, x, v) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
@test num_jacvec!(dy, f, x, v, similar(v), similar(v)) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
@test num_jacvec(f, x, v) ForwardDiff.jacobian(f,similar(x),x)*v rtol=1e-6
Expand All @@ -44,8 +44,8 @@ cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v)) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test numback_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8

cache3 = ForwardDiff.Dual{Nothing}.(x, v)
cache4 = ForwardDiff.Dual{Nothing}.(x, v)
cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing,eltype(x))),eltype(x),1}.(x, ForwardDiff.Partials.(Tuple.(v)))
@test autoback_hesvec!(dy, g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test autoback_hesvec!(dy, g, x, v, cache3, cache4) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test autoback_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
Expand Down

0 comments on commit 508a6cc

Please sign in to comment.