Skip to content

Commit 8596ad6

Browse files
committed
Update bitonic_sort.hlsl
1 parent 69e178a commit 8596ad6

File tree

1 file changed

+54
-204
lines changed

1 file changed

+54
-204
lines changed

include/nbl/builtin/hlsl/workgroup/bitonic_sort.hlsl

Lines changed: 54 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "nbl/builtin/hlsl/bit.hlsl"
88
#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
99
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
10+
#include "nbl/builtin/hlsl/concepts/accessors/bitonic_sort.hlsl"
1011

1112
namespace nbl
1213
{
@@ -17,18 +18,19 @@ namespace workgroup
1718
namespace bitonic_sort
1819
{
1920

20-
template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
21+
template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator = less<KeyType> NBL_PRIMARY_REQUIRES(_ElementsPerInvocationLog2 >= 1 && _WorkgroupSizeLog2 >= 5)
2122
struct bitonic_sort_config
2223
{
2324
using key_t = KeyType;
2425
using value_t = ValueType;
2526
using comparator_t = Comparator;
26-
2727
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
2828
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
2929

3030
NBL_CONSTEXPR_STATIC_INLINE uint32_t ElementsPerInvocation = 1u << ElementsPerInvocationLog2;
3131
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = 1u << WorkgroupSizeLog2;
32+
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedmemDWORDs = sizeof(pair<key_t, value_t>) / sizeof(uint32_t) * WorkgroupSize;
33+
3234
};
3335
}
3436

@@ -48,56 +50,38 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
4850
using SortConfig = subgroup::bitonic_sort_config<key_t, value_t, comparator_t>;
4951

5052
template<typename SharedMemoryAccessor>
51-
static void mergeStage(NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
52-
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
53+
static void mergeStage(NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID,
54+
NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) loPair, NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) hiPair)
5355
{
5456
const uint32_t WorkgroupSize = config_t::WorkgroupSize;
55-
using key_adaptor = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, key_t, uint32_t, 1, WorkgroupSize>;
56-
using value_adaptor = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, value_t, uint32_t, 1, WorkgroupSize>;
57-
58-
key_adaptor sharedmemAdaptorKey;
59-
sharedmemAdaptorKey.accessor = sharedmemAccessor;
60-
61-
value_adaptor sharedmemAdaptorValue;
62-
sharedmemAdaptorValue.accessor = sharedmemAccessor;
63-
6457
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
6558
comparator_t comp;
6659

6760
[unroll]
6861
for (uint32_t pass = 0; pass <= stage; pass++)
6962
{
7063
if (pass)
71-
sharedmemAdaptorValue.workgroupExecutionAndMemoryBarrier();
64+
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
7265

7366
const uint32_t stridePower = (stage - pass + 1) + subgroupSizeLog2;
7467
const uint32_t stride = 1u << stridePower;
7568
const uint32_t threadStride = stride >> 1;
7669

77-
key_t pLoKey = loKey;
78-
shuffleXor(pLoKey, threadStride, sharedmemAdaptorKey);
79-
sharedmemAdaptorKey.workgroupExecutionAndMemoryBarrier();
80-
81-
value_t pLoVal = loVal;
82-
shuffleXor(pLoVal, threadStride, sharedmemAdaptorValue);
83-
sharedmemAdaptorValue.workgroupExecutionAndMemoryBarrier();
70+
nbl::hlsl::pair<key_t, value_t> pLoPair = loPair;
71+
shuffleXor(pLoPair, threadStride, sharedmemAccessor);
72+
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
8473

85-
key_t pHiKey = hiKey;
86-
shuffleXor(pHiKey, threadStride, sharedmemAdaptorKey);
87-
sharedmemAdaptorKey.workgroupExecutionAndMemoryBarrier();
88-
89-
value_t pHiVal = hiVal;
90-
shuffleXor(pHiVal, threadStride, sharedmemAdaptorValue);
74+
nbl::hlsl::pair<key_t, value_t> pHiPair = hiPair;
75+
shuffleXor(pHiPair, threadStride, sharedmemAccessor);
9176

9277
const bool isUpper = (invocationID & threadStride) != 0;
9378
const bool takeLarger = isUpper == bitonicAscending;
9479

95-
nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loKey, pLoKey, hiKey, pHiKey, loVal, pLoVal, hiVal, pHiVal, comp);
96-
80+
nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loPair, pLoPair, hiPair, pHiPair, comp);
9781
}
9882
}
9983

