Skip to content

Commit 81106f5

Browse files
authored
Merge pull request #225 from JuliaDiff/ox/adjoint
Add adjoint and transpose for thunks
2 parents a57b46e + ca44608 commit 81106f5

File tree

5 files changed

+58
-42
lines changed

5 files changed

+58
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.10"
3+
version = "0.9.11"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/differentials/abstract_differential.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,6 @@ The subtypes of `AbstractDifferential` define a custom \"algebra\" for chain
77
rule evaluation that attempts to factor various features like complex derivative
88
support, broadcast fusion, zero-elision, etc. into nicely separated parts.
99
10-
All subtypes of `AbstractDifferential` implement the following operations:
11-
12-
`+(a, b)`: linearly combine differential `a` and differential `b`
13-
14-
`*(a, b)`: multiply the differential `b` by the scaling factor `a`
15-
16-
`Base.conj(x)`: complex conjugate of the differential `x`
17-
18-
`Base.zero(x) = Zero()`: a zero.
19-
2010
In general a differential type is the type of a derivative of a value.
2111
The type of the value is for contrast called the primal type.
2212
Differential types correspond to primal types, although the relation is not one-to-one.
@@ -30,6 +20,18 @@ That allows for gradients to be accumulated.
3020
3121
It generally also should be able to be added to a primal to give back another primal, as
3222
this facilitates gradient descent.
23+
24+
All subtypes of `AbstractDifferential` implement the following operations:
25+
26+
- `+(a, b)`: linearly combine differential `a` and differential `b`
27+
- `*(a, b)`: multiply the differential `b` by the scaling factor `a`
28+
- `Base.zero(x) = Zero()`: a zero.
29+
30+
Further, they often implement other linear operators, such as `conj`, `adjoint`, `dot`.
31+
Pullbacks/pushforwards are linear operators, and their inputs are often
32+
`AbstractDifferential` subtypes.
33+
Pullbacks/pushforwards in-turn call other linear operators on those inputs.
34+
Thus it is desirable to have all common linear operators work on `AbstractDifferential`s.
3335
"""
3436
abstract type AbstractDifferential end
3537

src/differentials/thunks.jl

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@ end
1313
return element, (val, new_state)
1414
end
1515

16+
"""
17+
@thunk expr
18+
19+
Define a [`Thunk`](@ref) wrapping the `expr`, to lazily defer its evaluation.
20+
"""
21+
macro thunk(body)
22+
# Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined.
23+
# so we get useful stack traces if it errors.
24+
func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
25+
return :(Thunk($(esc(func))))
26+
end
27+
28+
"""
29+
unthunk(x)
30+
31+
On `AbstractThunk`s this removes 1 layer of thunking.
32+
On any other type, it is the identity operation.
33+
34+
In contrast to [`extern`](@ref) this is nonrecursive.
35+
"""
36+
@inline unthunk(x) = x
37+
38+
@inline extern(x::AbstractThunk) = extern(unthunk(x))
39+
40+
Base.conj(x::AbstractThunk) = @thunk(conj(unthunk(x)))
41+
Base.adjoint(x::AbstractThunk) = @thunk(adjoint(unthunk(x)))
42+
Base.transpose(x::AbstractThunk) = @thunk(transpose(unthunk(x)))
43+
1644
#####
1745
##### `Thunk`
1846
#####
@@ -59,42 +87,14 @@ Also if we did `Zero() * res[1]` then the result would be `Zero()` and `f(x)` wo
5987
with a field for each variable used in the expression, and call overloaded.
6088
6189
Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself.
90+
This is commonly the case for scalar operators.
6291
6392
For more details see the manual section [on using thunks effectively](http://www.juliadiff.org/ChainRulesCore.jl/dev/writing_good_rules.html#Use-Thunks-appropriately-1)
6493
"""
6594
struct Thunk{F} <: AbstractThunk
6695
f::F
6796
end
6897

69-
70-
"""
71-
@thunk expr
72-
73-
Define a [`Thunk`](@ref) wrapping the `expr`, to lazily defer its evaluation.
74-
"""
75-
macro thunk(body)
76-
# Basically `:(Thunk(() -> $(esc(body))))` but use the location where it is defined.
77-
# so we get useful stack traces if it errors.
78-
func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
79-
return :(Thunk($(esc(func))))
80-
end
81-
82-
"""
83-
unthunk(x)
84-
85-
On `AbstractThunk`s this removes 1 layer of thunking.
86-
On any other type, it is the identity operation.
87-
88-
In contrast to [`extern`](@ref) this is nonrecursive.
89-
"""
90-
@inline unthunk(x) = x
91-
92-
@inline extern(x::AbstractThunk) = extern(unthunk(x))
93-
94-
# have to define this here after `@thunk` and `Thunk` is defined
95-
Base.conj(x::AbstractThunk) = @thunk(conj(unthunk(x)))
96-
97-
9898
(x::Thunk)() = x.f()
9999
@inline unthunk(x::Thunk) = x()
100100

test/differentials/composite.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ end
9999

100100
@testset "Tuples" begin
101101
@test ==(
102-
typeof(Composite{Tuple{}}() + Composite{Tuple{}}()),
102+
typeof(Composite{Tuple{}}() + Composite{Tuple{}}()),
103103
Composite{Tuple{}, Tuple{}}
104104
)
105105
@test (
@@ -219,7 +219,12 @@ end
219219

220220
@testset "show" begin
221221
@test repr(Composite{Foo}(x=1,)) == "Composite{Foo}(x = 1,)"
222-
@test repr(Composite{Tuple{Int,Int}}(1, 2)) == "Composite{Tuple{Int64,Int64}}(1, 2)"
222+
# check for exact regex match not occurence( `^...$`)
223+
# and allowing optional whitespace (`\s?`)
224+
@test occursin(
225+
r"^Composite{Tuple{Int64,\s?Int64}}\(1,\s?2\)$",
226+
repr(Composite{Tuple{Int64,Int64}}(1, 2)),
227+
)
223228
end
224229

225230
@testset "internals" begin

test/differentials/thunks.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@
3636
end
3737
end
3838

39+
@testset "Linear operators" begin
40+
x_real = [2.0 4.0; 8.0 5.0]
41+
x_complex = [(2.0 + im) 4.0; 8.0 (5.0 + 4im)]
42+
@testset "$(typeof(x))" for x in (x_real, x_complex)
43+
x_thunked = @thunk(1.0 * x)
44+
@test unthunk(x_thunked') == x'
45+
@test unthunk(transpose(x_thunked)) == transpose(x)
46+
end
47+
end
3948

4049
@testset "Broadcast" begin
4150
@testset "Array" begin

0 commit comments

Comments
 (0)