|
1 | | -@testset "accumulate_1d" begin |
| 1 | +ALGS = AK.AccumulateAlgorithm[AK.ScanPrefixes()] |
| 2 | + |
| 3 | +@isdefined(TEST_DL) && TEST_DL[] && push!(ALGS, AK.DecoupledLookback()) |
| 4 | + |
| 5 | +@testset "accumulate_1d $(alg isa AK.DecoupledLookback ? "DL" : "SP")" for alg in ALGS |
2 | 6 |
|
3 | 7 | Random.seed!(0) |
4 | 8 |
|
5 | 9 | # Single block exlusive scan (each block processes two elements) |
6 | 10 | for num_elems in 1:256 |
7 | 11 | x = array_from_host(ones(Int32, num_elems)) |
8 | 12 | y = copy(x) |
9 | | - AK.accumulate!(+, y; init=0, inclusive=false, block_size=128) |
| 13 | + AK.accumulate!(+, y; init=0, inclusive=false, block_size=128, alg) |
10 | 14 | yh = Array(y) |
11 | 15 | @test all(yh .== 0:length(yh) - 1) |
12 | 16 | end |
|
15 | 19 | for num_elems in 1:256 |
16 | 20 | x = array_from_host(rand(1:1000, num_elems), Int32) |
17 | 21 | y = copy(x) |
18 | | - AK.accumulate!(+, y; init=0, block_size=128) |
| 22 | + AK.accumulate!(+, y; init=0, block_size=128, alg) |
19 | 23 | @test all(Array(y) .== accumulate(+, Array(x))) |
20 | 24 | end |
21 | 25 |
|
|
24 | 28 | num_elems = rand(1:100_000) |
25 | 29 | x = array_from_host(ones(Int32, num_elems)) |
26 | 30 | y = copy(x) |
27 | | - AK.accumulate!(+, y; init=0, inclusive=false) |
| 31 | + AK.accumulate!(+, y; init=0, inclusive=false, alg) |
28 | 32 | yh = Array(y) |
29 | 33 | @test all(yh .== 0:length(yh) - 1) |
30 | 34 | end |
|
34 | 38 | num_elems = rand(1:100_000) |
35 | 39 | x = array_from_host(rand(1:1000, num_elems), Int32) |
36 | 40 | y = copy(x) |
37 | | - AK.accumulate!(+, y; init=0) |
| 41 | + AK.accumulate!(+, y; init=0, alg) |
38 | 42 | @test all(Array(y) .== accumulate(+, Array(x))) |
39 | 43 | end |
40 | 44 |
|
|
43 | 47 | num_elems = rand(1:100_000) |
44 | 48 | x = array_from_host(rand(1:1000, num_elems), Int32) |
45 | 49 | y = copy(x) |
46 | | - AK.accumulate!(+, y; init=0, block_size=16) |
| 50 | + AK.accumulate!(+, y; init=0, block_size=16, alg) |
47 | 51 | @test all(Array(y) .== accumulate(+, Array(x))) |
48 | 52 | end |
49 | 53 |
|
|
54 | 58 | n3 = rand(1:100) |
55 | 59 | vh = rand(Float32, n1, n2, n3) |
56 | 60 | v = array_from_host(vh) |
57 | | - AK.accumulate!(+, v; init=0) |
| 61 | + AK.accumulate!(+, v; init=0, alg) |
58 | 62 | @test all(Array(v) .≈ accumulate(+, vh)) |
59 | 63 | end |
60 | 64 |
|
|
64 | 68 | x = array_from_host(rand(1:1000, num_elems), Int32) |
65 | 69 | y = similar(x) |
66 | 70 | init = rand(-1000:1000) |
67 | | - AK.accumulate!(+, y, x; init=Int32(init)) |
| 71 | + AK.accumulate!(+, y, x; init=Int32(init), alg) |
68 | 72 | @test all(Array(y) .== accumulate(+, Array(x); init)) |
69 | 73 | end |
70 | 74 |
|
71 | 75 | # Exclusive scan |
72 | 76 | x = array_from_host(ones(Int32, 10)) |
73 | 77 | y = copy(x) |
74 | | - AK.accumulate!(+, y; init=0, inclusive=false) |
| 78 | + AK.accumulate!(+, y; init=0, inclusive=false, alg) |
75 | 79 | @test all(Array(y) .== 0:9) |
76 | 80 |
|
77 | 81 | # Test init value is respected with exclusive scan too |
78 | 82 | x = array_from_host(ones(Int32, 10)) |
79 | 83 | y = copy(x) |
80 | 84 | init = 10 |
81 | | - AK.accumulate!(+, y; init=Int32(init), inclusive=false) |
| 85 | + AK.accumulate!(+, y; init=Int32(init), inclusive=false, alg) |
82 | 86 | @test all(Array(y) .== 10:19) |
83 | 87 |
|
84 | 88 | # Test that undefined kwargs are not accepted |
85 | 89 | @test_throws MethodError AK.accumulate(+, y; init=10, dims=2, inclusive=false, bad=:kwarg) |
86 | 90 |
|
87 | 91 | # Testing different settings |
88 | | - AK.accumulate!(+, array_from_host(ones(Int32, 1000)), init=0, inclusive=false, |
89 | | - block_size=128, |
| 92 | + AK.accumulate!(+, array_from_host(ones(Int32, 1000)); init=0, inclusive=false, |
| 93 | + block_size=128, alg, |
90 | 94 | temp=array_from_host(zeros(Int32, 1000)), |
91 | 95 | temp_flags=array_from_host(zeros(Int8, 1000))) |
92 | | - AK.accumulate(+, array_from_host(ones(Int32, 1000)), init=0, inclusive=false, |
93 | | - block_size=128, |
| 96 | + AK.accumulate(+, array_from_host(ones(Int32, 1000)); init=0, inclusive=false, |
| 97 | + block_size=128, alg, |
94 | 98 | temp=array_from_host(zeros(Int64, 1000)), |
95 | 99 | temp_flags=array_from_host(zeros(Int8, 1000))) |
96 | 100 | end |
|
0 commit comments