@@ -49,14 +49,37 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
4949
5050 using SortConfig = subgroup::bitonic_sort_config<key_t, value_t, comparator_t>;
5151
52+ template<typename Key, typename Value, typename key_adaptor_t, typename value_adaptor_t>
53+ static void shuffleXor (NBL_REF_ARG (pair<Key, Value>) p, uint32_t ownedIdx, uint32_t mask, NBL_REF_ARG (key_adaptor_t) keyAdaptor, NBL_REF_ARG (value_adaptor_t) valueAdaptor)
54+ {
55+ keyAdaptor.template set<Key>(ownedIdx, p.first);
56+ valueAdaptor.template set<Value>(ownedIdx, p.second);
57+
58+ // Wait until all writes are done before reading - only barrier on one adaptor here
59+ keyAdaptor.workgroupExecutionAndMemoryBarrier ();
60+
61+ keyAdaptor.template get<Key>(ownedIdx ^ mask, p.first);
62+ valueAdaptor.template get<Value>(ownedIdx ^ mask, p.second);
63+ }
64+
65+
5266 template<typename SharedMemoryAccessor>
5367 static void mergeStage (NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID,
54- NBL_REF_ARG (nbl::hlsl::pair<key_t, value_t>) loPair , NBL_REF_ARG (nbl::hlsl::pair<key_t, value_t>) hiPair )
68+ NBL_REF_ARG (nbl::hlsl::pair<key_t, value_t>) lopair , NBL_REF_ARG (nbl::hlsl::pair<key_t, value_t>) hipair )
5569 {
5670 const uint32_t WorkgroupSize = config_t::WorkgroupSize;
5771 const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2 ();
5872 comparator_t comp;
5973
74+
75+ using key_adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, uint32_t, uint32_t, 1 , WorkgroupSize>;
76+ using value_adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, uint32_t, uint32_t, 1 , WorkgroupSize, integral_constant<uint32_t, WorkgroupSize * sizeof (key_t) / sizeof (uint32_t)> >;
77+
78+ key_adaptor_t keyAdaptor;
79+ keyAdaptor.accessor = sharedmemAccessor;
80+ value_adaptor_t valueAdaptor;
81+ valueAdaptor.accessor = sharedmemAccessor;
82+
6083 [unroll]
6184 for (uint32_t pass = 0 ; pass <= stage; pass ++)
6285 {
@@ -67,17 +90,16 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
6790 const uint32_t stride = 1u << stridePower;
6891 const uint32_t threadStride = stride >> 1 ;
6992
70- nbl::hlsl::pair<key_t, value_t> pLoPair = loPair;
71- shuffleXor (pLoPair, threadStride, sharedmemAccessor);
72- sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
93+ nbl::hlsl::pair<key_t, value_t> plopair = lopair;
94+ shuffleXor (plopair, invocationID, threadStride, keyAdaptor, valueAdaptor);
7395
74- nbl::hlsl::pair<key_t, value_t> pHiPair = hiPair ;
75- shuffleXor (pHiPair, threadStride, sharedmemAccessor );
96+ nbl::hlsl::pair<key_t, value_t> phipair = hipair ;
97+ shuffleXor (phipair, invocationID ^ threadStride, threadStride, keyAdaptor, valueAdaptor );
7698
7799 const bool isUpper = (invocationID & threadStride) != 0 ;
78100 const bool takeLarger = isUpper == bitonicAscending;
79101
80- nbl::hlsl::bitonic_sort::compareExchangeWithPartner (takeLarger, loPair, pLoPair, hiPair, pHiPair , comp);
102+ nbl::hlsl::bitonic_sort::compareExchangeWithPartner (takeLarger, lopair, plopair, hipair, phipair , comp);
81103 }
82104 }
83105
@@ -96,12 +118,12 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
96118 const uint32_t loIdx = invocationID * 2 ;
97119 const uint32_t hiIdx = loIdx | 1 ;
98120
99- nbl::hlsl::pair<key_t, value_t> loPair, hiPair ;
100- accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIdx, loPair );
101- accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hiPair );
121+ nbl::hlsl::pair<key_t, value_t> lopair, hipair ;
122+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIdx, lopair );
123+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hipair );
102124
103125 const bool subgroupAscending = (subgroupID & 1 ) == 0 ;
104- subgroup::bitonic_sort<SortConfig>::__call (subgroupAscending, loPair.first, hiPair.first, loPair.second, hiPair.second );
126+ subgroup::bitonic_sort<SortConfig>::__call (subgroupAscending, lopair, hipair );
105127
106128 const uint32_t subgroupInvocationID = glsl::gl_SubgroupInvocationID ();
107129
@@ -110,13 +132,13 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
110132 {
111133 const bool bitonicAscending = !bool (invocationID & (subgroupSize << (stage + 1 )));
112134
113- mergeStage (sharedmemAccessor, stage, bitonicAscending, invocationID, loPair, hiPair );
135+ mergeStage (sharedmemAccessor, stage, bitonicAscending, invocationID, lopair, hipair );
114136
115- subgroup::bitonic_sort<SortConfig>::mergeStage (subgroupSizeLog2, bitonicAscending, subgroupInvocationID, loPair.first, hiPair.first, loPair.second, hiPair.second );
137+ subgroup::bitonic_sort<SortConfig>::mergeStage (subgroupSizeLog2, bitonicAscending, subgroupInvocationID, lopair, hipair );
116138 }
117139
118- accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIdx, loPair );
119- accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hiPair );
140+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIdx, lopair );
141+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hipair );
120142 }
121143};
122144
@@ -178,14 +200,14 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, Work
178200 const uint32_t loIx = (((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u))) + offsetAccessor.offset;
179201 const uint32_t hiIx = loIx | stride;
180202
181- nbl::hlsl::pair<key_t, value_t> loPair, hiPair ;
182- accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair );
183- accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair );
203+ nbl::hlsl::pair<key_t, value_t> lopair, hipair ;
204+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair );
205+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair );
184206
185- nbl::hlsl::bitonic_sort:: swap (loPair.first, hiPair.first, loPair.second, hiPair.second );
207+ swap (lopair, hipair );
186208
187- accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair );
188- accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair );
209+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair );
210+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair );
189211 }
190212 sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
191213 }
@@ -209,14 +231,14 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, Work
209231
210232 const bool bitonicAscending = ((loIx & k) == 0u);
211233
212- nbl::hlsl::pair<key_t, value_t> loPair, hiPair ;
213- accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair );
214- accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair );
234+ nbl::hlsl::pair<key_t, value_t> lopair, hipair ;
235+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair );
236+ accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair );
215237
216- nbl::hlsl::bitonic_sort::compareSwap (bitonicAscending, loPair.first, hiPair.first, loPair.second, hiPair.second , comp);
238+ nbl::hlsl::bitonic_sort::compareSwap (bitonicAscending, lopair, hipair , comp);
217239
218- accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair );
219- accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair );
240+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair );
241+ accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair );
220242 }
221243 sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
222244 }
0 commit comments