Skip to content

Commit 8f333b6

Browse files
authored
Merge pull request #55 from christiangnrd/redgroup
Deduplicate `reduce_group`
2 parents 0060568 + 1d7b59d commit 8f333b6

File tree

4 files changed

+61
-90
lines changed

4 files changed

+61
-90
lines changed

src/AcceleratedKernels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module AcceleratedKernels
1414
using ArgCheck: @argcheck
1515
using GPUArraysCore: AnyGPUArray, @allowscalar
1616
using KernelAbstractions
17+
using KernelAbstractions: @context
1718
import UnsafeAtomics
1819

1920

src/reduce/mapreduce_1d_gpu.jl

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,60 +25,7 @@
2525

2626
@synchronize()
2727

28-
if N >= 512u16
29-
if ithread < 256u16
30-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])
31-
end
32-
@synchronize()
33-
end
34-
if N >= 256u16
35-
if ithread < 128u16
36-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])
37-
end
38-
@synchronize()
39-
end
40-
if N >= 128u16
41-
if ithread < 64u16
42-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])
43-
end
44-
@synchronize()
45-
end
46-
if N >= 64u16
47-
if ithread < 32u16
48-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])
49-
end
50-
@synchronize()
51-
end
52-
if N >= 32u16
53-
if ithread < 16u16
54-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])
55-
end
56-
@synchronize()
57-
end
58-
if N >= 16u16
59-
if ithread < 8u16
60-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])
61-
end
62-
@synchronize()
63-
end
64-
if N >= 8u16
65-
if ithread < 4u16
66-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])
67-
end
68-
@synchronize()
69-
end
70-
if N >= 4u16
71-
if ithread < 2u16
72-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])
73-
end
74-
@synchronize()
75-
end
76-
if N >= 2u16
77-
if ithread < 1u16
78-
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 0x1 + 0x1])
79-
end
80-
@synchronize()
81-
end
28+
@inline reduce_group!(@context, op, sdata, N, ithread)
8229

8330
# Code below would work on NVidia GPUs with warp size of 32, but create race conditions and
8431
# return incorrect results on Intel Graphics. It would be useful to have a way to statically

src/reduce/mapreduce_nd.jl

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -332,42 +332,8 @@ end
332332
sdata[ithread + 0x1] = partial
333333
@synchronize()
334334

335-
if N >= 512u16
336-
ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]))
337-
@synchronize()
338-
end
339-
if N >= 256u16
340-
ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]))
341-
@synchronize()
342-
end
343-
if N >= 128u16
344-
ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]))
345-
@synchronize()
346-
end
347-
if N >= 64u16
348-
ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]))
349-
@synchronize()
350-
end
351-
if N >= 32u16
352-
ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]))
353-
@synchronize()
354-
end
355-
if N >= 16u16
356-
ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]))
357-
@synchronize()
358-
end
359-
if N >= 8u16
360-
ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]))
361-
@synchronize()
362-
end
363-
if N >= 4u16
364-
ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]))
365-
@synchronize()
366-
end
367-
if N >= 2u16
368-
ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1]))
369-
@synchronize()
370-
end
335+
@inline reduce_group!(@context, op, sdata, N, ithread)
336+
371337
if ithread == 0x0
372338
dst[iblock + 0x1] = op(init, sdata[0x1])
373339
end

src/reduce/utilities.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,60 @@ function _mapreduce_nd_apply_init!(
4343
dst[i] = op(init, f(src[i]))
4444
end
4545
end
46+
47+
@inline function reduce_group!(@context, op, sdata, N, ithread)
48+
if N >= 512u16
49+
if ithread < 256u16
50+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])
51+
end
52+
@synchronize()
53+
end
54+
if N >= 256u16
55+
if ithread < 128u16
56+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])
57+
end
58+
@synchronize()
59+
end
60+
if N >= 128u16
61+
if ithread < 64u16
62+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])
63+
end
64+
@synchronize()
65+
end
66+
if N >= 64u16
67+
if ithread < 32u16
68+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])
69+
end
70+
@synchronize()
71+
end
72+
if N >= 32u16
73+
if ithread < 16u16
74+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])
75+
end
76+
@synchronize()
77+
end
78+
if N >= 16u16
79+
if ithread < 8u16
80+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])
81+
end
82+
@synchronize()
83+
end
84+
if N >= 8u16
85+
if ithread < 4u16
86+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])
87+
end
88+
@synchronize()
89+
end
90+
if N >= 4u16
91+
if ithread < 2u16
92+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])
93+
end
94+
@synchronize()
95+
end
96+
if N >= 2u16
97+
if ithread < 1u16
98+
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1])
99+
end
100+
@synchronize()
101+
end
102+
end

0 commit comments

Comments
 (0)