Skip to content

Commit c00ce96

Browse files
authored
Merge pull request #49 from christiangnrd/testboth
Test both 1d `accumulate` algorithms when supported
2 parents 09c2c99 + ca2610e commit c00ce96

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

test/accumulate.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
@testset "accumulate_1d" begin
1+
ALGS = AK.AccumulateAlgorithm[AK.ScanPrefixes()]
2+
3+
@isdefined(TEST_DL) && TEST_DL[] && push!(ALGS, AK.DecoupledLookback())
4+
5+
@testset "accumulate_1d $(alg isa AK.DecoupledLookback ? "DL" : "SP")" for alg in ALGS
26

37
Random.seed!(0)
48

59
# Single block exlusive scan (each block processes two elements)
610
for num_elems in 1:256
711
x = array_from_host(ones(Int32, num_elems))
812
y = copy(x)
9-
AK.accumulate!(+, y; init=0, inclusive=false, block_size=128)
13+
AK.accumulate!(+, y; init=0, inclusive=false, block_size=128, alg)
1014
yh = Array(y)
1115
@test all(yh .== 0:length(yh) - 1)
1216
end
@@ -15,7 +19,7 @@
1519
for num_elems in 1:256
1620
x = array_from_host(rand(1:1000, num_elems), Int32)
1721
y = copy(x)
18-
AK.accumulate!(+, y; init=0, block_size=128)
22+
AK.accumulate!(+, y; init=0, block_size=128, alg)
1923
@test all(Array(y) .== accumulate(+, Array(x)))
2024
end
2125

@@ -24,7 +28,7 @@
2428
num_elems = rand(1:100_000)
2529
x = array_from_host(ones(Int32, num_elems))
2630
y = copy(x)
27-
AK.accumulate!(+, y; init=0, inclusive=false)
31+
AK.accumulate!(+, y; init=0, inclusive=false, alg)
2832
yh = Array(y)
2933
@test all(yh .== 0:length(yh) - 1)
3034
end
@@ -34,7 +38,7 @@
3438
num_elems = rand(1:100_000)
3539
x = array_from_host(rand(1:1000, num_elems), Int32)
3640
y = copy(x)
37-
AK.accumulate!(+, y; init=0)
41+
AK.accumulate!(+, y; init=0, alg)
3842
@test all(Array(y) .== accumulate(+, Array(x)))
3943
end
4044

@@ -43,7 +47,7 @@
4347
num_elems = rand(1:100_000)
4448
x = array_from_host(rand(1:1000, num_elems), Int32)
4549
y = copy(x)
46-
AK.accumulate!(+, y; init=0, block_size=16)
50+
AK.accumulate!(+, y; init=0, block_size=16, alg)
4751
@test all(Array(y) .== accumulate(+, Array(x)))
4852
end
4953

@@ -54,7 +58,7 @@
5458
n3 = rand(1:100)
5559
vh = rand(Float32, n1, n2, n3)
5660
v = array_from_host(vh)
57-
AK.accumulate!(+, v; init=0)
61+
AK.accumulate!(+, v; init=0, alg)
5862
@test all(Array(v) .≈ accumulate(+, vh))
5963
end
6064

@@ -64,33 +68,33 @@
6468
x = array_from_host(rand(1:1000, num_elems), Int32)
6569
y = similar(x)
6670
init = rand(-1000:1000)
67-
AK.accumulate!(+, y, x; init=Int32(init))
71+
AK.accumulate!(+, y, x; init=Int32(init), alg)
6872
@test all(Array(y) .== accumulate(+, Array(x); init))
6973
end
7074

7175
# Exclusive scan
7276
x = array_from_host(ones(Int32, 10))
7377
y = copy(x)
74-
AK.accumulate!(+, y; init=0, inclusive=false)
78+
AK.accumulate!(+, y; init=0, inclusive=false, alg)
7579
@test all(Array(y) .== 0:9)
7680

7781
# Test init value is respected with exclusive scan too
7882
x = array_from_host(ones(Int32, 10))
7983
y = copy(x)
8084
init = 10
81-
AK.accumulate!(+, y; init=Int32(init), inclusive=false)
85+
AK.accumulate!(+, y; init=Int32(init), inclusive=false, alg)
8286
@test all(Array(y) .== 10:19)
8387

8488
# Test that undefined kwargs are not accepted
8589
@test_throws MethodError AK.accumulate(+, y; init=10, dims=2, inclusive=false, bad=:kwarg)
8690

8791
# Testing different settings
88-
AK.accumulate!(+, array_from_host(ones(Int32, 1000)), init=0, inclusive=false,
89-
block_size=128,
92+
AK.accumulate!(+, array_from_host(ones(Int32, 1000)); init=0, inclusive=false,
93+
block_size=128, alg,
9094
temp=array_from_host(zeros(Int32, 1000)),
9195
temp_flags=array_from_host(zeros(Int8, 1000)))
92-
AK.accumulate(+, array_from_host(ones(Int32, 1000)), init=0, inclusive=false,
93-
block_size=128,
96+
AK.accumulate(+, array_from_host(ones(Int32, 1000)); init=0, inclusive=false,
97+
block_size=128, alg,
9498
temp=array_from_host(zeros(Int64, 1000)),
9599
temp_flags=array_from_host(zeros(Int8, 1000)))
96100
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using Test
44
using Random
55
import Pkg
66

7+
# Set to true when testing backends that support this
8+
const TEST_DL = Ref{Bool}(false)
79

810
# Pass command-line argument to test suite to install the right backend, e.g.
911
# julia> import Pkg
@@ -13,16 +15,19 @@ if "--CUDA" in ARGS
1315
using CUDA
1416
CUDA.versioninfo()
1517
const BACKEND = CUDABackend()
18+
TEST_DL[] = true
1619
elseif "--oneAPI" in ARGS
1720
Pkg.add("oneAPI")
1821
using oneAPI
1922
oneAPI.versioninfo()
2023
const BACKEND = oneAPIBackend()
24+
TEST_DL[] = true
2125
elseif "--AMDGPU" in ARGS
2226
Pkg.add("AMDGPU")
2327
using AMDGPU
2428
AMDGPU.versioninfo()
2529
const BACKEND = ROCBackend()
30+
TEST_DL[] = true
2631
elseif "--Metal" in ARGS
2732
Pkg.add("Metal")
2833
using Metal
@@ -35,6 +40,7 @@ elseif "--OpenCL" in ARGS
3540
using OpenCL
3641
OpenCL.versioninfo()
3742
const BACKEND = OpenCLBackend()
43+
TEST_DL[] = true
3844
elseif !@isdefined(BACKEND)
3945
# Otherwise do CPU tests
4046
using InteractiveUtils

0 commit comments

Comments
 (0)