Skip to content

Commit 92beead

Browse files
authored
Merge pull request #39 from christiangnrd/kwargs2
[NFC] Reduce kwarg duplication
2 parents 99c2821 + 207e77f commit 92beead

22 files changed

+244
-644
lines changed

ext/AcceleratedKernelsMetalExt.jl

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -38,73 +38,4 @@ function AK.accumulate!(
3838
)
3939
end
4040

41-
42-
function AK.cumsum(
43-
src::AbstractArray, backend::MetalBackend;
44-
init=zero(eltype(src)),
45-
neutral=zero(eltype(src)),
46-
dims::Union{Nothing, Int}=nothing,
47-
48-
# CPU settings - not used
49-
max_tasks::Int=Threads.nthreads(),
50-
min_elems::Int=1,
51-
52-
# Algorithm choice
53-
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),
54-
55-
# GPU settings
56-
block_size::Int=256,
57-
temp::Union{Nothing, AbstractArray}=nothing,
58-
temp_flags::Union{Nothing, AbstractArray}=nothing,
59-
)
60-
AK.accumulate(
61-
+, src, backend;
62-
init=init,
63-
neutral=neutral,
64-
dims=dims,
65-
inclusive=true,
66-
67-
alg=alg,
68-
69-
block_size=block_size,
70-
temp=temp,
71-
temp_flags=temp_flags,
72-
)
73-
end
74-
75-
76-
function AK.cumprod(
77-
src::AbstractArray, backend::MetalBackend;
78-
init=one(eltype(src)),
79-
neutral=one(eltype(src)),
80-
dims::Union{Nothing, Int}=nothing,
81-
82-
# CPU settings - not used
83-
max_tasks::Int=Threads.nthreads(),
84-
min_elems::Int=1,
85-
86-
# Algorithm choice
87-
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),
88-
89-
# GPU settings
90-
block_size::Int=256,
91-
temp::Union{Nothing, AbstractArray}=nothing,
92-
temp_flags::Union{Nothing, AbstractArray}=nothing,
93-
)
94-
AK.accumulate(
95-
*, src, backend;
96-
init=init,
97-
neutral=neutral,
98-
dims=dims,
99-
inclusive=true,
100-
101-
alg=alg,
102-
103-
block_size=block_size,
104-
temp=temp,
105-
temp_flags=temp_flags,
106-
)
107-
end
108-
109-
11041
end # module AcceleratedKernelsMetalExt

ext/AcceleratedKernelsoneAPIExt.jl

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,12 @@ function AK.any(
1212

1313
# Algorithm choice
1414
alg::AK.PredicatesAlgorithm=AK.MapReduce(),
15-
16-
# CPU settings
17-
max_tasks=Threads.nthreads(),
18-
min_elems=1,
19-
20-
# GPU settings
21-
block_size::Int=256,
15+
kwargs...
2216
)
2317
AK._any_impl(
2418
pred, v, backend;
25-
alg=alg,
26-
max_tasks=max_tasks,
27-
min_elems=min_elems,
28-
block_size=block_size,
19+
alg,
20+
kwargs...
2921
)
3022
end
3123

@@ -35,20 +27,12 @@ function AK.all(
3527

3628
# Algorithm choice
3729
alg::AK.PredicatesAlgorithm=AK.MapReduce(),
38-
39-
# CPU settings
40-
max_tasks=Threads.nthreads(),
41-
min_elems=1,
42-
43-
# GPU settings
44-
block_size::Int=256,
30+
kwargs...
4531
)
4632
AK._all_impl(
4733
pred, v, backend;
48-
alg=alg,
49-
max_tasks=max_tasks,
50-
min_elems=min_elems,
51-
block_size=block_size,
34+
alg,
35+
kwargs...
5236
)
5337
end
5438

