Skip to content

Commit 55eb6a4

Browse files
authored
Merge pull request #38 from christiangnrd/kwargs
Address #37
2 parents 14de3f2 + a9e6950 commit 55eb6a4

File tree

6 files changed

+58
-126
lines changed

6 files changed

+58
-126
lines changed

ext/AcceleratedKernelsMetalExt.jl

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,14 @@ import AcceleratedKernels as AK
1010
function AK.accumulate!(
1111
op, v::AbstractArray, backend::MetalBackend;
1212
init,
13-
neutral=AK.neutral_element(op, eltype(v)),
14-
dims::Union{Nothing, Int}=nothing,
15-
inclusive::Bool=true,
16-
17-
# CPU settings - not used
18-
max_tasks::Int=Threads.nthreads(),
19-
min_elems::Int=1,
20-
21-
# Algorithm choice
13+
# Algorithm choice is the only differing default
2214
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),
23-
24-
# GPU settings
25-
block_size::Int=256,
26-
temp::Union{Nothing, AbstractArray}=nothing,
27-
temp_flags::Union{Nothing, AbstractArray}=nothing,
15+
kwargs...
2816
)
2917
AK._accumulate_impl!(
30-
op, v, backend,
31-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
32-
alg=alg,
33-
block_size=block_size, temp=temp, temp_flags=temp_flags,
18+
op, v, backend;
19+
init, alg,
20+
kwargs...
3421
)
3522
end
3623

@@ -39,28 +26,15 @@ end
3926
function AK.accumulate!(
4027
op, dst::AbstractArray, src::AbstractArray, backend::MetalBackend;
4128
init,
42-
neutral=AK.neutral_element(op, eltype(dst)),
43-
dims::Union{Nothing, Int}=nothing,
44-
inclusive::Bool=true,
45-
46-
# CPU settings - not used
47-
max_tasks::Int=Threads.nthreads(),
48-
min_elems::Int=1,
49-
50-
# Algorithm choice
29+
# Algorithm choice is the only differing default
5130
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),
52-
53-
# GPU settings
54-
block_size::Int=256,
55-
temp::Union{Nothing, AbstractArray}=nothing,
56-
temp_flags::Union{Nothing, AbstractArray}=nothing,
31+
kwargs...
5732
)
5833
copyto!(dst, src)
5934
AK._accumulate_impl!(
60-
op, dst, backend,
61-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
62-
alg=alg,
63-
block_size=block_size, temp=temp, temp_flags=temp_flags,
35+
op, dst, backend;
36+
init, alg,
37+
kwargs...
6438
)
6539
end
6640

src/accumulate/accumulate.jl

Lines changed: 19 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -124,58 +124,26 @@ AK.accumulate!(+, v, alg=AK.ScanPrefixes())
124124
function accumulate!(
125125
op, v::AbstractArray, backend::Backend=get_backend(v);
126126
init,
127-
neutral=neutral_element(op, eltype(v)),
128-
dims::Union{Nothing, Int}=nothing,
129-
inclusive::Bool=true,
130-
131-
# CPU settings
132-
max_tasks::Int=Threads.nthreads(),
133-
min_elems::Int=2,
134-
135-
# Algorithm choice
136-
alg::AccumulateAlgorithm=DecoupledLookback(),
137-
138-
# GPU settings
139-
block_size::Int=256,
140-
temp::Union{Nothing, AbstractArray}=nothing,
141-
temp_flags::Union{Nothing, AbstractArray}=nothing,
127+
kwargs...
142128
)
143129
_accumulate_impl!(
144-
op, v, backend,
145-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
146-
max_tasks=max_tasks, min_elems=min_elems,
147-
alg=alg,
148-
block_size=block_size, temp=temp, temp_flags=temp_flags,
130+
op, v, backend;
131+
init,
132+
kwargs...
149133
)
150134
end
151135

152136

153137
function accumulate!(
154138
op, dst::AbstractArray, src::AbstractArray, backend::Backend=get_backend(dst);
155139
init,
156-
neutral=neutral_element(op, eltype(dst)),
157-
dims::Union{Nothing, Int}=nothing,
158-
inclusive::Bool=true,
159-
160-
# CPU settings
161-
max_tasks::Int=Threads.nthreads(),
162-
min_elems::Int=2,
163-
164-
# Algorithm choice
165-
alg::AccumulateAlgorithm=DecoupledLookback(),
166-
167-
# GPU settings
168-
block_size::Int=256,
169-
temp::Union{Nothing, AbstractArray}=nothing,
170-
temp_flags::Union{Nothing, AbstractArray}=nothing,
140+
kwargs...
171141
)
172142
copyto!(dst, src)
173143
_accumulate_impl!(
174-
op, dst, backend,
175-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
176-
max_tasks=max_tasks, min_elems=min_elems,
177-
alg=alg,
178-
block_size=block_size, temp=temp, temp_flags=temp_flags,
144+
op, dst, backend;
145+
init,
146+
kwargs...
179147
)
180148
end
181149

