@@ -9,7 +9,7 @@ namespace hlsl
99namespace property_pools
1010{
1111
12- [[vk::push_constant]] GlobalPushContants globals;
12+ [[vk::push_constant]] TransferDispatchInfo globals;
1313
1414template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2, uint64_t DstIndexSizeLog2>
1515struct TransferLoop
@@ -39,12 +39,12 @@ struct TransferLoop
3939 else if (SrcIndexSizeLog2 == 3 ) vk::RawBufferStore<uint64_t>(dstAddressMapped, vk::RawBufferLoad<uint64_t>(srcAddressMapped));
4040 }
4141
42- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
42+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
4343 {
4444 uint64_t elementCount = uint64_t (transferRequest.elementCount32)
4545 | uint64_t (transferRequest.elementCountExtra) << 32 ;
46- uint64_t lastInvocation = min (elementCount, globals .endOffset);
47- for (uint64_t invocationIndex = globals .beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
46+ uint64_t lastInvocation = min (elementCount, dispatchInfo .endOffset);
47+ for (uint64_t invocationIndex = dispatchInfo .beginOffset + baseInvocationIndex; invocationIndex < lastInvocation; invocationIndex += dispatchSize)
4848 {
4949 iteration (propertyId, transferRequest, invocationIndex);
5050 }
@@ -62,58 +62,53 @@ struct TransferLoop
6262template<bool Fill, bool SrcIndexIota, bool DstIndexIota, uint64_t SrcIndexSizeLog2>
6363struct TransferLoopPermutationSrcIndexSizeLog
6464{
65- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
65+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
6666 {
67- if (transferRequest.dstIndexSizeLog2 == 0 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
68- else if (transferRequest.dstIndexSizeLog2 == 1 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
69- else if (transferRequest.dstIndexSizeLog2 == 2 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
70- else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
67+ if (transferRequest.dstIndexSizeLog2 == 0 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 0 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
68+ else if (transferRequest.dstIndexSizeLog2 == 1 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 1 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
69+ else if (transferRequest.dstIndexSizeLog2 == 2 ) { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 2 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
70+ else /*if (transferRequest.dstIndexSizeLog2 == 3)*/ { TransferLoop<Fill, SrcIndexIota, DstIndexIota, SrcIndexSizeLog2, 3 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
7171 }
7272};
7373
7474template<bool Fill, bool SrcIndexIota, bool DstIndexIota>
7575struct TransferLoopPermutationDstIota
7676{
77- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
77+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
7878 {
79- if (transferRequest.srcIndexSizeLog2 == 0 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
80- else if (transferRequest.srcIndexSizeLog2 == 1 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
81- else if (transferRequest.srcIndexSizeLog2 == 2 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
82- else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3 > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
79+ if (transferRequest.srcIndexSizeLog2 == 0 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 0 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
80+ else if (transferRequest.srcIndexSizeLog2 == 1 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 1 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
81+ else if (transferRequest.srcIndexSizeLog2 == 2 ) { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 2 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
82+ else /*if (transferRequest.srcIndexSizeLog2 == 3)*/ { TransferLoopPermutationSrcIndexSizeLog<Fill, SrcIndexIota, DstIndexIota, 3 > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
8383 }
8484};
8585
8686template<bool Fill, bool SrcIndexIota>
8787struct TransferLoopPermutationSrcIota
8888{
89- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
89+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
9090 {
9191 bool dstIota = transferRequest.dstIndexAddr == 0 ;
92- if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
93- else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
92+ if (dstIota) { TransferLoopPermutationDstIota<Fill, SrcIndexIota, true > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
93+ else { TransferLoopPermutationDstIota<Fill, SrcIndexIota, false > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
9494 }
9595};
9696
9797template<bool Fill>
9898struct TransferLoopPermutationFill
9999{
100- void copyLoop (uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
100+ void copyLoop (NBL_CONST_REF_ARG (TransferDispatchInfo) dispatchInfo, uint baseInvocationIndex, uint propertyId, TransferRequest transferRequest, uint dispatchSize)
101101 {
102102 bool srcIota = transferRequest.srcIndexAddr == 0 ;
103- if (srcIota) { TransferLoopPermutationSrcIota<Fill, true > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
104- else { TransferLoopPermutationSrcIota<Fill, false > loop; loop.copyLoop (baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
103+ if (srcIota) { TransferLoopPermutationSrcIota<Fill, true > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
104+ else { TransferLoopPermutationSrcIota<Fill, false > loop; loop.copyLoop (dispatchInfo, baseInvocationIndex, propertyId, transferRequest, dispatchSize); }
105105 }
106106};
107107
108- template<typename device_capabilities>
109- void main (uint32_t3 dispatchId)
110- {
111- const uint propertyId = dispatchId.y;
112- const uint invocationIndex = dispatchId.x;
113-
114- // Loading transfer request from the pointer (can't use struct
115- // with BDA on HLSL SPIRV)
116- uint64_t transferCmdAddr = globals.transferCommandsAddress + sizeof (TransferRequest) * propertyId;
108+ // Loading transfer request from the pointer (can't use struct
109+ // with BDA on HLSL SPIRV)
110+ static TransferRequest TransferRequest::newFromAddress (const uint64_t transferCmdAddr)
111+ {
117112 TransferRequest transferRequest;
118113 transferRequest.srcAddr = vk::RawBufferLoad<uint64_t>(transferCmdAddr,8 );
119114 transferRequest.dstAddr = vk::RawBufferLoad<uint64_t>(transferCmdAddr + sizeof (uint64_t),8 );
@@ -129,35 +124,31 @@ void main(uint32_t3 dispatchId)
129124 transferRequest.srcIndexSizeLog2 = uint32_t (bitfieldType >> (32 + 3 + 24 + 1 ));
130125 transferRequest.dstIndexSizeLog2 = uint32_t (bitfieldType >> (32 + 3 + 24 + 1 + 2 ));
131126
132- const uint dispatchSize = nbl::hlsl::device_capabilities_traits<device_capabilities>::maxOptimallyResidentWorkgroupInvocations;
127+ return transferRequest;
128+ }
129+
130+ template<typename device_capabilities>
131+ void main (uint32_t3 dispatchId, const uint dispatchSize)
132+ {
133+ const uint propertyId = dispatchId.y;
134+ const uint invocationIndex = dispatchId.x;
135+
136+ uint64_t transferCmdAddr = globals.transferCommandsAddress + sizeof (TransferRequest) * propertyId;
137+ TransferRequest transferRequest = TransferRequest::newFromAddress (transferCmdAddr);
138+
133139 const bool fill = transferRequest.fill == 1 ;
134140
135- //uint64_t debugWriteAddr = transferRequest.dstAddr + sizeof(uint64_t) * 9 * propertyId;
136- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 0, transferRequest.srcAddr,8);
137- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 1, transferRequest.dstAddr,8);
138- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 2, transferRequest.srcIndexAddr,8);
139- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 3, transferRequest.dstIndexAddr,8);
140- //uint64_t elementCount = uint64_t(transferRequest.elementCount32)
141- // | uint64_t(transferRequest.elementCountExtra) << 32;
142- //vk::RawBufferStore<uint64_t>(debugWriteAddr + sizeof(uint64_t) * 4, elementCount,8);
143- //vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 5, transferRequest.propertySize,4);
144- //vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 6, transferRequest.fill,4);
145- //vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 7, transferRequest.srcIndexSizeLog2,4);
146- //vk::RawBufferStore<uint32_t>(debugWriteAddr + sizeof(uint64_t) * 8, transferRequest.dstIndexSizeLog2,4);
147- //vk::RawBufferStore<uint64_t>(transferRequest.dstAddr + sizeof(uint64_t) * invocationIndex, invocationIndex,8);
148-
149- if (fill) { TransferLoopPermutationFill<true > loop; loop.copyLoop (invocationIndex, propertyId, transferRequest, dispatchSize); }
150- else { TransferLoopPermutationFill<false > loop; loop.copyLoop (invocationIndex, propertyId, transferRequest, dispatchSize); }
141+ if (fill) { TransferLoopPermutationFill<true > loop; loop.copyLoop (globals, invocationIndex, propertyId, transferRequest, dispatchSize); }
142+ else { TransferLoopPermutationFill<false > loop; loop.copyLoop (globals, invocationIndex, propertyId, transferRequest, dispatchSize); }
151143}
152144
153145}
154146}
155147}
156148
157- // TODO: instead use some sort of replace function for getting optimal size?
158- [numthreads (512 ,1 ,1 )]
149+ [numthreads (nbl::hlsl::property_pools::OptimalDispatchSize,1 ,1 )]
159150void main (uint32_t3 dispatchId : SV_DispatchThreadID )
160151{
161- nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId);
152+ nbl::hlsl::property_pools::main<nbl::hlsl::jit::device_capabilities>(dispatchId, nbl::hlsl::property_pools::OptimalDispatchSize );
162153}
163154
0 commit comments