From 54ebdf8a62be3925a069e1be38b7b73f571dbe14 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 17 Apr 2020 10:41:01 +0200 Subject: [PATCH 1/3] fix ambiguities with `Composite` Co-Authored-By: Lyndon White --- src/differential_arithmetic.jl | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 26be95479..bb2a2170a 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -16,7 +16,7 @@ Notice: Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist() Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist() -for T in (:One, :AbstractThunk, :Any) +for T in (:One, :AbstractThunk, :Composite, :Any) @eval Base.:+(::DoesNotExist, b::$T) = b @eval Base.:+(a::$T, ::DoesNotExist) = a @@ -43,7 +43,7 @@ Base.muladd(::Zero, ::Zero, ::Zero) = Zero() Base.:+(::Zero, ::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() -for T in (:One, :AbstractThunk, :Any) +for T in (:One, :AbstractThunk, :Composite, :Any) @eval Base.:+(::Zero, b::$T) = b @eval Base.:+(a::$T, ::Zero) = a @@ -53,9 +53,11 @@ end Base.:+(a::One, b::One) = extern(a) + extern(b) Base.:*(::One, ::One) = One() -for T in (:AbstractThunk, :Any) - @eval Base.:+(a::One, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::One) = a + extern(b) +for T in (:AbstractThunk, :Composite, :Any) + if T != :Composite + @eval Base.:+(a::One, b::$T) = extern(a) + b + @eval Base.:+(a::$T, b::One) = a + extern(b) + end @eval Base.:*(::One, b::$T) = b @eval Base.:*(a::$T, ::One) = a @@ -64,7 +66,7 @@ end Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b) Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b) -for T in (:Any,) +for T in (:Composite, :Any) @eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b @eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b) @@ -72,15 +74,6 @@ for T in (:Any,) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -################## Composite ############################################################## - -# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful -# In general one doesn't have to represent multiplications of 2 differentials -# Only of a differential and a scaling factor (generally `Real`) -Base.:*(s::Any, comp::Composite) = map(x->s*x, comp) -Base.:*(comp::Composite, s::Any) = map(x->x*s, comp) - - function Base.:+(a::Composite{P}, b::Composite{P}) where P data = elementwise_add(backing(a), backing(b)) return Composite{P, typeof(data)}(data) @@ -98,3 +91,11 @@ function Base.:+(a::P, d::Composite{P}) where P end end Base.:+(a::Composite{P}, b::P) where P = b + a + +# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful +# In general one doesn't have to represent multiplications of 2 differentials +# Only of a differential and a scaling factor (generally `Real`) +for T in (:Any,) + @eval Base.:*(s::$T, comp::Composite) = map(x->s*x, comp) + @eval Base.:*(comp::Composite, s::$T) = map(x->x*s, comp) +end From 6586b746eefc1a5448441040011cea469cf7b8ef Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 17 Apr 2020 10:41:19 +0200 Subject: [PATCH 2/3] test differential arithmetic with `Composite` --- test/differentials/composite.jl | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/test/differentials/composite.jl b/test/differentials/composite.jl index 396cd2bd8..0f4256e17 100644 --- a/test/differentials/composite.jl +++ b/test/differentials/composite.jl @@ -154,7 +154,24 @@ end @test diff + value == StructWithInvariant(12.5) end - @testset "Scaling" begin + @testset "differential arithmetic" begin + c = Composite{Foo}(y=1.5, x=2.5) + + @test DoesNotExist() * c == DoesNotExist() + @test c * DoesNotExist() == DoesNotExist() + + @test Zero() * c == Zero() + @test c * Zero() == Zero() + + @test One() * c === c + @test c * One() === c + + t = @thunk 2 + @test t * c == 2 * c + @test c * t == c * 2 + end + + @testset "scaling" begin @test ( 2 * Composite{Foo}(y=1.5, x=2.5) == Composite{Foo}(y=3.0, x=5.0) From 8a3867b14d724c2221c34e7830f9e9787ef4d449 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Sat, 18 Apr 2020 10:15:02 +0200 Subject: [PATCH 3/3] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b54c9369c..25cdc6a05 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.7.4" +version = "0.7.5" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"