100-
template<typename Accessor, typename SharedMemoryAccessor>
84+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(bitonic_sort::BitonicSortAccessor<Accessor, key_t, value_t>&& bitonic_sort::BitonicSortSharedMemoryAccessor<SharedMemoryAccessor>)
10185
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
10286
{
10387
const uint32_t WorkgroupSize = config_t::WorkgroupSize;
@@ -111,15 +95,13 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
11195

11296
const uint32_t loIdx = invocationID * 2;
11397
const uint32_t hiIdx = loIdx | 1;
114-
key_t loKey, hiKey;
115-
value_t loVal, hiVal;
116-
accessor.template get<key_t>(loIdx, loKey);
117-
accessor.template get<key_t>(hiIdx, hiKey);
118-
accessor.template get<value_t>(loIdx, loVal);
119-
accessor.template get<value_t>(hiIdx, hiVal);
98+
99+
nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
100+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIdx, loPair);
101+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hiPair);
120102

121103
const bool subgroupAscending = (subgroupID & 1) == 0;
122-
subgroup::bitonic_sort<SortConfig>::__call(subgroupAscending, loKey, hiKey, loVal, hiVal);
104+
subgroup::bitonic_sort<SortConfig>::__call(subgroupAscending, loPair.first, hiPair.first, loPair.second, hiPair.second);
123105

124106
const uint32_t subgroupInvocationID = glsl::gl_SubgroupInvocationID();
125107

@@ -128,139 +110,17 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
128110
{
129111
const bool bitonicAscending = !bool(invocationID & (subgroupSize << (stage + 1)));
130112

131-
mergeStage(sharedmemAccessor, stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal);
113+
mergeStage(sharedmemAccessor, stage, bitonicAscending, invocationID, loPair, hiPair);
132114

133-
subgroup::bitonic_sort<SortConfig>::mergeStage(subgroupSizeLog2, bitonicAscending, subgroupInvocationID, loKey, hiKey, loVal, hiVal);
115+
subgroup::bitonic_sort<SortConfig>::mergeStage(subgroupSizeLog2, bitonicAscending, subgroupInvocationID, loPair.first, hiPair.first, loPair.second, hiPair.second);
134116
}
135117

