77#include "nbl/builtin/hlsl/bit.hlsl"
88#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
99#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
10+ #include "nbl/builtin/hlsl/concepts/accessors/bitonic_sort.hlsl"
1011
1112namespace nbl
1213{
@@ -17,18 +18,19 @@ namespace workgroup
1718namespace bitonic_sort
1819{
1920
20- template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator = less <KeyType> >
21+ template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator = less <KeyType> NBL_PRIMARY_REQUIRES (_ElementsPerInvocationLog2 >= 1 && _WorkgroupSizeLog2 >= 5 )
2122struct bitonic_sort_config
2223{
2324 using key_t = KeyType;
2425 using value_t = ValueType;
2526 using comparator_t = Comparator;
26-
2727 NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
2828 NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
2929
3030 NBL_CONSTEXPR_STATIC_INLINE uint32_t ElementsPerInvocation = 1u << ElementsPerInvocationLog2;
3131 NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = 1u << WorkgroupSizeLog2;
32+ NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedmemDWORDs = sizeof (pair<key_t, value_t>) / sizeof (uint32_t) * WorkgroupSize;
33+
3234};
3335}
3436
@@ -48,56 +50,38 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
4850 using SortConfig = subgroup::bitonic_sort_config<key_t, value_t, comparator_t>;
4951
5052 template<typename SharedMemoryAccessor>
51- static void mergeStage (NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG (key_t) loKey, NBL_REF_ARG (key_t) hiKey,
52- NBL_REF_ARG (value_t) loVal , NBL_REF_ARG (value_t) hiVal )
53+ static void mergeStage (NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID,
54+ NBL_REF_ARG (nbl::hlsl::pair<key_t, value_t>) loPair , NBL_REF_ARG (nbl::hlsl::pair<key_t, value_t>) hiPair )
5355 {
5456 const uint32_t WorkgroupSize = config_t::WorkgroupSize;
55- using key_adaptor = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, key_t, uint32_t, 1 , WorkgroupSize>;
56- using value_adaptor = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, value_t, uint32_t, 1 , WorkgroupSize>;
57-
58- key_adaptor sharedmemAdaptorKey;
59- sharedmemAdaptorKey.accessor = sharedmemAccessor;
60-
61- value_adaptor sharedmemAdaptorValue;
62- sharedmemAdaptorValue.accessor = sharedmemAccessor;
63-
6457 const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2 ();
6558 comparator_t comp;
6659
6760 [unroll]
6861 for (uint32_t pass = 0 ; pass <= stage; pass ++)
6962 {
7063 if (pass )
71- sharedmemAdaptorValue .workgroupExecutionAndMemoryBarrier ();
64+ sharedmemAccessor .workgroupExecutionAndMemoryBarrier ();
7265
7366 const uint32_t stridePower = (stage - pass + 1 ) + subgroupSizeLog2;
7467 const uint32_t stride = 1u << stridePower;
7568 const uint32_t threadStride = stride >> 1 ;
7669
77- key_t pLoKey = loKey;
78- shuffleXor (pLoKey, threadStride, sharedmemAdaptorKey);
79- sharedmemAdaptorKey.workgroupExecutionAndMemoryBarrier ();
80-
81- value_t pLoVal = loVal;
82- shuffleXor (pLoVal, threadStride, sharedmemAdaptorValue);
83- sharedmemAdaptorValue.workgroupExecutionAndMemoryBarrier ();
70+ nbl::hlsl::pair<key_t, value_t> pLoPair = loPair;
71+ shuffleXor (pLoPair, threadStride, sharedmemAccessor);
72+ sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
8473
85- key_t pHiKey = hiKey;
86- shuffleXor (pHiKey, threadStride, sharedmemAdaptorKey);
87- sharedmemAdaptorKey.workgroupExecutionAndMemoryBarrier ();
88-
89- value_t pHiVal = hiVal;
90- shuffleXor (pHiVal, threadStride, sharedmemAdaptorValue);
74+ nbl::hlsl::pair<key_t, value_t> pHiPair = hiPair;
75+ shuffleXor (pHiPair, threadStride, sharedmemAccessor);
9176
9277 const bool isUpper = (invocationID & threadStride) != 0 ;
9378 const bool takeLarger = isUpper == bitonicAscending;
9479
95- nbl::hlsl::bitonic_sort::compareExchangeWithPartner (takeLarger, loKey, pLoKey, hiKey, pHiKey, loVal, pLoVal, hiVal, pHiVal, comp);
96-
80+ nbl::hlsl::bitonic_sort::compareExchangeWithPartner (takeLarger, loPair, pLoPair, hiPair, pHiPair, comp);
9781 }
9882 }
9983
100- template<typename Accessor, typename SharedMemoryAccessor>
84+ template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES (bitonic_sort::BitonicSortAccessor<Accessor, key_t, value_t>&& bitonic_sort::BitonicSortSharedMemoryAccessor<SharedMemoryAccessor>)
10185 static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
10286 {
10387 const uint32_t WorkgroupSize = config_t::WorkgroupSize;
@@ -111,15 +95,13 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
11195
11296 const uint32_t loIdx = invocationID * 2 ;
11397 const uint32_t hiIdx = loIdx | 1 ;
114- key_t loKey, hiKey;
115- value_t loVal, hiVal;
116- accessor.template get<key_t>(loIdx, loKey);
117- accessor.template get<key_t>(hiIdx, hiKey);
118- accessor.template get<value_t>(loIdx, loVal);
119- accessor.template get<value_t>(hiIdx, hiVal);
98+
99+ nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
100+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIdx, loPair);
101+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hiPair);
120102
121103 const bool subgroupAscending = (subgroupID & 1 ) == 0 ;
122- subgroup::bitonic_sort<SortConfig>::__call (subgroupAscending, loKey, hiKey, loVal, hiVal );
104+ subgroup::bitonic_sort<SortConfig>::__call (subgroupAscending, loPair.first, hiPair.first, loPair.second, hiPair.second );
123105
124106 const uint32_t subgroupInvocationID = glsl::gl_SubgroupInvocationID ();
125107
@@ -128,139 +110,17 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
128110 {
129111 const bool bitonicAscending = !bool (invocationID & (subgroupSize << (stage + 1 )));
130112
131- mergeStage (sharedmemAccessor, stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal );
113+ mergeStage (sharedmemAccessor, stage, bitonicAscending, invocationID, loPair, hiPair );
132114
133- subgroup::bitonic_sort<SortConfig>::mergeStage (subgroupSizeLog2, bitonicAscending, subgroupInvocationID, loKey, hiKey, loVal, hiVal );
115+ subgroup::bitonic_sort<SortConfig>::mergeStage (subgroupSizeLog2, bitonicAscending, subgroupInvocationID, loPair.first, hiPair.first, loPair.second, hiPair.second );
134116 }
135117
136-
137- accessor.template set<key_t>(loIdx, loKey);
138- accessor.template set<key_t>(hiIdx, hiKey);
139- accessor.template set<value_t>(loIdx, loVal);
140- accessor.template set<value_t>(hiIdx, hiVal);
118+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIdx, loPair);
119+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hiPair);
141120 }
142121};
143- // ==================== ElementsPerThreadLog2 = 2 Specialization (Virtual Threading) ====================
144- template<uint16_t WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
145- struct BitonicSort<bitonic_sort::bitonic_sort_config<2 , WorkgroupSizeLog2, KeyType, ValueType, Comparator>, device_capabilities>
146- {
147- using config_t = bitonic_sort::bitonic_sort_config<2 , WorkgroupSizeLog2, KeyType, ValueType, Comparator>;
148- using simple_config_t = bitonic_sort::bitonic_sort_config<1 , WorkgroupSizeLog2, KeyType, ValueType, Comparator>;
149122
150- using key_t = KeyType;
151- using value_t = ValueType;
152- using comparator_t = Comparator;
153-
154- template<typename Accessor, typename SharedMemoryAccessor>
155- static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
156- {
157- const uint32_t WorkgroupSize = config_t::WorkgroupSize;
158- const uint32_t ElementsPerThread = config_t::ElementsPerInvocation;
159- const uint32_t TotalElements = WorkgroupSize * ElementsPerThread;
160- const uint32_t ElementsPerSimpleSort = WorkgroupSize * 2 ; // E=1 handles WG*2 elements
161-
162- const uint32_t threadID = glsl::gl_LocalInvocationID ().x;
163- comparator_t comp;
164-
165- accessor_adaptors::Offset<Accessor> offsetAccessor;
166- offsetAccessor.accessor = accessor;
167-
168- [unroll]
169- for (uint32_t k = 0 ; k < ElementsPerThread; k += 2 )
170- {
171- if (k)
172- sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
173-
174- offsetAccessor.offset = ElementsPerSimpleSort * (k / 2 );
175-
176- BitonicSort<simple_config_t, device_capabilities>::template __call (offsetAccessor, sharedmemAccessor);
177- }
178- sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
179-
180- accessor = offsetAccessor.accessor;
181-
182- const uint32_t simpleLog = hlsl::findMSB (ElementsPerSimpleSort - 1 ) + 1u;
183- const uint32_t totalLog = hlsl::findMSB (TotalElements - 1 ) + 1u;
184-
185- [unroll]
186- for (uint32_t blockLog = simpleLog + 1u; blockLog <= totalLog; blockLog++)
187- {
188- // Reverse odd halves for bitonic property
189- const uint32_t halfLog = blockLog - 1u;
190- const uint32_t halfSize = 1u << halfLog;
191- const uint32_t numHalves = TotalElements >> halfLog;
192-
193- // Process only odd-indexed halves (no thread divergence)
194- [unroll]
195- for (uint32_t halfIdx = 1u; halfIdx < numHalves; halfIdx += 2u)
196- {
197- const uint32_t halfBaseIdx = halfIdx << halfLog;
198-
199- [unroll]
200- for (uint32_t strideLog = halfLog - 1u; strideLog + 1u > 0u; strideLog--)
201- {
202- const uint32_t stride = 1u << strideLog;
203- const uint32_t virtualThreadsInHalf = halfSize >> 1u;
204-
205- [unroll]
206- for (uint32_t virtualThreadID = threadID; virtualThreadID < virtualThreadsInHalf; virtualThreadID += WorkgroupSize)
207- {
208- const uint32_t localLoIx = ((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u));
209- const uint32_t loIx = halfBaseIdx + localLoIx;
210- const uint32_t hiIx = loIx | stride;
211-
212- key_t loKeyGlobal, hiKeyGlobal;
213- value_t loValGlobal, hiValGlobal;
214- accessor.template get<key_t>(loIx, loKeyGlobal);
215- accessor.template get<key_t>(hiIx, hiKeyGlobal);
216- accessor.template get<value_t>(loIx, loValGlobal);
217- accessor.template get<value_t>(hiIx, hiValGlobal);
218-
219- nbl::hlsl::bitonic_sort::swap (loKeyGlobal, hiKeyGlobal, loValGlobal, hiValGlobal);
220-
221- accessor.template set<key_t>(loIx, loKeyGlobal);
222- accessor.template set<key_t>(hiIx, hiKeyGlobal);
223- accessor.template set<value_t>(loIx, loValGlobal);
224- accessor.template set<value_t>(hiIx, hiValGlobal);
225- }
226- sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
227- }
228- }
229-
230- const uint32_t k = 1u << blockLog;
231- [unroll]
232- for (uint32_t strideLog = blockLog - 1u; strideLog + 1u > 0u; strideLog--)
233- {
234- const uint32_t stride = 1u << strideLog;
235-
236- [unroll]
237- for (uint32_t virtualThreadID = threadID; virtualThreadID < TotalElements / 2 ; virtualThreadID += WorkgroupSize)
238- {
239- const uint32_t loIx = ((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u));
240- const uint32_t hiIx = loIx | stride;
241-
242- const bool bitonicAscending = ((loIx & k) == 0u);
243-
244- key_t loKeyGlobal, hiKeyGlobal;
245- value_t loValGlobal, hiValGlobal;
246- accessor.template get<key_t>(loIx, loKeyGlobal);
247- accessor.template get<key_t>(hiIx, hiKeyGlobal);
248- accessor.template get<value_t>(loIx, loValGlobal);
249- accessor.template get<value_t>(hiIx, hiValGlobal);
250-
251- nbl::hlsl::bitonic_sort::compareSwap (bitonicAscending, loKeyGlobal, hiKeyGlobal, loValGlobal, hiValGlobal, comp);
252-
253- accessor.template set<key_t>(loIx, loKeyGlobal);
254- accessor.template set<key_t>(hiIx, hiKeyGlobal);
255- accessor.template set<value_t>(loIx, loValGlobal);
256- accessor.template set<value_t>(hiIx, hiValGlobal);
257- }
258- sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
259- }
260- }
261- }
262- };
263- // ==================== ElementsPerThreadLog2 > 2 Specialization (Virtual Threading) ====================
123+ // ==================== ElementsPerThreadLog2 > 1 Specialization (Virtual Threading) ====================
264124// This handles larger arrays by combining global memory stages with recursive E=1 workgroup sorts
265125template<uint16_t ElementsPerThreadLog2, uint16_t WorkgroupSizeLog2, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
266126struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, WorkgroupSizeLog2, KeyType, ValueType, Comparator>, device_capabilities>
@@ -295,10 +155,10 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, Work
295155 if (sub)
296156 sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
297157
298- offsetAccessor.offset = sub * ElementsPerSimpleSort;
158+ offsetAccessor.offset = sub * ElementsPerSimpleSort;
299159
300- // Call E=1 workgroup sort
301- BitonicSort<simple_config_t, device_capabilities>::template __call (offsetAccessor, sharedmemAccessor);
160+ // Call E=1 workgroup sort
161+ BitonicSort<simple_config_t, device_capabilities>::template __call (offsetAccessor, sharedmemAccessor);
302162 }
303163 sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
304164
@@ -318,59 +178,49 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, Work
318178 const uint32_t loIx = (((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u))) + offsetAccessor.offset;
319179 const uint32_t hiIx = loIx | stride;
320180
321- key_t loKeyGlobal, hiKeyGlobal;
322- value_t loValGlobal, hiValGlobal;
323- accessor.template get<key_t>(loIx, loKeyGlobal);
324- accessor.template get<key_t>(hiIx, hiKeyGlobal);
325- accessor.template get<value_t>(loIx, loValGlobal);
326- accessor.template get<value_t>(hiIx, hiValGlobal);
181+ nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
182+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
183+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
327184
328- nbl::hlsl::bitonic_sort::swap (loKeyGlobal, hiKeyGlobal, loValGlobal, hiValGlobal );
185+ nbl::hlsl::bitonic_sort::swap (loPair.first, hiPair.first, loPair.second, hiPair.second );
329186
330- accessor.template set<key_t>(loIx, loKeyGlobal);
331- accessor.template set<key_t>(hiIx, hiKeyGlobal);
332- accessor.template set<value_t>(loIx, loValGlobal);
333- accessor.template set<value_t>(hiIx, hiValGlobal);
187+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
188+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
334189 }
335190 sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
191+ }
336192 }
337- }
338193
339- // PHASE 3: Global memory bitonic merge
340- const uint32_t totalLog = hlsl::findMSB (TotalElements - 1 ) + 1u;
341- [unroll]
342- for (uint32_t blockLog = simpleLog + 1u; blockLog <= totalLog; blockLog++)
343- {
344- const uint32_t k = 1u << blockLog;
194+ // PHASE 3: Global memory bitonic merge
195+ const uint32_t totalLog = hlsl::findMSB (TotalElements - 1 ) + 1u;
345196 [unroll]
346- for (uint32_t strideLog = blockLog - 1u; strideLog + 1u > 0u; strideLog-- )
197+ for (uint32_t blockLog = simpleLog + 1u; blockLog <= totalLog; blockLog++ )
347198 {
348- const uint32_t stride = 1u << strideLog ;
199+ const uint32_t k = 1u << blockLog ;
349200 [unroll]
350- for (uint32_t virtualThreadID = threadID; virtualThreadID < TotalElements / 2 ; virtualThreadID += WorkgroupSize )
201+ for (uint32_t strideLog = blockLog - 1u; strideLog + 1u > 0u; strideLog-- )
351202 {
352- const uint32_t loIx = ((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u));
353- const uint32_t hiIx = loIx | stride;
203+ const uint32_t stride = 1u << strideLog;
204+ [unroll]
205+ for (uint32_t virtualThreadID = threadID; virtualThreadID < TotalElements / 2 ; virtualThreadID += WorkgroupSize)
206+ {
207+ const uint32_t loIx = ((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u));
208+ const uint32_t hiIx = loIx | stride;
354209
355- const bool bitonicAscending = ((loIx & k) == 0u);
210+ const bool bitonicAscending = ((loIx & k) == 0u);
356211
357- key_t loKeyGlobal, hiKeyGlobal;
358- value_t loValGlobal, hiValGlobal;
359- accessor.template get<key_t>(loIx, loKeyGlobal);
360- accessor.template get<key_t>(hiIx, hiKeyGlobal);
361- accessor.template get<value_t>(loIx, loValGlobal);
362- accessor.template get<value_t>(hiIx, hiValGlobal);
212+ nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
213+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
214+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
363215
364- nbl::hlsl::bitonic_sort::compareSwap (bitonicAscending, loKeyGlobal, hiKeyGlobal, loValGlobal, hiValGlobal , comp);
216+ nbl::hlsl::bitonic_sort::compareSwap (bitonicAscending, loPair.first, hiPair.first, loPair.second, hiPair.second , comp);
365217
366- accessor.template set<key_t> (loIx, loKeyGlobal );
367- accessor.template set<key_t> (hiIx, hiKeyGlobal );
368- accessor.template set<value_t>(loIx, loValGlobal);
369- accessor.template set<value_t>(hiIx, hiValGlobal );
218+ accessor.template set<nbl::hlsl::pair< key_t, value_t> > (loIx, loPair );
219+ accessor.template set<nbl::hlsl::pair< key_t, value_t> > (hiIx, hiPair );
220+ }
221+ sharedmemAccessor. workgroupExecutionAndMemoryBarrier ( );
370222 }
371- sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
372223 }
373- }
374224 }
375225};
376226
0 commit comments