Skip to content

Commit

Permalink
Merge pull request #149 from JuliaDiff/sim/fix_ambiguities
Browse files Browse the repository at this point in the history
fix ambiguities with `Composite`
  • Loading branch information
simeonschaub authored Apr 26, 2020
2 parents 200bd8f + 8a3867b commit c3dfd3c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.7.4"
version = "0.7.5"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Expand Down
31 changes: 16 additions & 15 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Notice:

Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
for T in (:One, :AbstractThunk, :Any)
for T in (:One, :AbstractThunk, :Composite, :Any)
@eval Base.:+(::DoesNotExist, b::$T) = b
@eval Base.:+(a::$T, ::DoesNotExist) = a

Expand All @@ -43,7 +43,7 @@ Base.muladd(::Zero, ::Zero, ::Zero) = Zero()

Base.:+(::Zero, ::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:One, :AbstractThunk, :Any)
for T in (:One, :AbstractThunk, :Composite, :Any)
@eval Base.:+(::Zero, b::$T) = b
@eval Base.:+(a::$T, ::Zero) = a

Expand All @@ -53,9 +53,11 @@ end

Base.:+(a::One, b::One) = extern(a) + extern(b)
Base.:*(::One, ::One) = One()
for T in (:AbstractThunk, :Any)
@eval Base.:+(a::One, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::One) = a + extern(b)
for T in (:AbstractThunk, :Composite, :Any)
if T != :Composite
@eval Base.:+(a::One, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::One) = a + extern(b)
end

@eval Base.:*(::One, b::$T) = b
@eval Base.:*(a::$T, ::One) = a
Expand All @@ -64,23 +66,14 @@ end

Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
for T in (:Any,)
for T in (:Composite, :Any)
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
@eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b)

@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
end

################## Composite ##############################################################

# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful
# In general one doesn't have to represent multiplications of 2 differentials
# Only of a differential and a scaling factor (generally `Real`)
Base.:*(s::Any, comp::Composite) = map(x->s*x, comp)
Base.:*(comp::Composite, s::Any) = map(x->x*s, comp)


function Base.:+(a::Composite{P}, b::Composite{P}) where P
data = elementwise_add(backing(a), backing(b))
return Composite{P, typeof(data)}(data)
Expand All @@ -98,3 +91,11 @@ function Base.:+(a::P, d::Composite{P}) where P
end
end
Base.:+(a::Composite{P}, b::P) where P = b + a

# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful
# In general one doesn't have to represent multiplications of 2 differentials
# Only of a differential and a scaling factor (generally `Real`)
for T in (:Any,)
@eval Base.:*(s::$T, comp::Composite) = map(x->s*x, comp)
@eval Base.:*(comp::Composite, s::$T) = map(x->x*s, comp)
end
19 changes: 18 additions & 1 deletion test/differentials/composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,24 @@ end
@test diff + value == StructWithInvariant(12.5)
end

@testset "Scaling" begin
@testset "differential arithmetic" begin
c = Composite{Foo}(y=1.5, x=2.5)

@test DoesNotExist() * c == DoesNotExist()
@test c * DoesNotExist() == DoesNotExist()

@test Zero() * c == Zero()
@test c * Zero() == Zero()

@test One() * c === c
@test c * One() === c

t = @thunk 2
@test t * c == 2 * c
@test c * t == c * 2
end

@testset "scaling" begin
@test (
2 * Composite{Foo}(y=1.5, x=2.5)
== Composite{Foo}(y=3.0, x=5.0)
Expand Down

2 comments on commit c3dfd3c

@simeonschaub
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/13672

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.5 -m "<description of version>" c3dfd3cb7a465375a11f8c51fee63a67fb3ccefc
git push origin v0.7.5

Please sign in to comment.