@@ -30,8 +30,8 @@ struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_
3030 using value_t = typename config_t::value_t;
3131 using comparator_t = typename config_t::comparator_t;
3232
33- static void mergeStage (uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG (key_t) loKey, NBL_REF_ARG (key_t) hiKey,
34- NBL_REF_ARG (value_t) loVal , NBL_REF_ARG (value_t) hiVal )
33+ static void mergeStage (uint32_t stage, bool bitonicAscending, uint32_t invocationID,
34+ NBL_REF_ARG (pair<key_t, value_t>) loPair , NBL_REF_ARG (pair<key_t, value_t>) hiPair )
3535 {
3636 comparator_t comp;
3737
@@ -43,34 +43,36 @@ struct bitonic_sort<bitonic_sort_config<KeyType, ValueType, Comparator>, device_
4343 if (threadStride == 0 )
4444 {
4545 // Local compare and swap for stage 0
46- nbl::hlsl::bitonic_sort::compareSwap (bitonicAscending, loKey, hiKey, loVal, hiVal , comp);
46+ nbl::hlsl::bitonic_sort::compareSwap (bitonicAscending, loPair, hiPair , comp);
4747 }
4848 else
4949 {
5050 // Shuffle from partner using XOR
51- const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
52- const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
53- const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
54- const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);
51+ const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loPair.first, threadStride);
52+ const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loPair.second, threadStride);
53+ const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiPair.first, threadStride);
54+ const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiPair.second, threadStride);
55+
56+ const pair<key_t, value_t> partnerLoPair = make_pair (pLoKey, pLoVal);
57+ const pair<key_t, value_t> partnerHiPair = make_pair (pHiKey, pHiVal);
5558
5659 const bool isUpper = bool (invocationID & threadStride);
5760 const bool takeLarger = isUpper == bitonicAscending;
5861
59- nbl::hlsl::bitonic_sort::compareExchangeWithPartner (takeLarger, loKey, pLoKey, hiKey, pHiKey, loVal, pLoVal, hiVal, pHiVal , comp);
62+ nbl::hlsl::bitonic_sort::compareExchangeWithPartner (takeLarger, loPair, partnerLoPair, hiPair, partnerHiPair , comp);
6063 }
6164 }
6265 }
6366
64- static void __call (bool ascending, NBL_REF_ARG (key_t) loKey, NBL_REF_ARG (key_t) hiKey,
65- NBL_REF_ARG (value_t) loVal, NBL_REF_ARG (value_t) hiVal)
67+ static void __call (bool ascending, NBL_REF_ARG (pair<key_t, value_t>) loPair, NBL_REF_ARG (pair<key_t, value_t>) hiPair)
6668 {
6769 const uint32_t invocationID = glsl::gl_SubgroupInvocationID ();
6870 const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2 ();
6971 [unroll]
7072 for (uint32_t stage = 0 ; stage <= subgroupSizeLog2; stage++)
7173 {
7274 const bool bitonicAscending = (stage == subgroupSizeLog2) ? ascending : !bool (invocationID & (1u << stage));
73- mergeStage (stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal );
75+ mergeStage (stage, bitonicAscending, invocationID, loPair, hiPair );
7476 }
7577 }
7678};
0 commit comments