diff --git a/.travis.yml b/.travis.yml index 8cc900185..3c0a64eb9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,10 +7,10 @@ julia: - 1.1 - 1.2 - 1.3 + - 1.4 - nightly jobs: allow_failures: - - julia: 1.3 - julia: nightly include: - stage: "Documentation" diff --git a/Project.toml b/Project.toml index a8d04495e..c032e4fc4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.6.1" +version = "0.7.0" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/docs/src/index.md b/docs/src/index.md index 9c0b70b9f..38c5c4cfa 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -59,7 +59,7 @@ Almost always the _pullback_ will be declared locally within the `rrule`, and wi The `frule` is written: ```julia -function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...) +function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... return y, ∂Y end @@ -175,7 +175,7 @@ end ``` But because it is fused into frule we see it as part of: ```julia -function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...) +function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... return y, ∂y end @@ -183,7 +183,7 @@ end The input to the pushforward is often called the _perturbation_. -If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule(f, x, ṡelf, ẋ))`. +If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule((ṡelf, ẋ), f, x))`. `ẏ` is commonly used to represent the perturbation for `y`. !!! note @@ -238,14 +238,14 @@ If we would like to know the the directional derivative of `f` for an input chan ```julia direction = (1.5, 0.4, -1) # (ȧ, ḃ, ċ) -y, ẏ = frule(f, a, b, c, Zero(), direction) +y, ẏ = frule((Zero(), direction...), f, a, b, c) ``` On the basis directions one gets the partial derivatives of `y`: ```julia -y, ∂y_∂a = frule(f, a, b, c, Zero(), 1, 0, 0) -y, ∂y_∂b = frule(f, a, b, c, Zero(), 0, 1, 0) -y, ∂y_∂c = frule(f, a, b, c, Zero(), 0, 0, 1) +y, ∂y_∂a = frule((Zero(), 1, 0, 0), f, a, b, c) +y, ∂y_∂b = frule((Zero(), 0, 1, 0), f, a, b, c) +y, ∂y_∂c = frule((Zero(), 0, 0, 1), f, a, b, c) ``` Similarly, the most trivial use of `rrule` and returned `pullback` is to calculate the [Gradient](https://en.wikipedia.org/wiki/Gradient): @@ -320,10 +320,10 @@ x = 3; ẋ = 1; # ∂x/∂x nofields = Zero(); # ∂self/∂self -a, ȧ = frule(sin, x, nofields, ẋ); # ∂a/∂x -b, ḃ = frule(*, 2, a, nofields, Zero(), unthunk(ȧ)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x +a, ȧ = frule((nofields, ẋ), sin, x); # ∂a/∂x +b, ḃ = frule((nofields, Zero(), unthunk(ȧ)), *, 2, a); # ∂b/∂x = ∂b/∂a⋅∂a/∂x -c, ċ = frule(asin, b, nofields, unthunk(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x +c, ċ = frule((nofields, unthunk(ḃ)), asin, b); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x unthunk(ċ) # output -2.0638950738662625 diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 29046ecd7..4edc8c88f 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -148,7 +148,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials) # Δs is the input to the propagator rule # because this is push-forward there is one per input to the function - Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs] + Δs = [esc(Symbol(:Δ, i)) for i in 1:n_inputs] pushforward_returns = map(1:n_outputs) do output_i ∂s = partials[output_i].args propagation_expr(Δs, ∂s) @@ -163,7 +163,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials) return quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - function ChainRulesCore.frule(::typeof($f), $(inputs...), _, $(Δs...)) + function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $pushforward_returns diff --git a/src/rules.jl b/src/rules.jl index 0f0c84e69..0fb73ece2 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -2,54 +2,48 @@ ##### `frule`/`rrule` ##### -# TODO: remember to update the examples """ - frule(f, x..., ṡelf, Δx...) + frule((Δf, Δx...), f, x...) -Expressing `x` as the tuple `(x₁, x₂, ...)`, `Δx` as the tuple `(Δx₁, Δx₂, -...)`, and the output tuple of `f(x...)` as `Ω`, return the tuple: +Expressing the output of `f(x...)` as `Ω`, return the tuple: - (Ω, (Ω̇₁, Ω̇₂, ...)) + (Ω, ΔΩ) -The second return value is the propagation rule, or the pushforward. -It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`) -and `ṡelf` the internal values of the function (for closures). +The second return value is the differential w.r.t. the output. - -If no method matching `frule(f, x..., ṡelf, Δx...)` has been defined, then -return `nothing`. +If no method matching `frule((Δf, Δx...), f, x...)` has been defined, then return `nothing`. Examples: unary input, unary output scalar function: -``` +```jldoctest julia> dself = Zero() Zero() julia> x = rand(); -julia> sinx, sin_pushforward = frule(sin, x, dself, 1) +julia> sinx, Δsinx = frule(sin, x, dself, 1) (0.35696518021277485, 0.9341176907197836) julia> sinx == sin(x) true -julia> sin_pushforward == cos(x) +julia> Δsinx == cos(x) true ``` unary input, binary output scalar function: -``` +```jldoctest julia> x = rand(); -julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1); +julia> sincosx, Δsincosx = frule(sincos, x, dself, 1); julia> sincosx == sincos(x) true -julia> sincos_pushforward == (cos(x), -sin(x)) +julia> Δsincosx == (cos(x), -sin(x)) true ``` diff --git a/test/rules.jl b/test/rules.jl index bd502d9f5..a80000376 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -15,14 +15,56 @@ nice(x) = 1 very_nice(x, y) = x + y @scalar_rule(very_nice(x, y), (One(), One())) + +# Tests that aim to ensure that the API for frules doesn't regress and make these things +# hard to implement. + +varargs_function(x...) = sum(x) +function ChainRulesCore.frule(dargs, ::typeof(varargs_function), x...) + Δx = Base.tail(dargs) + return sum(x), sum(Δx) +end + +mixed_vararg(x, y, z...) = x + y + sum(z) +function ChainRulesCore.frule( + dargs::Tuple{Any, Any, Any, Vararg}, + ::typeof(mixed_vararg), x, y, z..., +) + Δx = dargs[2] + Δy = dargs[3] + Δz = dargs[4:end] + return mixed_vararg(x, y, z...), Δx + Δy + sum(Δz) +end + +type_constraints(x::Int, y::Float64) = x + y +function ChainRulesCore.frule( + (_, Δx, Δy)::Tuple{Any, Int, Float64}, + ::typeof(type_constraints), x::Int, y::Float64, +) + return type_constraints(x, y), Δx + Δy +end + +mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) +function ChainRulesCore.frule( + dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, + ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, +) + Δx = dargs[2] + Δy = dargs[3] + Δz = dargs[4:end] + return x + y + sum(z), Δx + Δy + sum(Δz) +end + +ChainRulesCore.frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...) + ####### _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "frule and rrule" begin dself = Zero() - @test frule(cool, 1, dself, 1) === nothing - @test frule(cool, 1, dself, 1; iscool=true) === nothing + @test frule((dself, 1), cool, 1) === nothing + @test frule((dself, 1), cool, 1; iscool=true) === nothing @test rrule(cool, 1) === nothing @test rrule(cool, 1; iscool=true) === nothing @@ -37,7 +79,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) Tuple{typeof(rrule),typeof(cool),String}]) @test cool_methods == only_methods - frx, cool_pushforward = frule(cool, 1, dself, 1) + frx, cool_pushforward = frule((dself, 1), cool, 1) @test frx === 2 @test cool_pushforward === 1 rrx, cool_pullback = rrule(cool, 1) @@ -46,13 +88,38 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rrx === 2 @test rr1 === 1 - frx, nice_pushforward = frule(nice, 1, dself, 1) + frx, nice_pushforward = frule((dself, 1), nice, 1) @test nice_pushforward === Zero() rrx, nice_pullback = rrule(nice, 1) @test (NO_FIELDS, Zero()) === nice_pullback(1) - sx = @SVector [1, 2] - sy = @SVector [3, 4] - # This is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting - @inferred frule(very_nice, 1, 2, Zero(), sx, sy) + + # Test that these run. Do not care about numerical correctness. + @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) + + @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0) + + @test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0) + @test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing + + @test(frule( + (nothing, 3.0, 2.0, 1.0, 0.0), + mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, + ) == (6.0, 6.0)) + + # violates type constraints, thus an frule should not be found. + @test frule( + (nothing, 3, 2.0, 1.0, 5.0), + mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, + ) == nothing + + @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) + + @testset "broadcasting One" begin + sx = @SVector [1, 2] + sy = @SVector [3, 4] + + # Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting + @inferred frule((Zero(), sx, sy), very_nice, 1, 2) + end end