From 7c97dd846e05b142539610881dd67c46d39f0b40 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 10 Nov 2024 08:54:13 +0100 Subject: [PATCH 1/2] fix: respect array type in wrong-mode pushforward/pullback --- .../src/first_order/pullback.jl | 24 +++++++++---------- .../src/first_order/pushforward.jl | 18 +++++++------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index ebf74a26b..e2d1a38cc 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -167,8 +167,8 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - t1 = pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...) - dx = dot(dy, only(t1)) + t = pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...) + dx = dot(dy, only(t)) return dx end @@ -180,9 +180,10 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j - t1 = pushforward(f, pushforward_prep, backend, x, (basis(backend, x, j),), contexts...) - dot(dy, only(t1)) + dx = map(x, CartesianIndices(x)) do xj, j + bj = basis(backend, x, j) + tj = pushforward(f, pushforward_prep, backend, x, (bj,), contexts...) + dot(dy, only(tj)) end return dx end @@ -252,8 +253,8 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - t1 = pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...) - dx = dot(dy, only(t1)) + t = pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...) + dx = dot(dy, only(t)) return dx end @@ -266,11 +267,10 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j # preserve shape - t1 = pushforward( - f!, y, pushforward_prep, backend, x, (basis(backend, x, j),), contexts... - ) - dot(dy, only(t1)) + dx = map(x, CartesianIndices(x)) do xj, j # preserve shape + bj = basis(backend, x, j) + tj = pushforward(f!, y, pushforward_prep, backend, x, (bj,), contexts...) + dot(dy, only(tj)) end return dx end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 83475757b..44670bbaa 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -175,8 +175,8 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - t1 = pullback(f, pullback_prep, backend, x, (one(y),), contexts...) - dy = dot(dx, only(t1)) + t = pullback(f, pullback_prep, backend, x, (one(y),), contexts...) + dy = dot(dx, only(t)) return dy end @@ -189,9 +189,10 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i - t1 = pullback(f, pullback_prep, backend, x, (basis(backend, y, i),), contexts...) - dot(dx, only(t1)) + dy = map(y, CartesianIndices(y)) do yi, i + bi = basis(backend, y, i) + ti = pullback(f, pullback_prep, backend, x, (bi,), contexts...) + dot(dx, only(ti)) end return dy end @@ -261,9 +262,10 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i # preserve shape - t1 = pullback(f!, y, pullback_prep, backend, x, (basis(backend, y, i),), contexts...) - dot(dx, only(t1)) + dy = map(y, CartesianIndices(y)) do yi, i # preserve shape + bi = basis(backend, y, i) + ti = pullback(f!, y, pullback_prep, backend, x, (bi,), contexts...) + dot(dx, only(ti)) end return dy end From dde3d4a5f6fe1e79209855e2cca4a443f05799bb Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 08:43:29 +0100 Subject: [PATCH 2/2] Correct map --- DifferentiationInterface/src/first_order/pullback.jl | 4 ++-- DifferentiationInterface/src/first_order/pushforward.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index a77d15c0c..d4a4ba7ef 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -168,7 +168,7 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j + dx = map(CartesianIndices(x)) do xj, j t1 = pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...) convert(eltype(x), dot(only(t1), dy)) end @@ -254,7 +254,7 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - dx = map(CartesianIndices(x)) do j # preserve shape + dx = map(CartesianIndices(x)) do xj, j # preserve shape t1 = pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...) convert(eltype(x), dot(only(t1), dy)) end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 013fd4fb7..2a62fc94d 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -171,7 +171,7 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(y, CartesianIndices(y)) do i + dy = map(y, CartesianIndices(y)) do yi, i t1 = pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...) convert(eltype(y), dot(only(t1), dx)) end @@ -243,7 +243,7 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - dy = map(CartesianIndices(y)) do i # preserve shape + dy = map(CartesianIndices(y)) do yi, i # preserve shape t1 = pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...) convert(eltype(y), dot(only(t1), dx)) end