Skip to content

Commit 45e0013

Browse files
committed
make add!! responsible for deciding if to use InplaceThunk.add!
1 parent db9d5a7 commit 45e0013

File tree

5 files changed

+111
-27
lines changed

5 files changed

+111
-27
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.16"
3+
version = "0.9.17"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
77
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
8+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
89

910
[compat]
1011
BenchmarkTools = "0.5"

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
33
using LinearAlgebra: LinearAlgebra
4+
using SparseArrays: SparseVector, SparseMatrixCSC
45
using MuladdMacro: @muladd
56

67
export on_new_rule, refresh_rules # generation tools

src/accumulation.jl

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,48 @@
33
44
Returns `x+y`, potentially mutating `x` in-place to hold this value.
55
This avoids allocations when `x` can be mutated in this way.
6-
7-
See also: [`InplaceableThunk`](@ref).
86
"""
97
add!!(x, y) = x + y
108

11-
add!!(x, t::InplaceableThunk) = t.add!(x)
9+
"""
10+
add!!(x, t::ImplacableThunk)
11+
12+
The specialization of `add!!` for [`InplaceableThunk`](@ref) promises to only call
13+
`t.add!` on `x` if `x` is suitably mutable; otherwise it will be out of place.
14+
"""
15+
function add!!(x, t::InplaceableThunk)
16+
return if is_inplaceable_destination(x)
17+
t.add!(x)
18+
else
19+
x + t
20+
end
21+
end
1222

13-
function add!!(x::Array{<:Any, N}, y::AbstractArray{<:Any, N}) where N
14-
return x .+= y
23+
function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N
24+
return if is_inplaceable_destination(x)
25+
x .+= y
26+
else
27+
x + y
28+
end
29+
end
30+
31+
32+
"""
33+
is_inplaceable_destination(x)
34+
35+
Returns true if `x` is suitable for for storing inplace accumulation of gradients.
36+
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate
37+
differential.
38+
"""
39+
is_inplaceable_destination(::Any) = false
40+
is_inplaceable_destination(::Array) = true
41+
is_inplaceable_destination(::SparseVector) = true
42+
is_inplaceable_destination(::SparseMatrixCSC) = true
43+
is_inplaceable_destination(::BitArray) = true
44+
function is_inplaceable_destination(x::AbstractArray)
45+
p = parent(x)
46+
p === x && return false # no parent
47+
# basically all wrapper types delegate `setindex!` to their `parent` after some
48+
# processing and so are mutable if their `parent` is.
49+
return is_inplaceable_destination(p)
1550
end

test/accumulation.jl

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,23 @@
11
@testset "accumulation.jl" begin
2+
@testset "is_inplaceable_destination" begin
3+
is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination
4+
5+
@test is_inplaceable_destination([1, 2, 3, 4])
6+
@test !is_inplaceable_destination(1:4)
7+
8+
@test is_inplaceable_destination(Diagonal([1, 2, 3, 4]))
9+
@test !is_inplaceable_destination(Diagonal(1:4))
10+
11+
@test is_inplaceable_destination(view([1, 2, 3, 4], :, :))
12+
@test !is_inplaceable_destination(view(1:4, :, :))
13+
14+
@test !is_inplaceable_destination(1.3)
15+
@test is_inplaceable_destination(falses(4))
16+
@test is_inplaceable_destination(spzeros(4))
17+
@test is_inplaceable_destination(spzeros(2, 2))
18+
@test !is_inplaceable_destination(@SVector [1, 2, 3])
19+
end
20+
221
@testset "add!!" begin
322
@testset "scalar" begin
423
@test 16 == add!!(12, 4)
@@ -11,33 +30,50 @@
1130
@test 16 == add!!(16, DoesNotExist()) # Should this be an error?
1231
end
1332

14-
@testset "Array" begin
15-
@testset "Happy Path" begin
33+
@testset "add!!(::AbstractArray, ::AbstractArray)" begin
34+
@testset "LHS Array (inplace)" begin
1635
@testset "RHS Array" begin
1736
A = [1.0 2.0; 3.0 4.0]
18-
result = -1.0*ones(2,2)
19-
ret = add!!(result, A)
20-
@test ret === result # must be same object
21-
@test result == [0.0 1.0; 2.0 3.0]
37+
accumuland = -1.0*ones(2,2)
38+
ret = add!!(accumuland, A)
39+
@test ret === accumuland # must be same object
40+
@test accumuland == [0.0 1.0; 2.0 3.0]
2241
end
2342

2443
@testset "RHS StaticArray" begin
2544
A = @SMatrix[1.0 2.0; 3.0 4.0]
26-
result = -1.0*ones(2,2)
27-
ret = add!!(result, A)
28-
@test ret === result # must be same object
29-
@test result == [0.0 1.0; 2.0 3.0]
45+
accumuland = -1.0*ones(2,2)
46+
ret = add!!(accumuland, A)
47+
@test ret === accumuland # must be same object
48+
@test accumuland == [0.0 1.0; 2.0 3.0]
3049
end
3150

3251
@testset "RHS Diagonal" begin
3352
A = Diagonal([1.0, 2.0])
34-
result = -1.0*ones(2,2)
35-
ret = add!!(result, A)
36-
@test ret === result # must be same object
37-
@test result == [0.0 -1.0; -1.0 1.0]
53+
accumuland = -1.0*ones(2,2)
54+
ret = add!!(accumuland, A)
55+
@test ret === accumuland # must be same object
56+
@test accumuland == [0.0 -1.0; -1.0 1.0]
3857
end
3958
end
4059

60+
@testset "add!!(::StaticArray, ::Array) (out of place)" begin
61+
A = [1.0 2.0; 3.0 4.0]
62+
accumuland = @SMatrix [-1.0 -1.0; -1.0 -1.0]
63+
ret = add!!(accumuland, A)
64+
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
65+
@test ret !== accumuland # must not be same object
66+
@test accumuland == [-1.0 -1.0; -1.0 -1.0] # must not have changed
67+
end
68+
69+
@testset "add!!(::Diagonal{<:Vector}, ::Diagonal{<:Vector}) (inplace)" begin
70+
A = Diagonal([1.0, 2.0])
71+
accumuland = Diagonal([-2.0, -2.0])
72+
ret = add!!(accumuland, A)
73+
@test ret === accumuland # must be same object
74+
@test accumuland == Diagonal([-1.0, 0.0])
75+
end
76+
4177
@testset "Unhappy Path" begin
4278
# wrong length
4379
@test_throws DimensionMismatch add!!(ones(4,4), ones(2,2))
@@ -49,16 +85,26 @@
4985
end
5086

5187
@testset "InplaceableThunk" begin
52-
A=[1.0 2.0; 3.0 4.0]
5388
ithunk = InplaceableThunk(
54-
@thunk(A*B),
55-
x -> x.+=A
89+
@thunk(-1.0*ones(2, 2)),
90+
x -> x .-= ones(2, 2)
5691
)
5792

58-
accumuland = -1.0*ones(2,2)
59-
ret = add!!(accumuland, ithunk)
60-
@test ret === accumuland # must be same object
61-
@test accumuland == [0.0 1.0; 2.0 3.0]
93+
@testset "in place" begin
94+
accumuland = [1.0 2.0; 3.0 4.0]
95+
ret = add!!(accumuland, ithunk)
96+
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
97+
@test ret === accumuland # must be same object
98+
end
99+
100+
@testset "out of place" begin
101+
accumuland = @SMatrix [1.0 2.0; 3.0 4.0]
102+
103+
ret = add!!(accumuland, ithunk)
104+
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
105+
@test ret !== accumuland # must not be same object
106+
@test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated
107+
end
62108
end
63109
end
64110
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using BenchmarkTools
33
using ChainRulesCore
44
using LinearAlgebra: Diagonal, dot
55
using StaticArrays
6+
using SparseArrays
67
using Test
78

89
@testset "ChainRulesCore" begin

0 commit comments

Comments
 (0)