136-
137-
accessor.template set<key_t>(loIdx, loKey);
138-
accessor.template set<key_t>(hiIdx, hiKey);
139-
accessor.template set<value_t>(loIdx, loVal);
140-
accessor.template set<value_t>(hiIdx, hiVal);
118+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIdx, loPair);
119+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hiPair);
141120
}
142121
};
143-
// ==================== ElementsPerThreadLog2 = 2 Specialization (Virtual Threading) ====================
144-
template<uint16_t WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
145-
struct BitonicSort<bitonic_sort::bitonic_sort_config<2, WorkgroupSizeLog2, KeyType, ValueType, Comparator>, device_capabilities>
146-
{
147-
using config_t = bitonic_sort::bitonic_sort_config<2, WorkgroupSizeLog2, KeyType, ValueType, Comparator>;
148-
using simple_config_t = bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyType, ValueType, Comparator>;
149122

150-
using key_t = KeyType;
151-
using value_t = ValueType;
152-
using comparator_t = Comparator;
153-
154-
template<typename Accessor, typename SharedMemoryAccessor>
155-
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
156-
{
157-
const uint32_t WorkgroupSize = config_t::WorkgroupSize;
158-
const uint32_t ElementsPerThread = config_t::ElementsPerInvocation;
159-
const uint32_t TotalElements = WorkgroupSize * ElementsPerThread;
160-
const uint32_t ElementsPerSimpleSort = WorkgroupSize * 2; // E=1 handles WG*2 elements
161-
162-
const uint32_t threadID = glsl::gl_LocalInvocationID().x;
163-
comparator_t comp;
164-
165-
accessor_adaptors::Offset<Accessor> offsetAccessor;
166-
offsetAccessor.accessor = accessor;
167-
168-
[unroll]
169-
for (uint32_t k = 0; k < ElementsPerThread; k += 2)
170-
{
171-
if (k)
172-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
173-
174-
offsetAccessor.offset = ElementsPerSimpleSort * (k / 2);
175-
176-
BitonicSort<simple_config_t, device_capabilities>::template __call(offsetAccessor, sharedmemAccessor);
177-
}
178-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
179-
180-
accessor = offsetAccessor.accessor;
181-
182-
const uint32_t simpleLog = hlsl::findMSB(ElementsPerSimpleSort - 1) + 1u;
183-
const uint32_t totalLog = hlsl::findMSB(TotalElements - 1) + 1u;
184-
185-
[unroll]
186-
for (uint32_t blockLog = simpleLog + 1u; blockLog <= totalLog; blockLog++)
187-
{
188-
// Reverse odd halves for bitonic property
189-
const uint32_t halfLog = blockLog - 1u;
190-
const uint32_t halfSize = 1u << halfLog;
191-
const uint32_t numHalves = TotalElements >> halfLog;
192-
193-
// Process only odd-indexed halves (no thread divergence)
194-
[unroll]
195-
for (uint32_t halfIdx = 1u; halfIdx < numHalves; halfIdx += 2u)
196-
{
197-
const uint32_t halfBaseIdx = halfIdx << halfLog;
198-
199-
[unroll]
200-
for (uint32_t strideLog = halfLog - 1u; strideLog + 1u > 0u; strideLog--)
201-
{
202-
const uint32_t stride = 1u << strideLog;
203-
const uint32_t virtualThreadsInHalf = halfSize >> 1u;
204-
205-
[unroll]
206-
for (uint32_t virtualThreadID = threadID; virtualThreadID < virtualThreadsInHalf; virtualThreadID += WorkgroupSize)
207-
{
208-
const uint32_t localLoIx = ((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u));
209-
const uint32_t loIx = halfBaseIdx + localLoIx;
210-
const uint32_t hiIx = loIx | stride;
211-
212-
key_t loKeyGlobal, hiKeyGlobal;
213-
value_t loValGlobal, hiValGlobal;
214-
accessor.template get<key_t>(loIx, loKeyGlobal);
215-
accessor.template get<key_t>(hiIx, hiKeyGlobal);
216-
accessor.template get<value_t>(loIx, loValGlobal);
217-
accessor.template get<value_t>(hiIx, hiValGlobal);
218-
219-
nbl::hlsl::bitonic_sort::swap(loKeyGlobal, hiKeyGlobal, loValGlobal, hiValGlobal);
220-
221-
accessor.template set<key_t>(loIx, loKeyGlobal);
222-
accessor.template set<key_t>(hiIx, hiKeyGlobal);
223-
accessor.template set<value_t>(loIx, loValGlobal);
224-
accessor.template set<value_t>(hiIx, hiValGlobal);
225-
}
226-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
227-
}
228-
}
229-
230-
const uint32_t k = 1u << blockLog;
231-
[unroll]
232-
for (uint32_t strideLog = blockLog - 1u; strideLog + 1u > 0u; strideLog--)
233-
{
234-
const uint32_t stride = 1u << strideLog;
235-
236-
[unroll]
237-
for (uint32_t virtualThreadID = threadID; virtualThreadID < TotalElements / 2; virtualThreadID += WorkgroupSize)
238-
{
239-
const uint32_t loIx = ((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u));
240-
const uint32_t hiIx = loIx | stride;
241-
242-
const bool bitonicAscending = ((loIx & k) == 0u);
243-
244-
key_t loKeyGlobal, hiKeyGlobal;
245-
value_t loValGlobal, hiValGlobal;
246-
accessor.template get<key_t>(loIx, loKeyGlobal);
247-
accessor.template get<key_t>(hiIx, hiKeyGlobal);
248-
accessor.template get<value_t>(loIx, loValGlobal);
249-
accessor.template get<value_t>(hiIx, hiValGlobal);
250-
251-
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loKeyGlobal, hiKeyGlobal, loValGlobal, hiValGlobal, comp);
252-
253-
accessor.template set<key_t>(loIx, loKeyGlobal);
254-
accessor.template set<key_t>(hiIx, hiKeyGlobal);
255-
accessor.template set<value_t>(loIx, loValGlobal);
256-
accessor.template set<value_t>(hiIx, hiValGlobal);
257-
}
258-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
259-
}
260-
}
261-
}
262-
};
263-
// ==================== ElementsPerThreadLog2 > 2 Specialization (Virtual Threading) ====================
123+
// ==================== ElementsPerThreadLog2 > 1 Specialization (Virtual Threading) ====================
264124
// This handles larger arrays by combining global memory stages with recursive E=1 workgroup sorts
265125
template<uint16_t ElementsPerThreadLog2, uint16_t WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
266126
struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, WorkgroupSizeLog2, KeyType, ValueType, Comparator>, device_capabilities>
@@ -295,10 +155,10 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, Work
295155
if (sub)
296156
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
297157

298-
offsetAccessor.offset = sub * ElementsPerSimpleSort;
158+
offsetAccessor.offset = sub * ElementsPerSimpleSort;
299159

