44#include <nbl/builtin/hlsl/cpp_compat.hlsl>
55#include <nbl/builtin/hlsl/concepts.hlsl>
66#include <nbl/builtin/hlsl/math/intutil.hlsl>
7+ #include <nbl/builtin/hlsl/pair.hlsl>
78
89namespace nbl
910{
@@ -14,24 +15,22 @@ namespace bitonic_sort
1415
1516template<typename KeyType, typename ValueType, typename Comparator>
1617void compareExchangeWithPartner (
17- bool takeLarger,
18- NBL_REF_ARG (KeyType) loKey,
19- NBL_CONST_REF_ARG (KeyType) partnerLoKey,
20- NBL_REF_ARG (KeyType) hiKey,
21- NBL_CONST_REF_ARG (KeyType) partnerHiKey,
22- NBL_REF_ARG (ValueType) loVal,
23- NBL_CONST_REF_ARG (ValueType) partnerLoVal,
24- NBL_REF_ARG (ValueType) hiVal,
25- NBL_CONST_REF_ARG (ValueType) partnerHiVal,
26- NBL_CONST_REF_ARG (Comparator) comp)
18+ bool takeLarger,
19+ NBL_REF_ARG (KeyType) loKey,
20+ NBL_CONST_REF_ARG (KeyType) partnerLoKey,
21+ NBL_REF_ARG (KeyType) hiKey,
22+ NBL_CONST_REF_ARG (KeyType) partnerHiKey,
23+ NBL_REF_ARG (ValueType) loVal,
24+ NBL_CONST_REF_ARG (ValueType) partnerLoVal,
25+ NBL_REF_ARG (ValueType) hiVal,
26+ NBL_CONST_REF_ARG (ValueType) partnerHiVal,
27+ NBL_CONST_REF_ARG (Comparator) comp)
2728{
28- // Process lo pair
2929 const bool loSelfSmaller = comp (loKey, partnerLoKey);
3030 const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller;
3131 loKey = takePartnerLo ? partnerLoKey : loKey;
3232 loVal = takePartnerLo ? partnerLoVal : loVal;
3333
34- // Process hi pair
3534 const bool hiSelfSmaller = comp (hiKey, partnerHiKey);
3635 const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller;
3736 hiKey = takePartnerHi ? partnerHiKey : hiKey;
@@ -41,12 +40,12 @@ void compareExchangeWithPartner(
4140
4241template<typename KeyType, typename ValueType, typename Comparator>
4342void compareSwap (
44- bool ascending,
45- NBL_REF_ARG (KeyType) loKey,
46- NBL_REF_ARG (KeyType) hiKey,
47- NBL_REF_ARG (ValueType) loVal,
48- NBL_REF_ARG (ValueType) hiVal,
49- NBL_CONST_REF_ARG (Comparator) comp)
43+ bool ascending,
44+ NBL_REF_ARG (KeyType) loKey,
45+ NBL_REF_ARG (KeyType) hiKey,
46+ NBL_REF_ARG (ValueType) loVal,
47+ NBL_REF_ARG (ValueType) hiVal,
48+ NBL_CONST_REF_ARG (Comparator) comp)
5049{
5150 const bool shouldSwap = comp (hiKey, loKey);
5251
@@ -63,10 +62,10 @@ void compareSwap(
6362
6463template<typename KeyType, typename ValueType>
6564void swap (
66- NBL_REF_ARG (KeyType) loKey,
67- NBL_REF_ARG (KeyType) hiKey,
68- NBL_REF_ARG (ValueType) loVal,
69- NBL_REF_ARG (ValueType) hiVal)
65+ NBL_REF_ARG (KeyType) loKey,
66+ NBL_REF_ARG (KeyType) hiKey,
67+ NBL_REF_ARG (ValueType) loVal,
68+ NBL_REF_ARG (ValueType) hiVal)
7069{
7170 KeyType tempKey = loKey;
7271 loKey = hiKey;
@@ -77,6 +76,55 @@ void swap(
7776 hiVal = tempVal;
7877}
7978
79+
80+
81+ template<typename KeyType, typename ValueType, typename Comparator>
82+ void compareExchangeWithPartner (
83+ bool takeLarger,
84+ NBL_REF_ARG (pair<KeyType, ValueType>) loPair,
85+ NBL_CONST_REF_ARG (pair<KeyType, ValueType>) partnerLoPair,
86+ NBL_REF_ARG (pair<KeyType, ValueType>) hiPair,
87+ NBL_CONST_REF_ARG (pair<KeyType, ValueType>) partnerHiPair,
88+ NBL_CONST_REF_ARG (Comparator) comp)
89+ {
90+ const bool loSelfSmaller = comp (loPair.first, partnerLoPair.first);
91+ const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller;
92+ loPair.first = takePartnerLo ? partnerLoPair.first : loPair.first;
93+ loPair.second = takePartnerLo ? partnerLoPair.second : loPair.second;
94+
95+ const bool hiSelfSmaller = comp (hiPair.first, partnerHiPair.first);
96+ const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller;
97+ hiPair.first = takePartnerHi ? partnerHiPair.first : hiPair.first;
98+ hiPair.second = takePartnerHi ? partnerHiPair.second : hiPair.second;
99+ }
100+
101+ template<typename KeyType, typename ValueType, typename Comparator>
102+ void compareSwap (
103+ bool ascending,
104+ NBL_REF_ARG (pair<KeyType, ValueType>) loPair,
105+ NBL_REF_ARG (pair<KeyType, ValueType>) hiPair,
106+ NBL_CONST_REF_ARG (Comparator) comp)
107+ {
108+ const bool shouldSwap = comp (hiPair.first, loPair.first);
109+ const bool doSwap = (shouldSwap == ascending);
110+
111+ KeyType tempKey = loPair.first;
112+ ValueType tempVal = loPair.second;
113+ loPair.first = doSwap ? hiPair.first : loPair.first;
114+ loPair.second = doSwap ? hiPair.second : loPair.second;
115+ hiPair.first = doSwap ? tempKey : hiPair.first;
116+ hiPair.second = doSwap ? tempVal : hiPair.second;
117+ }
118+
119+ template<typename KeyType, typename ValueType>
120+ void swap (
121+ NBL_REF_ARG (pair<KeyType, ValueType>) loPair,
122+ NBL_REF_ARG (pair<KeyType, ValueType>) hiPair)
123+ {
124+ pair<KeyType, ValueType> temp = loPair;
125+ loPair = hiPair;
126+ hiPair = temp;
127+ }
80128}
81129}
82130}
0 commit comments