Skip to content

Commit 6073a47

Browse files
authored
Add methods for thunks (#371)
1 parent bf01ddf commit 6073a47

File tree

7 files changed

+198
-13
lines changed

7 files changed

+198
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.10.5"
3+
version = "0.10.6"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.1"
16+
version = "0.10.6"
1717

1818
[[Compat]]
1919
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
3-
using LinearAlgebra: LinearAlgebra
3+
using LinearAlgebra
44
using SparseArrays: SparseVector, SparseMatrixCSC
55
using Compat: hasfield
66

src/differentials/abstract_zero.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Base.:/(z::AbstractZero, ::Any) = z
2626

2727
Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T)
2828

29+
Base.getindex(z::AbstractZero, k) = z
30+
2931
"""
3032
ZeroTangent() <: AbstractZero
3133

src/differentials/thunks.jl

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
abstract type AbstractThunk <: AbstractTangent end
22

3+
struct MutateThunkException <: Exception end
4+
5+
function Base.showerror(io::IO, e::MutateThunkException)
6+
print(io, "Tried to mutate a thunk, this is not supported. `unthunk` it first.")
7+
return nothing
8+
end
9+
310
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(unthunk(x))
411

512
@inline function Base.iterate(x::AbstractThunk)
@@ -19,6 +26,94 @@ Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b)
1926
Base.:(==)(a::AbstractThunk, b) = unthunk(a) == b
2027
Base.:(==)(a, b::AbstractThunk) = a == unthunk(b)
2128

29+
Base.:(-)(a::AbstractThunk) = -unthunk(a)
30+
Base.:(-)(a::AbstractThunk, b) = unthunk(a) - b
31+
Base.:(-)(a, b::AbstractThunk) = a - unthunk(b)
32+
Base.:(/)(a::AbstractThunk, b) = unthunk(a) / b
33+
Base.:(/)(a, b::AbstractThunk) = a / unthunk(b)
34+
35+
Base.real(a::AbstractThunk) = real(unthunk(a))
36+
Base.imag(a::AbstractThunk) = imag(unthunk(a))
37+
Base.Complex(a::AbstractThunk) = Complex(unthunk(a))
38+
Base.Complex(a::AbstractThunk, b::AbstractThunk) = Complex(unthunk(a), unthunk(b))
39+
40+
Base.mapreduce(f, op, a::AbstractThunk; kws...) = mapreduce(f, op, unthunk(a); kws...)
41+
function Base.mapreduce(f, op, itr, a::AbstractThunk; kws...)
42+
return mapreduce(f, op, itr, unthunk(a); kws...)
43+
end
44+
Base.sum!(r, A::AbstractThunk; kws...) = sum!(r, unthunk(A); kws...)
45+
46+
Base.vec(a::AbstractThunk) = vec(unthunk(a))
47+
Base.reshape(a::AbstractThunk, args...) = reshape(unthunk(a), args...)
48+
Base.getindex(a::AbstractThunk, args...) = getindex(unthunk(a), args...)
49+
Base.setindex!(a::AbstractThunk, value, key...) = throw(MutateThunkException())
50+
Base.selectdim(a::AbstractThunk, args...) = selectdim(unthunk(a), args...)
51+
52+
LinearAlgebra.Array(a::AbstractThunk) = Array(unthunk(a))
53+
LinearAlgebra.Matrix(a::AbstractThunk) = Matrix(unthunk(a))
54+
LinearAlgebra.Diagonal(a::AbstractThunk) = Diagonal(unthunk(a))
55+
LinearAlgebra.LowerTriangular(a::AbstractThunk) = LowerTriangular(unthunk(a))
56+
LinearAlgebra.UpperTriangular(a::AbstractThunk) = UpperTriangular(unthunk(a))
57+
LinearAlgebra.Symmetric(a::AbstractThunk, uplo=:U) = Symmetric(unthunk(a), uplo)
58+
LinearAlgebra.Hermitian(a::AbstractThunk, uplo=:U) = Hermitian(unthunk(a), uplo)
59+
60+
function LinearAlgebra.diagm(kv::Pair{<:Integer,<:AbstractThunk}...)
61+
return diagm((k => unthunk(v) for (k, v) in kv)...)
62+
end
63+
function LinearAlgebra.diagm(m, n, kv::Pair{<:Integer,<:AbstractThunk}...)
64+
return diagm(m, n, (k => unthunk(v) for (k, v) in kv)...)
65+
end
66+
LinearAlgebra.tril(a::AbstractThunk) = tril(unthunk(a))
67+
LinearAlgebra.tril(a::AbstractThunk, k) = tril(unthunk(a), k)
68+
LinearAlgebra.triu(a::AbstractThunk) = triu(unthunk(a))
69+
LinearAlgebra.triu(a::AbstractThunk, k) = triu(unthunk(a), k)
70+
LinearAlgebra.tr(a::AbstractThunk) = tr(unthunk(a))
71+
LinearAlgebra.cross(a::AbstractThunk, b) = cross(unthunk(a), b)
72+
LinearAlgebra.cross(a, b::AbstractThunk) = cross(a, unthunk(b))
73+
LinearAlgebra.cross(a::AbstractThunk, b::AbstractThunk) = cross(unthunk(a), unthunk(b))
74+
LinearAlgebra.dot(a::AbstractThunk, b) = dot(unthunk(a), b)
75+
LinearAlgebra.dot(a, b::AbstractThunk) = dot(a, unthunk(b))
76+
LinearAlgebra.dot(a::AbstractThunk, b::AbstractThunk) = dot(unthunk(a), unthunk(b))
77+
78+
LinearAlgebra.ldiv!(a, b::AbstractThunk) = throw(MutateThunkException())
79+
LinearAlgebra.rdiv!(a::AbstractThunk, b) = throw(MutateThunkException())
80+
81+
LinearAlgebra.mul!(C::AbstractThunk, A, B, α, β) = throw(MutateThunkException())
82+
function LinearAlgebra.mul!(C::AbstractThunk, A::AbstractThunk, B, α, β)
83+
return throw(MutateThunkException())
84+
end
85+
function LinearAlgebra.mul!(C::AbstractThunk, A, B::AbstractThunk, α, β)
86+
return throw(MutateThunkException())
87+
end
88+
function LinearAlgebra.mul!(C::AbstractThunk, A::AbstractThunk, B::AbstractThunk, α, β)
89+
return throw(MutateThunkException())
90+
end
91+
LinearAlgebra.mul!(C, A::AbstractThunk, B, α, β) = mul!(C, unthunk(A), B, α, β)
92+
LinearAlgebra.mul!(C, A, B::AbstractThunk, α, β) = mul!(C, A, unthunk(B), α, β)
93+
function LinearAlgebra.mul!(C, A::AbstractThunk, B::AbstractThunk, α, β)
94+
return mul!(C, unthunk(A), unthunk(B), α, β)
95+
end
96+
97+
function LinearAlgebra.BLAS.ger!(alpha, x::AbstractThunk, y, A)
98+
return LinearAlgebra.BLAS.ger!(alpha, unthunk(x), y, A)
99+
end
100+
function LinearAlgebra.BLAS.ger!(alpha, x, y::AbstractThunk, A)
101+
return LinearAlgebra.BLAS.ger!(alpha, x, unthunk(y), A)
102+
end
103+
function LinearAlgebra.BLAS.gemv!(tA, alpha, A, x::AbstractThunk, beta, y)
104+
return LinearAlgebra.BLAS.gemv!(tA, alpha, A, unthunk(x), beta, y)
105+
end
106+
function LinearAlgebra.BLAS.gemv(tA, alpha, A, x::AbstractThunk)
107+
return LinearAlgebra.BLAS.gemv(tA, alpha, A, unthunk(x))
108+
end
109+
function LinearAlgebra.BLAS.scal!(n, a::AbstractThunk, X, incx)
110+
return LinearAlgebra.BLAS.scal!(n, unthunk(a), X, incx)
111+
end
112+
113+
function LinearAlgebra.LAPACK.trsyl!(transa, transb, A, B, C::AbstractThunk, isgn=1)
114+
return throw(MutateThunkException())
115+
end
116+
22117
"""
23118
@thunk expr
24119
@@ -109,7 +204,7 @@ but it should do this more efficently than simply doing this directly.
109204
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
110205
and destroy its inplacability.
111206
"""
112-
struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
207+
struct InplaceableThunk{T<:Thunk,F} <: AbstractThunk
113208
val::T
114209
add!::F
115210
end
@@ -118,5 +213,5 @@ unthunk(x::InplaceableThunk) = unthunk(x.val)
118213
(x::InplaceableThunk)() = unthunk(x)
119214

120215
function Base.show(io::IO, x::InplaceableThunk)
121-
print(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
216+
return print(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
122217
end

test/differentials/thunks.jl

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
@testset "Thunk" begin
2+
MutateThunkException = ChainRulesCore.MutateThunkException
3+
24
@test @thunk(3) isa Thunk
35

46
@testset "==" begin
@@ -47,7 +49,7 @@
4749

4850
@testset "Linear operators" begin
4951
x_real = [2.0 4.0; 8.0 5.0]
50-
x_complex = [(2.0 + im) 4.0; 8.0 (5.0 + 4im)]
52+
x_complex = [(2.0+im) 4.0; 8.0 (5.0+4im)]
5153
@testset "$(typeof(x))" for x in (x_real, x_complex)
5254
x_thunked = @thunk(1.0 * x)
5355
@test unthunk(x_thunked') == x'
@@ -59,34 +61,119 @@
5961
@testset "Array" begin
6062
was_unthunked = 0
6163
array_thunk = @thunk begin
62-
was_unthunked += 1;
64+
was_unthunked += 1
6365
[1.0, 2.0, 3.0]
6466
end
6567

6668
was_unthunked = 0
67-
@test array_thunk .+ fill(10, 3) .+ fill(10, 3) == [21.0, 22.0, 23.0]
69+
@test array_thunk .+ fill(10, 3) .+ fill(10, 3) == [21.0, 22.0, 23.0]
6870
@test was_unthunked == 1
6971

7072
was_unthunked = 0
7173
@test array_thunk .+ 10.0 .+ 10.0 == [21.0, 22.0, 23.0]
7274
@test was_unthunked == 1
73-
7475
end
7576

7677
@testset "Scalar" begin
77-
was_unthunked=0
78+
was_unthunked = 0
7879
scalar_thunk = @thunk begin
79-
was_unthunked += 1;
80+
was_unthunked += 1
8081
sqrt(4.0)
8182
end
8283

8384
was_unthunked = 0
84-
@test scalar_thunk .+ fill(10, 3) .+ fill(10, 3) == [22.0, 22.0, 22.0]
85+
@test scalar_thunk .+ fill(10, 3) .+ fill(10, 3) == [22.0, 22.0, 22.0]
8586
@test was_unthunked == 1
8687

8788
was_unthunked = 0
8889
@test scalar_thunk .+ 10.0 .+ 10.0 == 22.0
8990
@test was_unthunked == 1
9091
end
9192
end
93+
94+
@testset "basic math" begin
95+
@test 1 == -@thunk(-1)
96+
@test 1 == @thunk(2) - 1
97+
@test 1 == 2 - @thunk(1)
98+
@test 1.0 == @thunk(1) / 1.0
99+
@test 1.0 == 1.0 / @thunk(1)
100+
101+
@test 1 == real(@thunk(1 + 1im))
102+
@test 1 == imag(@thunk(1 + 1im))
103+
@test 1 + 1im == Complex(@thunk(1 + 1im))
104+
@test 1 + 1im == Complex(@thunk(1), @thunk(1))
105+
end
106+
107+
@testset "Base functions" begin
108+
v = [1, 2, 3]
109+
t = @thunk(v)
110+
111+
if VERSION >= v"1.2"
112+
@test 3 == mapreduce(_ -> 1, +, t)
113+
@test 3 == mapreduce((_, _) -> 1, +, v, t)
114+
end
115+
@test [4 6] == sum!([1 1], @thunk([1 2; 3 4]))
116+
117+
@test v == vec(t)
118+
@test [1 2 3] == reshape(t, 1, 3)
119+
@test 1 == getindex(t, 1)
120+
@test_throws MutateThunkException setindex!(t, 0.0, 1)
121+
@test [4; 5; 6] == selectdim([1 2 3; 4 5 6], 1, 2)
122+
end
123+
124+
@testset "LinearAlgebra" begin
125+
v = [1.0, 2.0, 3.0]
126+
tv = @thunk(v)
127+
a = [1.0 2.0; 3.0 4.0]
128+
t = @thunk(a)
129+
@test Array(a) == Array(t)
130+
@test Matrix(a) == Matrix(t)
131+
@test Diagonal(a) == Diagonal(t)
132+
@test LowerTriangular(a) == LowerTriangular(t)
133+
@test UpperTriangular(a) == UpperTriangular(t)
134+
@test Symmetric(a) == Symmetric(t)
135+
@test Hermitian(a) == Hermitian(t)
136+
137+
if VERSION >= v"1.2"
138+
@test diagm(0 => v) == diagm(0 => tv)
139+
@test diagm(3, 4, 0 => v) == diagm(3, 4, 0 => tv)
140+
end
141+
@test tril(a) == tril(t)
142+
@test tril(a, 1) == tril(t, 1)
143+
@test triu(a) == triu(t)
144+
@test triu(a, 1) == triu(t, 1)
145+
@test tr(a) == tr(t)
146+
@test cross(v, v) == cross(v, tv)
147+
@test cross(v, v) == cross(tv, v)
148+
@test cross(v, v) == cross(tv, tv)
149+
@test dot(v, v) == dot(v, tv)
150+
@test dot(v, v) == dot(tv, v)
151+
@test dot(v, v) == dot(tv, tv)
152+
153+
if VERSION >= v"1.2"
154+
@test_throws MutateThunkException ldiv!(2.0, deepcopy(t)) ==
155+
ldiv!(2.0, deepcopy(a))
156+
@test_throws MutateThunkException rdiv!(deepcopy(t), 2.0) ==
157+
rdiv!(deepcopy(a), 2.0)
158+
end
159+
160+
res = mul!(deepcopy(a), a, a, true, true)
161+
@test_throws MutateThunkException mul!(deepcopy(t), a, a, true, true)
162+
@test_throws MutateThunkException mul!(deepcopy(t), t, a, true, true)
163+
@test_throws MutateThunkException mul!(deepcopy(t), a, t, true, true)
164+
@test_throws MutateThunkException mul!(deepcopy(t), t, t, true, true)
165+
@test res == mul!(deepcopy(a), t, a, true, true)
166+
@test res == mul!(deepcopy(a), a, t, true, true)
167+
@test res == mul!(deepcopy(a), t, t, true, true)
168+
169+
m = rand(3, 3)
170+
@test ger!(1.0, v, v, deepcopy(m)) == ger!(1.0, tv, v, deepcopy(m))
171+
@test ger!(1.0, v, v, deepcopy(m)) == ger!(1.0, v, tv, deepcopy(m))
172+
@test gemv!('C', 1.0, m, v, 1.0, deepcopy(v)) ==
173+
gemv!('C', 1.0, m, tv, 1.0, deepcopy(v))
174+
@test gemv('N', 1.0, m, v) == gemv('N', 1.0, m, tv)
175+
176+
@test scal!(2, 2.0, v, 1) == scal!(2, @thunk(2.0), v, 1)
177+
@test_throws MutateThunkException LAPACK.trsyl!('C', 'C', m, m, @thunk(m))
178+
end
92179
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using Base.Broadcast: broadcastable
22
using BenchmarkTools
33
using ChainRulesCore
4-
using LinearAlgebra: Diagonal, dot, Hermitian, Symmetric
4+
using LinearAlgebra
5+
using LinearAlgebra.BLAS: ger!, gemv!, gemv, scal!
56
using StaticArrays
67
using SparseArrays
78
using Test

0 commit comments

Comments
 (0)