|
1 | 1 | @testset "accumulation.jl" begin |
| 2 | + @testset "is_inplaceable_destination" begin |
| 3 | + is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination |
| 4 | + |
| 5 | + @test is_inplaceable_destination([1, 2, 3, 4]) |
| 6 | + @test !is_inplaceable_destination(1:4) |
| 7 | + |
| 8 | + @test is_inplaceable_destination(Diagonal([1, 2, 3, 4])) |
| 9 | + @test !is_inplaceable_destination(Diagonal(1:4)) |
| 10 | + |
| 11 | + @test is_inplaceable_destination(view([1, 2, 3, 4], :, :)) |
| 12 | + @test !is_inplaceable_destination(view(1:4, :, :)) |
| 13 | + |
| 14 | + @test !is_inplaceable_destination(1.3) |
| 15 | + @test is_inplaceable_destination(falses(4)) |
| 16 | + @test is_inplaceable_destination(spzeros(4)) |
| 17 | + @test is_inplaceable_destination(spzeros(2, 2)) |
| 18 | + @test !is_inplaceable_destination(@SVector [1, 2, 3]) |
| 19 | + end |
| 20 | + |
2 | 21 | @testset "add!!" begin |
3 | 22 | @testset "scalar" begin |
4 | 23 | @test 16 == add!!(12, 4) |
|
11 | 30 | @test 16 == add!!(16, DoesNotExist()) # Should this be an error? |
12 | 31 | end |
13 | 32 |
|
14 | | - @testset "Array" begin |
15 | | - @testset "Happy Path" begin |
| 33 | + @testset "add!!(::AbstractArray, ::AbstractArray)" begin |
| 34 | + @testset "LHS Array (inplace)" begin |
16 | 35 | @testset "RHS Array" begin |
17 | 36 | A = [1.0 2.0; 3.0 4.0] |
18 | | - result = -1.0*ones(2,2) |
19 | | - ret = add!!(result, A) |
20 | | - @test ret === result # must be same object |
21 | | - @test result == [0.0 1.0; 2.0 3.0] |
| 37 | + accumuland = -1.0*ones(2,2) |
| 38 | + ret = add!!(accumuland, A) |
| 39 | + @test ret === accumuland # must be same object |
| 40 | + @test accumuland == [0.0 1.0; 2.0 3.0] |
22 | 41 | end |
23 | 42 |
|
24 | 43 | @testset "RHS StaticArray" begin |
25 | 44 | A = @SMatrix[1.0 2.0; 3.0 4.0] |
26 | | - result = -1.0*ones(2,2) |
27 | | - ret = add!!(result, A) |
28 | | - @test ret === result # must be same object |
29 | | - @test result == [0.0 1.0; 2.0 3.0] |
| 45 | + accumuland = -1.0*ones(2,2) |
| 46 | + ret = add!!(accumuland, A) |
| 47 | + @test ret === accumuland # must be same object |
| 48 | + @test accumuland == [0.0 1.0; 2.0 3.0] |
30 | 49 | end |
31 | 50 |
|
32 | 51 | @testset "RHS Diagonal" begin |
33 | 52 | A = Diagonal([1.0, 2.0]) |
34 | | - result = -1.0*ones(2,2) |
35 | | - ret = add!!(result, A) |
36 | | - @test ret === result # must be same object |
37 | | - @test result == [0.0 -1.0; -1.0 1.0] |
| 53 | + accumuland = -1.0*ones(2,2) |
| 54 | + ret = add!!(accumuland, A) |
| 55 | + @test ret === accumuland # must be same object |
| 56 | + @test accumuland == [0.0 -1.0; -1.0 1.0] |
38 | 57 | end |
39 | 58 | end |
40 | 59 |
|
| 60 | + @testset "add!!(::StaticArray, ::Array) (out of place)" begin |
| 61 | + A = [1.0 2.0; 3.0 4.0] |
| 62 | + accumuland = @SMatrix [-1.0 -1.0; -1.0 -1.0] |
| 63 | + ret = add!!(accumuland, A) |
| 64 | + @test ret == [0.0 1.0; 2.0 3.0] # must return right answer |
| 65 | + @test ret !== accumuland # must not be same object |
| 66 | + @test accumuland == [-1.0 -1.0; -1.0 -1.0] # must not have changed |
| 67 | + end |
| 68 | + |
| 69 | + @testset "add!!(::Diagonal{<:Vector}, ::Diagonal{<:Vector}) (inplace)" begin |
| 70 | + A = Diagonal([1.0, 2.0]) |
| 71 | + accumuland = Diagonal([-2.0, -2.0]) |
| 72 | + ret = add!!(accumuland, A) |
| 73 | + @test ret === accumuland # must be same object |
| 74 | + @test accumuland == Diagonal([-1.0, 0.0]) |
| 75 | + end |
| 76 | + |
41 | 77 | @testset "Unhappy Path" begin |
42 | 78 | # wrong length |
43 | 79 | @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) |
|
49 | 85 | end |
50 | 86 |
|
51 | 87 | @testset "InplaceableThunk" begin |
52 | | - A=[1.0 2.0; 3.0 4.0] |
53 | 88 | ithunk = InplaceableThunk( |
54 | | - @thunk(A*B), |
55 | | - x -> x.+=A |
| 89 | + @thunk(-1.0*ones(2, 2)), |
| 90 | + x -> x .-= ones(2, 2) |
56 | 91 | ) |
57 | 92 |
|
58 | | - accumuland = -1.0*ones(2,2) |
59 | | - ret = add!!(accumuland, ithunk) |
60 | | - @test ret === accumuland # must be same object |
61 | | - @test accumuland == [0.0 1.0; 2.0 3.0] |
| 93 | + @testset "in place" begin |
| 94 | + accumuland = [1.0 2.0; 3.0 4.0] |
| 95 | + ret = add!!(accumuland, ithunk) |
| 96 | + @test ret == [0.0 1.0; 2.0 3.0] # must return right answer |
| 97 | + @test ret === accumuland # must be same object |
| 98 | + end |
| 99 | + |
| 100 | + @testset "out of place" begin |
| 101 | + accumuland = @SMatrix [1.0 2.0; 3.0 4.0] |
| 102 | + |
| 103 | + ret = add!!(accumuland, ithunk) |
| 104 | + @test ret == [0.0 1.0; 2.0 3.0] # must return right answer |
| 105 | + @test ret !== accumuland # must not be same object |
| 106 | + @test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated |
| 107 | + end |
62 | 108 | end |
63 | 109 | end |
64 | 110 | end |
0 commit comments