Skip to content

Commit

Permalink
Conjugate partials in scalar rrule (#170)
Browse files Browse the repository at this point in the history
* Conjugate partials in scalar rrule

* Document holomorphic requirement

* Test rrule partial is conjugated

* Increment version number

* Test complex rules using FD

* Remove only
  • Loading branch information
sethaxen authored Jun 26, 2020
1 parent d187432 commit 23b1a6e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 6 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.8.1"
version = "0.9.0"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"

[compat]
BenchmarkTools = "0.5"
FiniteDifferences = "0.10"
MuladdMacro = "0.2.1"
StaticArrays = "0.11, 0.12"
julia = "^1.0"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "BenchmarkTools", "LinearAlgebra", "StaticArrays"]
test = ["Test", "BenchmarkTools", "FiniteDifferences", "LinearAlgebra", "StaticArrays"]
16 changes: 12 additions & 4 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
allows the primal result to be conveniently referenced (as `Ω`) within the
derivative/setup expressions.
This macro assumes complex functions are holomorphic. In general, for non-holomorphic
functions, the `frule` and `rrule` must be defined manually.
The `@setup` argument can be elided if no setup code is need. In other
words:
Expand Down Expand Up @@ -182,7 +185,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
# 1 partial derivative per input
pullback_returns = map(1:n_inputs) do input_i
∂s = [partial.args[input_i] for partial in partials]
propagation_expr(Δs, ∂s)
propagation_expr(Δs, ∂s, true)
end

# Multi-output functions have pullbacks with a tuple input that will be destructured
Expand All @@ -203,19 +206,24 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
end

"""
propagation_expr(Δs, ∂s)
propagation_expr(Δs, ∂s, _conj = false)
Returns the expression for the propagation of
the input gradient `Δs` though the partials `∂s`.
Specify `_conj = true` to conjugate the partials.
"""
function propagation_expr(Δs, ∂s)
function propagation_expr(Δs, ∂s, _conj = false)
# This is basically Δs ⋅ ∂s
∂s = map(esc, ∂s)
n∂s = length(∂s)

# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
# literals.
∂_mul_Δs = ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
∂_mul_Δs = if _conj
ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s)
else
ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
end

# Avoiding the extra `+` operation, it is potentially expensive for vector
# mode AD.
Expand Down
15 changes: 15 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ nice(x) = 1
very_nice(x, y) = x + y
@scalar_rule(very_nice(x, y), (One(), One()))

complex_times(x) = (1 + 2im) * x
@scalar_rule(complex_times(x), 1 + 2im)

# Tests that aim to ensure that the API for frules doesn't regress and make these things
# hard to implement.
Expand Down Expand Up @@ -122,6 +124,19 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
# Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule((Zero(), sx, sy), very_nice, 1, 2)
end

@testset "complex inputs" begin
x, ẋ, Ω̄ = randn(ComplexF64, 3)
Ω = complex_times(x)
Ω_fwd, Ω̇ = frule((nothing, ẋ), complex_times, x)
@test Ω_fwd == Ω
@test Ω̇ jvp(central_fdm(5, 1), complex_times, (x, ẋ))
Ω_rev, back = rrule(complex_times, x)
@test Ω_rev == Ω
∂self, ∂x = back(Ω̄)
@test ∂self == NO_FIELDS
@test ∂x j′vp(central_fdm(5, 1), complex_times, Ω̄, x)[1]
end
end


Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Base.Broadcast: broadcastable
using BenchmarkTools
using ChainRulesCore
using LinearAlgebra: Diagonal
using FiniteDifferences
using Test

@testset "ChainRulesCore" begin
Expand Down

2 comments on commit 23b1a6e

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17028

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.0 -m "<description of version>" 23b1a6eb8953e18b1052d256339aca69a58eb149
git push origin v0.9.0

Please sign in to comment.