src/accumulate/accumulate_1d_cpu.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ function accumulate_1d!(
3636
if itask == 1
3737
_accumulate_1d_cpu_section!(
3838
op, @view(v[irange]);
39-
init=init,
40-
inclusive=inclusive,
39+
init, inclusive,
4140
)
4241
else
4342
# Later sections should always be inclusively accumulated

src/arithmetics.jl

Lines changed: 24 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -44,28 +44,12 @@ s = AK.sum(m, dims=2, temp=temp)
4444
function sum(
4545
src::AbstractArray, backend::Backend=get_backend(src);
4646
init=zero(eltype(src)),
47-
dims::Union{Nothing, Int}=nothing,
48-
49-
# CPU settings
50-
max_tasks=Threads.nthreads(),
51-
min_elems=1,
52-
53-
# GPU settings
54-
block_size::Int=256,
55-
temp::Union{Nothing, AbstractArray}=nothing,
56-
switch_below::Int=0,
47+
kwargs...
5748
)
5849
reduce(
5950
+, src, backend;
60-
init=init,
61-
dims=dims,
62-
63-
max_tasks=max_tasks,
64-
min_elems=min_elems,
65-
66-
block_size=block_size,
67-
temp=temp,
68-
switch_below=switch_below,
51+
init,
52+
kwargs...
6953
)
7054
end
7155

@@ -116,28 +100,12 @@ p = AK.prod(m, dims=2, temp=temp)
116100
function prod(
117101
src::AbstractArray, backend::Backend=get_backend(src);
118102
init=one(eltype(src)),
119-
dims::Union{Nothing, Int}=nothing,
120-
121-
# CPU settings
122-
max_tasks=Threads.nthreads(),
123-
min_elems=1,
124-
125-
# GPU settings
126-
block_size::Int=256,
127-
temp::Union{Nothing, AbstractArray}=nothing,
128-
switch_below::Int=0,
103+
kwargs...
129104
)
130105
reduce(
131106
*, src, backend;
132-
init=init,
133-
dims=dims,
134-
135-
max_tasks=max_tasks,
136-
min_elems=min_elems,
137-
138-
block_size=block_size,
139-
temp=temp,
140-
switch_below=switch_below,
107+
init,
108+
kwargs...
141109
)
142110
end
143111

@@ -188,28 +156,12 @@ m = AK.maximum(m, dims=2, temp=temp)
188156
function maximum(
189157
src::AbstractArray, backend::Backend=get_backend(src);
190158
init=typemin(eltype(src)),
191-
dims::Union{Nothing, Int}=nothing,
192-
193-
# CPU settings
194-
max_tasks=Threads.nthreads(),
195-
min_elems=1,
196-
197-
# GPU settings
198-
block_size::Int=256,
199-
temp::Union{Nothing, AbstractArray}=nothing,
200-
switch_below::Int=0,
159+
kwargs...
201160
)
202161
reduce(
203162
max, src, backend;
204-
init=init,
205-
dims=dims,
206-
207-
max_tasks=max_tasks,
208-
min_elems=min_elems,
209-
210-
block_size=block_size,
211-
temp=temp,
212-
switch_below=switch_below,
163+
init,
164+
kwargs...
213165
)
214166
end
215167

@@ -260,28 +212,12 @@ m = AK.minimum(m, dims=2, temp=temp)
260212
function minimum(
261213
src::AbstractArray, backend::Backend=get_backend(src);
262214
init=typemax(eltype(src)),
263-
dims::Union{Nothing, Int}=nothing,
264-
265-
# CPU settings
266-
max_tasks=Threads.nthreads(),
267-
min_elems=1,
268-
269-
# GPU settings
270-
block_size::Int=256,
271-
temp::Union{Nothing, AbstractArray}=nothing,
272-
switch_below::Int=0,
215+
kwargs...
273216
)
274217
reduce(
275218
min, src, backend;
276-
init=init,
277-
dims=dims,
278-
279-
max_tasks=max_tasks,
280-
min_elems=min_elems,
281-
282-
block_size=block_size,
283-
temp=temp,
284-
switch_below=switch_below,
219+
init,
220+
kwargs...
285221
)
286222
end
287223

@@ -338,59 +274,27 @@ c = AK.count(m, dims=2, temp=temp)
338274
function count(
339275
src::AbstractArray, backend::Backend=get_backend(src);
340276
init=0,
341-
dims::Union{Nothing, Int}=nothing,
342-
343-
# CPU settings
344-
max_tasks=Threads.nthreads(),
345-
min_elems=1,
346-
347-
# GPU settings
348-
block_size::Int=256,
349-
temp::Union{Nothing, AbstractArray}=nothing,
350-
switch_below::Int=0,
277+
kwargs...
351278
)
352279
mapreduce(
353280
x -> x ? one(typeof(init)) : zero(typeof(init)), +, src, backend;
354-
init=init,
281+
init,
355282
neutral=zero(typeof(init)),
356-
dims=dims,
357-
358-
max_tasks=max_tasks,
359-
min_elems=min_elems,
360-
361-
block_size=block_size,
362-
temp=temp,
363-
switch_below=switch_below,
283+
kwargs...
364284
)
365285
end
366286

367287

368288
function count(
369289
f, src::AbstractArray, backend::Backend=get_backend(src);
370290
init=0,
371-
dims::Union{Nothing, Int}=nothing,
372-
373-
# CPU settings
374-
max_tasks=Threads.nthreads(),
375-
min_elems=1,
376-
377-
# GPU settings
378-
block_size::Int=256,
379-
temp::Union{Nothing, AbstractArray}=nothing,
380-
switch_below::Int=0,
291+
kwargs...
381292
)
382293
mapreduce(
383294
x -> f(x) ? one(typeof(init)) : zero(typeof(init)), +, src, backend;
384-
init=init,
295+
init,
385296
neutral=zero(typeof(init)),
386-
dims=dims,
387-
388-
max_tasks=max_tasks,
389-
min_elems=min_elems,
390-
391-
block_size=block_size,
392-
temp=temp,
393-
switch_below=switch_below,
297+
kwargs...
394298
)
395299
end
396300

@@ -437,28 +341,13 @@ function cumsum(
437341
src::AbstractArray, backend::Backend=get_backend(src);
438342
init=zero(eltype(src)),
439343
neutral=zero(eltype(src)),
440-
dims::Union{Nothing, Int}=nothing,
441-
442-
# Algorithm choice
443-
alg::AccumulateAlgorithm=DecoupledLookback(),
444-
445-
# GPU settings
446-
block_size::Int=256,
447-
temp::Union{Nothing, AbstractArray}=nothing,
448-
temp_flags::Union{Nothing, AbstractArray}=nothing,
344+
kwargs...
449345
)
450346
accumulate(
451347
+, src, backend;
452-
init=init,
453-
neutral=neutral,
454-
dims=dims,
348+
init, neutral,
455349
inclusive=true,
456-
457-
alg=alg,
458-
459-
block_size=block_size,
460-
temp=temp,
461-
temp_flags=temp_flags,
350+
kwargs...
462351
)
463352
end
464353

@@ -505,27 +394,12 @@ function cumprod(
505394
src::AbstractArray, backend::Backend=get_backend(src);
506395
init=one(eltype(src)),
507396
neutral=one(eltype(src)),
508-
dims::Union{Nothing, Int}=nothing,
509-
510-
# Algorithm choice
511-
alg::AccumulateAlgorithm=DecoupledLookback(),
512-
513-
# GPU settings
514-
block_size::Int=256,
515-
temp::Union{Nothing, AbstractArray}=nothing,
516-
temp_flags::Union{Nothing, AbstractArray}=nothing,
397+
kwargs...
517398
)
518399
accumulate(
519400
*, src, backend;
520-
init=init,
521-
neutral=neutral,
522-
dims=dims,
401+
init, neutral,
523402
inclusive=true,
524-
525-
alg=alg,
526-
527-
block_size=block_size,
528-
temp=temp,
529-
temp_flags=temp_flags,
403+
kwargs...
530404
)
531405
end

0 commit comments

Comments
 (0)