300-
// Call E=1 workgroup sort
301-
BitonicSort<simple_config_t, device_capabilities>::template __call(offsetAccessor, sharedmemAccessor);
160+
// Call E=1 workgroup sort
161+
BitonicSort<simple_config_t, device_capabilities>::template __call(offsetAccessor, sharedmemAccessor);
302162
}
303163
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
304164

@@ -318,59 +178,49 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, Work
318178
const uint32_t loIx = (((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u))) + offsetAccessor.offset;
319179
const uint32_t hiIx = loIx | stride;
320180

321-
key_t loKeyGlobal, hiKeyGlobal;
322-
value_t loValGlobal, hiValGlobal;
323-
accessor.template get<key_t>(loIx, loKeyGlobal);
324-
accessor.template get<key_t>(hiIx, hiKeyGlobal);
325-
accessor.template get<value_t>(loIx, loValGlobal);
326-
accessor.template get<value_t>(hiIx, hiValGlobal);
181+
nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
182+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
183+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
327184

328-
nbl::hlsl::bitonic_sort::swap(loKeyGlobal, hiKeyGlobal, loValGlobal, hiValGlobal);
185+
nbl::hlsl::bitonic_sort::swap(loPair.first, hiPair.first, loPair.second, hiPair.second);
329186

330-
accessor.template set<key_t>(loIx, loKeyGlobal);
331-
accessor.template set<key_t>(hiIx, hiKeyGlobal);
332-
accessor.template set<value_t>(loIx, loValGlobal);
333-
accessor.template set<value_t>(hiIx, hiValGlobal);
187+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
188+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
334189
}
335190
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
191+
}
336192
}
337-
}
338193

339-
// PHASE 3: Global memory bitonic merge
340-
const uint32_t totalLog = hlsl::findMSB(TotalElements - 1) + 1u;
341-
[unroll]
342-
for (uint32_t blockLog = simpleLog + 1u; blockLog <= totalLog; blockLog++)
343-
{
344-
const uint32_t k = 1u << blockLog;
194+
// PHASE 3: Global memory bitonic merge
195+
const uint32_t totalLog = hlsl::findMSB(TotalElements - 1) + 1u;
345196
[unroll]
346-
for (uint32_t strideLog = blockLog - 1u; strideLog + 1u > 0u; strideLog--)
197+
for (uint32_t blockLog = simpleLog + 1u; blockLog <= totalLog; blockLog++)
347198
{
348-
const uint32_t stride = 1u << strideLog;
199+
const uint32_t k = 1u << blockLog;
349200
[unroll]
350-
for (uint32_t virtualThreadID = threadID; virtualThreadID < TotalElements / 2; virtualThreadID += WorkgroupSize)
201+
for (uint32_t strideLog = blockLog - 1u; strideLog + 1u > 0u; strideLog--)
351202
{
352-
const uint32_t loIx = ((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u));
353-
const uint32_t hiIx = loIx | stride;
203+
const uint32_t stride = 1u << strideLog;
204+
[unroll]
205+
for (uint32_t virtualThreadID = threadID; virtualThreadID < TotalElements / 2; virtualThreadID += WorkgroupSize)
206+
{
207+
const uint32_t loIx = ((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u));
208+
const uint32_t hiIx = loIx | stride;
354209

355-
const bool bitonicAscending = ((loIx & k) == 0u);
210+
const bool bitonicAscending = ((loIx & k) == 0u);
356211

357-
key_t loKeyGlobal, hiKeyGlobal;
358-
value_t loValGlobal, hiValGlobal;
359-
accessor.template get<key_t>(loIx, loKeyGlobal);
360-
accessor.template get<key_t>(hiIx, hiKeyGlobal);
361-
accessor.template get<value_t>(loIx, loValGlobal);
362-
accessor.template get<value_t>(hiIx, hiValGlobal);
212+
nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
213+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
214+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
363215

364-
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loKeyGlobal, hiKeyGlobal, loValGlobal, hiValGlobal, comp);
216+
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loPair.first, hiPair.first, loPair.second, hiPair.second, comp);
365217

366-
accessor.template set<key_t>(loIx, loKeyGlobal);
367-
accessor.template set<key_t>(hiIx, hiKeyGlobal);
368-
accessor.template set<value_t>(loIx, loValGlobal);
369-
accessor.template set<value_t>(hiIx, hiValGlobal);
218+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
219+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
220+
}
221+
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
370222
}
371-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
372223
}
373-
}
374224
}
375225
};
376226

0 commit comments

Comments
 (0)