@@ -200,17 +168,17 @@ function _accumulate_impl!(
200168
)
201169
if isnothing(dims)
202170
return accumulate_1d!(
203-
op, v, backend, alg,
204-
init=init, neutral=neutral, inclusive=inclusive,
205-
max_tasks=max_tasks, min_elems=min_elems,
206-
block_size=block_size, temp=temp, temp_flags=temp_flags,
171+
op, v, backend, alg;
172+
init, neutral, inclusive,
173+
max_tasks, min_elems,
174+
block_size, temp, temp_flags,
207175
)
208176
else
209177
return accumulate_nd!(
210-
op, v, backend,
211-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
212-
max_tasks=max_tasks, min_elems=min_elems,
213-
block_size=block_size,
178+
op, v, backend;
179+
init, neutral, dims, inclusive,
180+
max_tasks, min_elems,
181+
block_size,
214182
)
215183
end
216184
end
@@ -242,31 +210,15 @@ Out-of-place version of [`accumulate!`](@ref).
242210
function accumulate(
243211
op, v::AbstractArray, backend::Backend=get_backend(v);
244212
init,
245-
neutral=neutral_element(op, eltype(v)),
246-
dims::Union{Nothing, Int}=nothing,
247-
inclusive::Bool=true,
248-
249-
# CPU settings
250-
max_tasks::Int=Threads.nthreads(),
251-
min_elems::Int=2,
252-
253-
# Algorithm choice
254-
alg::AccumulateAlgorithm=DecoupledLookback(),
255-
256-
# GPU settings
257-
block_size::Int=256,
258-
temp::Union{Nothing, AbstractArray}=nothing,
259-
temp_flags::Union{Nothing, AbstractArray}=nothing,
213+
kwargs...
260214
)
261215
dst_type = Base.promote_op(op, eltype(v), typeof(init))
262216
vcopy = similar(v, dst_type)
263217
copyto!(vcopy, v)
264218
accumulate!(
265219
op, vcopy, backend;
266-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
267-
max_tasks=max_tasks, min_elems=min_elems,
268-
alg=alg,
269-
block_size=block_size, temp=temp, temp_flags=temp_flags,
220+
init,
221+
kwargs...
270222
)
271223
vcopy
272224
end

src/accumulate/accumulate_1d_cpu.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@ function accumulate_1d!(
22
op, v::AbstractArray, backend::CPU, alg;
33
init,
44
neutral,
5-
inclusive::Bool=true,
5+
inclusive::Bool,
66

77
# CPU settings
8-
max_tasks::Int=Threads.nthreads(),
9-
min_elems::Int=2,
8+
max_tasks::Int,
9+
min_elems::Int,
1010

1111
# GPU settings - not used
12-
block_size::Int=256,
13-
temp::Union{Nothing, AbstractArray}=nothing,
14-
temp_flags::Union{Nothing, AbstractArray}=nothing,
12+
block_size::Int,
13+
temp::Union{Nothing, AbstractArray},
14+
temp_flags::Union{Nothing, AbstractArray},
1515
)
1616
# Trivial case
1717
if length(v) == 0

src/accumulate/accumulate_1d_gpu.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -252,16 +252,16 @@ function accumulate_1d!(
252252
op, v::AbstractArray, backend::GPU, ::DecoupledLookback;
253253
init,
254254
neutral,
255-
inclusive::Bool=true,
255+
inclusive::Bool,
256256

257257
# CPU settings - not used
258-
max_tasks::Int=Threads.nthreads(),
259-
min_elems::Int=1,
258+
max_tasks::Int,
259+
min_elems::Int,
260260

261261
# GPU settings
262-
block_size::Int=256,
263-
temp::Union{Nothing, AbstractArray}=nothing,
264-
temp_flags::Union{Nothing, AbstractArray}=nothing,
262+
block_size::Int,
263+
temp::Union{Nothing, AbstractArray},
264+
temp_flags::Union{Nothing, AbstractArray},
265265
)
266266
# Correctness checks
267267
@argcheck block_size > 0
@@ -311,16 +311,16 @@ function accumulate_1d!(
311311
op, v::AbstractArray, backend::GPU, ::ScanPrefixes;
312312
init,
313313
neutral,
314-
inclusive::Bool=true,
314+
inclusive::Bool,
315315

316316
# CPU settings - not used
317-
max_tasks::Int=Threads.nthreads(),
318-
min_elems::Int=1,
317+
max_tasks::Int,
318+
min_elems::Int,
319319

320320
# GPU settings
321-
block_size::Int=256,
322-
temp::Union{Nothing, AbstractArray}=nothing,
323-
temp_flags::Union{Nothing, AbstractArray}=nothing,
321+
block_size::Int,
322+
temp::Union{Nothing, AbstractArray},
323+
temp_flags::Union{Nothing, AbstractArray},
324324
)
325325
# Correctness checks
326326
@argcheck block_size > 0

src/accumulate/accumulate_nd.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
function accumulate_nd!(
22
op, v::AbstractArray, backend::Backend;
33
init,
4-
neutral=neutral_element(op, eltype(v)),
4+
neutral,
55
dims::Int,
6-
inclusive::Bool=true,
6+
inclusive::Bool,
77

88
# CPU settings
9-
max_tasks::Int=Threads.nthreads(),
10-
min_elems::Int=1,
9+
max_tasks::Int,
10+
min_elems::Int,
1111

1212
# GPU settings
13-
block_size::Int=256,
13+
block_size::Int,
1414
)
1515
# Correctness checks
1616
@argcheck block_size > 0

test/accumulate.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@
8181
AK.accumulate!(+, y; init=Int32(init), inclusive=false)
8282
@test all(Array(y) .== 10:19)
8383

84+
# Test that undefined kwargs are not accepted
85+
@test_throws MethodError AK.accumulate(+, y; init=10, dims=2, inclusive=false, bad=:kwarg)
86+
8487
# Testing different settings
8588
AK.accumulate!(+, array_from_host(ones(Int32, 1000)), init=0, inclusive=false,
8689
block_size=128,
@@ -186,6 +189,9 @@ end
186189
sh = Array(s)
187190
@test all([sh[i, :] == 10:19 for i in 1:10])
188191

192+
# Test that undefined kwargs are not accepted
193+
@test_throws MethodError AK.accumulate(+, v; init=10, dims=2, inclusive=false, bad=:kwarg)
194+
189195
# Testing different settings
190196
AK.accumulate(
191197
(x, y) -> x + 1,

0 commit comments

Comments
 (0)