Skip to content

Commit 17eebef

Browse files
committed
Update bitonic_sort.hlsl
1 parent c750811 commit 17eebef

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

include/nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_
3030
using value_t = typename config_t::value_t;
3131
using comparator_t = typename config_t::comparator_t;
3232

33-
static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
34-
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
33+
static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID,
34+
NBL_REF_ARG(pair<key_t, value_t>) loPair, NBL_REF_ARG(pair<key_t, value_t>) hiPair)
3535
{
3636
comparator_t comp;
3737

@@ -43,34 +43,36 @@ struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_
4343
if (threadStride == 0)
4444
{
4545
// Local compare and swap for stage 0
46-
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loKey, hiKey, loVal, hiVal, comp);
46+
nbl::hlsl::bitonic_sort::compareSwap(bitonicAscending, loPair, hiPair, comp);
4747
}
4848
else
4949
{
5050
// Shuffle from partner using XOR
51-
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
52-
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
53-
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
54-
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);
51+
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loPair.first, threadStride);
52+
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loPair.second, threadStride);
53+
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiPair.first, threadStride);
54+
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiPair.second, threadStride);
55+
56+
const pair<key_t, value_t> partnerLoPair = make_pair(pLoKey, pLoVal);
57+
const pair<key_t, value_t> partnerHiPair = make_pair(pHiKey, pHiVal);
5558

5659
const bool isUpper = bool(invocationID & threadStride);
5760
const bool takeLarger = isUpper == bitonicAscending;
5861

59-
nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loKey, pLoKey, hiKey, pHiKey, loVal, pLoVal, hiVal, pHiVal, comp);
62+
nbl::hlsl::bitonic_sort::compareExchangeWithPartner(takeLarger, loPair, partnerLoPair, hiPair, partnerHiPair, comp);
6063
}
6164
}
6265
}
6366

64-
static void __call(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
65-
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
67+
static void __call(bool ascending, NBL_REF_ARG(pair<key_t, value_t>) loPair, NBL_REF_ARG(pair<key_t, value_t>) hiPair)
6668
{
6769
const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
6870
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
6971
[unroll]
7072
for (uint32_t stage = 0; stage <= subgroupSizeLog2; stage++)
7173
{
7274
const bool bitonicAscending = (stage == subgroupSizeLog2) ? ascending : !bool(invocationID & (1u << stage));
73-
mergeStage(stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal);
75+
mergeStage(stage, bitonicAscending, invocationID, loPair, hiPair);
7476
}
7577
}
7678
};

0 commit comments

Comments
 (0)