Skip to content

Commit 51cd6d6

Browse files
committed
Deduplicate reduce_group
1 parent d997769 commit 51cd6d6

File tree

4 files changed

+61
-90
lines changed

4 files changed

+61
-90
lines changed

src/AcceleratedKernels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module AcceleratedKernels
1414
using ArgCheck: @argcheck
1515
using GPUArraysCore: AnyGPUArray, @allowscalar
1616
using KernelAbstractions
17+
using KernelAbstractions: @context
1718

1819

1920
# Exposed functions from upstream packages

src/reduce/mapreduce_1d_gpu.jl

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,60 +25,7 @@
2525

2626
@synchronize()
2727

28-
if N >= 512u16
29-
if ithread < 256u16
30-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])
31-
end
32-
@synchronize()
33-
end
34-
if N >= 256u16
35-
if ithread < 128u16
36-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])
37-
end
38-
@synchronize()
39-
end
40-
if N >= 128u16
41-
if ithread < 64u16
42-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])
43-
end
44-
@synchronize()
45-
end
46-
if N >= 64u16
47-
if ithread < 32u16
48-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])
49-
end
50-
@synchronize()
51-
end
52-
if N >= 32u16
53-
if ithread < 16u16
54-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])
55-
end
56-
@synchronize()
57-
end
58-
if N >= 16u16
59-
if ithread < 8u16
60-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])
61-
end
62-
@synchronize()
63-
end
64-
if N >= 8u16
65-
if ithread < 4u16
66-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])
67-
end
68-
@synchronize()
69-
end
70-
if N >= 4u16
71-
if ithread < 2u16
72-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])
73-
end
74-
@synchronize()
75-
end
76-
if N >= 2u16
77-
if ithread < 1u16
78-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 0x1 + 0x1])
79-
end
80-
@synchronize()
81-
end
28+
@inline reduce_group!(@context, op, sdata, N, ithread)
8229

8330
# Code below would work on NVidia GPUs with warp size of 32, but create race conditions and
8431
# return incorrect results on Intel Graphics. It would be useful to have a way to statically

src/reduce/mapreduce_nd.jl

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -332,42 +332,8 @@ end
332332
sdata[ithread + 0x1] = partial
333333
@synchronize()
334334

335-
if N >= 512u16
336-
ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]))
337-
@synchronize()
338-
end
339-
if N >= 256u16
340-
ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]))
341-
@synchronize()
342-
end
343-
if N >= 128u16
344-
ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]))
345-
@synchronize()
346-
end
347-
if N >= 64u16
348-
ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]))
349-
@synchronize()
350-
end
351-
if N >= 32u16
352-
ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]))
353-
@synchronize()
354-
end
355-
if N >= 16u16
356-
ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]))
357-
@synchronize()
358-
end
359-
if N >= 8u16
360-
ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]))
361-
@synchronize()
362-
end
363-
if N >= 4u16
364-
ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]))
365-
@synchronize()
366-
end
367-
if N >= 2u16
368-
ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1]))
369-
@synchronize()
370-
end
335+
@inline reduce_group!(@context, op, sdata, N, ithread)
336+
371337
if ithread == 0x0
372338
dst[iblock + 0x1] = op(init, sdata[0x1])
373339
end

src/reduce/utilities.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,60 @@ function _mapreduce_nd_apply_init!(
4242
dst[i] = op(init, f(src[i]))
4343
end
4444
end
45+
46+
@inline function reduce_group!(@context, op, sdata, N, ithread)
47+
if N >= 512u16
48+
if ithread < 256u16
49+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])
50+
end
51+
@synchronize()
52+
end
53+
if N >= 256u16
54+
if ithread < 128u16
55+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])
56+
end
57+
@synchronize()
58+
end
59+
if N >= 128u16
60+
if ithread < 64u16
61+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])
62+
end
63+
@synchronize()
64+
end
65+
if N >= 64u16
66+
if ithread < 32u16
67+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])
68+
end
69+
@synchronize()
70+
end
71+
if N >= 32u16
72+
if ithread < 16u16
73+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])
74+
end
75+
@synchronize()
76+
end
77+
if N >= 16u16
78+
if ithread < 8u16
79+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])
80+
end
81+
@synchronize()
82+
end
83+
if N >= 8u16
84+
if ithread < 4u16
85+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])
86+
end
87+
@synchronize()
88+
end
89+
if N >= 4u16
90+
if ithread < 2u16
91+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])
92+
end
93+
@synchronize()
94+
end
95+
if N >= 2u16
96+
if ithread < 1u16
97+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1])
98+
end
99+
@synchronize()
100+
end
101+
end

0 commit comments

Comments
 (0)