Skip to content

Commit b2f7d3a

Browse files
committed
Update bitonic_sort.hlsl
1 parent 17eebef commit b2f7d3a

File tree

1 file changed

+49
-27
lines changed

1 file changed

+49
-27
lines changed

include/nbl/builtin/hlsl/workgroup/bitonic_sort.hlsl

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,37 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
4949

5050
using SortConfig = subgroup::bitonic_sort_config<key_t, value_t, comparator_t>;
5151

52+
template<typename Key, typename Value, typename key_adaptor_t, typename value_adaptor_t>
53+
static void shuffleXor(NBL_REF_ARG(pair<Key, Value>) p, uint32_t ownedIdx, uint32_t mask, NBL_REF_ARG(key_adaptor_t) keyAdaptor, NBL_REF_ARG(value_adaptor_t) valueAdaptor)
54+
{
55+
keyAdaptor.template set<Key>(ownedIdx, p.first);
56+
valueAdaptor.template set<Value>(ownedIdx, p.second);
57+
58+
// Wait until all writes are done before reading - only barrier on one adaptor here
59+
keyAdaptor.workgroupExecutionAndMemoryBarrier();
60+
61+
keyAdaptor.template get<Key>(ownedIdx ^ mask, p.first);
62+
valueAdaptor.template get<Value>(ownedIdx ^ mask, p.second);
63+
}
64+
65+
5266
template<typename SharedMemoryAccessor>
5367
static void mergeStage(NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor, uint32_t stage, bool bitonicAscending, uint32_t invocationID,
54-
NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) loPair, NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) hiPair)
68+
NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) lopair, NBL_REF_ARG(nbl::hlsl::pair<key_t, value_t>) hipair)
5569
{
5670
const uint32_t WorkgroupSize = config_t::WorkgroupSize;
5771
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
5872
comparator_t comp;
5973

74+
75+
using key_adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, uint32_t, uint32_t, 1, WorkgroupSize>;
76+
using value_adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor, uint32_t, uint32_t, 1, WorkgroupSize, integral_constant<uint32_t, WorkgroupSize * sizeof(key_t) / sizeof(uint32_t)> >;
77+
78+
key_adaptor_t keyAdaptor;
79+
keyAdaptor.accessor = sharedmemAccessor;
80+
value_adaptor_t valueAdaptor;
81+
valueAdaptor.accessor = sharedmemAccessor;
82+
6083
[unroll]
6184
for (uint32_t pass = 0; pass <= stage; pass++)
6285
{
@@ -67,17 +90,16 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
6790
const uint32_t stride = 1u << stridePower;
6891
const uint32_t threadStride = stride >> 1;
6992

70-
nbl::hlsl::pair<key_t, value_t> pLoPair = loPair;
71-
shuffleXor(pLoPair, threadStride, sharedmemAccessor);
72-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
93+
nbl::hlsl::pair<key_t, value_t> plopair = lopair;
94+
shuffleXor(plopair, invocationID, threadStride, keyAdaptor, valueAdaptor);
7395

74-
nbl::hlsl::pair<key_t, value_t> pHiPair = hiPair;
75-
shuffleXor(pHiPair, threadStride, sharedmemAccessor);
96+
nbl::hlsl::pair<key_t, value_t> phipair = hipair;
97+
shuffleXor(phipair, invocationID ^ threadStride, threadStride, keyAdaptor, valueAdaptor);
7698

7799
const bool isUpper = (invocationID & threadStride) != 0;
78100
const bool takeLarger = isUpper == bitonicAscending;
79101

80-
nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loPair, pLoPair, hiPair, pHiPair, comp);
102+
nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, lopair, plopair, hipair, phipair, comp);
81103
}
82104
}
83105

@@ -96,12 +118,12 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
96118
const uint32_t loIdx = invocationID * 2;
97119
const uint32_t hiIdx = loIdx | 1;
98120

99-
nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
100-
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIdx, loPair);
101-
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hiPair);
121+
nbl::hlsl::pair<key_t, value_t> lopair, hipair;
122+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIdx, lopair);
123+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hipair);
102124

103125
const bool subgroupAscending = (subgroupID & 1) == 0;
104-
subgroup::bitonic_sort<SortConfig>::__call(subgroupAscending, loPair.first, hiPair.first, loPair.second, hiPair.second);
126+
subgroup::bitonic_sort<SortConfig>::__call(subgroupAscending, lopair, hipair);
105127

106128
const uint32_t subgroupInvocationID = glsl::gl_SubgroupInvocationID();
107129

@@ -110,13 +132,13 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<1, WorkgroupSizeLog2, KeyTy
110132
{
111133
const bool bitonicAscending = !bool(invocationID & (subgroupSize << (stage + 1)));
112134

113-
mergeStage(sharedmemAccessor, stage, bitonicAscending, invocationID, loPair, hiPair);
135+
mergeStage(sharedmemAccessor, stage, bitonicAscending, invocationID, lopair, hipair);
114136

115-
subgroup::bitonic_sort<SortConfig>::mergeStage(subgroupSizeLog2, bitonicAscending, subgroupInvocationID, loPair.first, hiPair.first, loPair.second, hiPair.second);
137+
subgroup::bitonic_sort<SortConfig>::mergeStage(subgroupSizeLog2, bitonicAscending, subgroupInvocationID, lopair, hipair);
116138
}
117139

118-
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIdx, loPair);
119-
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hiPair);
140+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIdx, lopair);
141+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIdx, hipair);
120142
}
121143
};
122144

@@ -178,14 +200,14 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, Work
178200
const uint32_t loIx = (((virtualThreadID & (~(stride - 1u))) << 1u) | (virtualThreadID & (stride - 1u))) + offsetAccessor.offset;
179201
const uint32_t hiIx = loIx | stride;
180202

181-
nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
182-
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
183-
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
203+
nbl::hlsl::pair<key_t, value_t> lopair, hipair;
204+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair);
205+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair);
184206

185-
nbl::hlsl::bitonic_sort::swap(loPair.first, hiPair.first, loPair.second, hiPair.second);
207+
swap(lopair, hipair);
186208

187-
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
188-
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
209+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair);
210+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair);
189211
}
190212
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
191213
}
@@ -209,14 +231,14 @@ struct BitonicSort<bitonic_sort::bitonic_sort_config<ElementsPerThreadLog2, Work
209231

210232
const bool bitonicAscending = ((loIx & k) == 0u);
211233

212-
nbl::hlsl::pair<key_t, value_t> loPair, hiPair;
213-
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
214-
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
234+
nbl::hlsl::pair<key_t, value_t> lopair, hipair;
235+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair);
236+
accessor.template get<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair);
215237

216-
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loPair.first, hiPair.first, loPair.second, hiPair.second, comp);
238+
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, lopair, hipair, comp);
217239

218-
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, loPair);
219-
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hiPair);
240+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(loIx, lopair);
241+
accessor.template set<nbl::hlsl::pair<key_t, value_t> >(hiIx, hipair);
220242
}
221243
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
222244
}

0 commit comments

Comments
 (0)