From 3d3d6536fdd09c6657ca06875cca4201bf2f99a2 Mon Sep 17 00:00:00 2001 From: Christian Helgeson <62450112+cmhhelgeson@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:04:22 -0700 Subject: [PATCH 01/10] init branch --- src/Three.TSL.js | 3 +++ src/nodes/math/MathNode.js | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/Three.TSL.js b/src/Three.TSL.js index e31e626a38a2b9..529be9a5798632 100644 --- a/src/Three.TSL.js +++ b/src/Three.TSL.js @@ -80,6 +80,7 @@ export const batch = TSL.batch; export const bentNormalView = TSL.bentNormalView; export const billboarding = TSL.billboarding; export const bitAnd = TSL.bitAnd; +export const bitCount = TSL.bitCount; export const bitNot = TSL.bitNot; export const bitOr = TSL.bitOr; export const bitXor = TSL.bitXor; @@ -183,6 +184,8 @@ export const expression = TSL.expression; export const faceDirection = TSL.faceDirection; export const faceForward = TSL.faceForward; export const faceforward = TSL.faceforward; +export const findLSB = TSL.findLSB; +export const findMSB = TSL.findMSB; export const float = TSL.float; export const floatBitsToInt = TSL.floatBitsToInt; export const floatBitsToUint = TSL.floatBitsToUint; diff --git a/src/nodes/math/MathNode.js b/src/nodes/math/MathNode.js index 8271477e82f6ec..ca221ec1c0e671 100644 --- a/src/nodes/math/MathNode.js +++ b/src/nodes/math/MathNode.js @@ -364,6 +364,9 @@ MathNode.FWIDTH = 'fwidth'; MathNode.TRANSPOSE = 'transpose'; MathNode.DETERMINANT = 'determinant'; MathNode.INVERSE = 'inverse'; +MathNode.COUNT_TRAILING_ZEROS = 'countTrailingZeros'; +MathNode.COUNT_LEADING_ZEROS = 'countLeadingZeros'; +MathNode.COUNT_ONE_BITS = 'countOneBits'; // 2 inputs @@ -1099,10 +1102,50 @@ export const atan2 = ( y, x ) => { // @deprecated, r172 }; + +/** + * Finds the number of consecutive 0 bits from the least significant bit of the input value, + * which is also the index of the least significant bit of the input value. + * + * Can only be used with {@link WebGPURenderer} and a WebGPU backend. + * + * @tsl + * @function + * @param {Node | number} x - The input value. + * @returns {Node} + */ +export const countTrailingZeros = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_TRAILING_ZEROS ).setParameterLength( 1 ); + +/** + * Finds the number of consecutive 0 bits starting from the most significant bit of the input value. + * + * Can only be used with {@link WebGPURenderer} and a WebGPU backend. + * + * @tsl + * @function + * @param {Node | number} x - The input value. + * @returns {Node} + */ +export const countLeadingZeros = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_LEADING_ZEROS ).setParameterLength( 1 ); + +/** + * Finds the number of '1' bits set in the input value + * + * Can only be used with {@link WebGPURenderer} and a WebGPU backend. + * + * @tsl + * @function + * @returns {Node} + */ +export const countOneBits = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_ONE_BITS ).setParameterLength( 1 ); + // GLSL alias function export const faceforward = faceForward; export const inversesqrt = inverseSqrt; +export const findLSB = countTrailingZeros; +export const findMSB = countLeadingZeros; +export const bitCount = countOneBits; // Method chaining @@ -1165,3 +1208,6 @@ addMethodChaining( 'transpose', transpose ); addMethodChaining( 'determinant', determinant ); addMethodChaining( 'inverse', inverse ); addMethodChaining( 'rand', rand ); +addMethodChaining( 'countTrailingZeros', countTrailingZeros ); +addMethodChaining( 'countLeadingZeros', countLeadingZeros ); +addMethodChaining( 'countOneBits', countOneBits ); From 9ef111a43016b4dec57f71914630acc2fb0065f9 Mon Sep 17 00:00:00 2001 From: Christian Helgeson <62450112+cmhhelgeson@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:33:14 -0700 Subject: [PATCH 02/10] sketch out solve work on src use setup approach and register functions that need to be created bitcountNode revert glslNodeBuilder changes lint-fix revert unintended changes to glslNodeBuilder and MathNode, change bitCount helper function names remove extra _include change mainLayout function to accurately reflect WGSL documentation (component wise bitCounts rather than accumulative bitCounts fix nodeType init branch fix prefix sum from other branch eod changes validate reduction, fix display/debug issues fix prefix sum errors fix fix spine scan and work on fixing downsweep add comments work work working prefix sum --- examples/jsm/gpgpu/BitonicSort.js | 4 +- examples/jsm/gpgpu/PrefixSum.js | 967 ++++++++++++++++++++++++ examples/webgpu_compute_prefix_sum.html | 339 +++++++++ src/Three.TSL.js | 1 + src/nodes/core/IndexNode.js | 9 + src/nodes/core/NodeBuilder.js | 7 + src/nodes/math/BitcountNode.js | 46 ++ src/nodes/math/MathNode.js | 46 -- 8 files changed, 1371 insertions(+), 48 deletions(-) create mode 100644 examples/jsm/gpgpu/PrefixSum.js create mode 100644 examples/webgpu_compute_prefix_sum.html diff --git a/examples/jsm/gpgpu/BitonicSort.js b/examples/jsm/gpgpu/BitonicSort.js index 2a1175c198a23d..4059c23b02adca 100644 --- a/examples/jsm/gpgpu/BitonicSort.js +++ b/examples/jsm/gpgpu/BitonicSort.js @@ -79,9 +79,9 @@ export class BitonicSort { /** * Constructs a new light probe helper. * - * @param {Renderer} renderer - The current scene's renderer. + * @param {Renderer} renderer - A renderer with the ability to execute compute operations. * @param {StorageBufferNode} dataBuffer - The data buffer to sort. - * @param {Object} [options={}] - Options that modify the bitonic sort. + * @param {Object} [options={}] - Options that modify the behavior of the bitonic sort. */ constructor( renderer, dataBuffer, options = {} ) { diff --git a/examples/jsm/gpgpu/PrefixSum.js b/examples/jsm/gpgpu/PrefixSum.js new file mode 100644 index 00000000000000..f88d56bed84636 --- /dev/null +++ b/examples/jsm/gpgpu/PrefixSum.js @@ -0,0 +1,967 @@ +import { Fn, If, instancedArray, invocationLocalIndex, countTrailingZeros, Loop, workgroupArray, subgroupSize, workgroupBarrier, workgroupId, uint, select, invocationSubgroupIndex, dot, uvec4, vec4, float, subgroupAdd, array, subgroupShuffle, subgroupInclusiveAdd, subgroupBroadcast, invocationSubgroupMetaIndex, arrayBuffer } from 'three/tsl'; + +const divRoundUp = ( size, part_size ) => { + + return Math.floor( ( size + part_size - 1 ) / part_size ); + +}; + +let id = 0; + +/** + * Storage buffers needed to execute a reduce-then-scan prefix sum`. + * + * @typedef {Object} PrefixSumStorageObjects + * @property {StorageBufferNode} reductionBuffer - Storage data buffer holding the reduction of each workgroup from the reduce step. + * @property {StorageBufferNode} dataBuffer - Storage data buffer holding the vectorized input data. + * @property {StorageBufferNode} unvectorizedDataBuffer - Storage data buffer holding the unvectorized input data. + * @property {StorageBufferNode} outputBuffer - Storage data buffer that returns the unvectorized output data of the prefix sum. + */ + +/** + * Compute functions needed to execute a reduce-then-scan prefix sum`. + * + * @typedef {Object} PrefixSumComputeFunctions + * @property {ComputeNode} reduceFn - A compute shader that executes the reduce step of a reduce-then-scan prefix sum. + * @property {ComputeNode} spineScanFn - A compute shader that executes the spine scan step of a reduce-then-scan prefix sum. + * @property {ComputeNode} downsweepFn - A compute shader that executes the downsweep step of a reduce-then-scan prefix sum. + */ + +/** + * Utility nodes used in multiple shaders across the reduce-then-scan prefix sum`. + * + * @typedef {Object} PrefixSumUtilityNodes + * @property {WorkgroupInfoNode} subgroupReductionArray - A workgroup memory buffer representing a workgroup scoped buffer that holds the result of a subgroup operation from each subgroup in a workgroup. Sized to account for minimumn WGSL subgroup size of 4. + * @property {Node} workgroupOffset - A node representing the vec4-alligned offset at which the workgroup with index 'workgroupId.x' will begin reading vec4 elements from the data buffer. + * @property {Node} subgroupOffset - A node representing the vec4-alligned offset from 'this.workgroupOffset' at which the subgroup with index 'subgroupMetaRank' will begin reading vec4 elements from a data buffer. + * @property {Node} unvectorizedSubgroupOffset - A node representing the uint-alligned offset from 'this.workgroupOffset' at which the subgroup with index 'subgroupMetaRank' will begin reading uint elements from a data buffer. + * @property {Node} subgroupSizeLog - A node that evaulates to n in 2^n = subgroupSize. + * @property {Node} spineSize - A node that calculates the number of partial reductions in a workgroup scan, or the number of subgroups in a workgroup on the current device. + * @property {Node} spineSizeLog - A node that evaluates to n in 2^n = spineSize. + */ + + +/** + * A class that represents a prefix sum running under the reduce/scan strategy. + * Currently limited to one-dimensional data buffers. + * + * @param {Renderer} renderer - A renderer with the ability to execute compute operations. + * @param {StorageBufferNode} dataBuffer - The data buffer to sum. + * @param {Object} [options={}] - Options that modify the reduce/scan prefix sum. + */ +export class PrefixSum { + + /** + * Constructs a new light probe helper. + * + * @param {Renderer} renderer - A renderer with the ability to execute compute operations. + * @param {number[]} inputArray - The data buffer to sum. + * @param {'uint' | 'float'} inputArrayType - Type of input array + * @param {Object} [options={}] - Options that modify the behavior of the prefix sum. + */ + constructor( renderer, inputArray, inputArrayType, options = {} ) { + + /** + * A reference to the renderer. + * + * @type {Renderer} + */ + this.renderer = renderer; + + /** + * @type {PrefixSumStorageObjects} + */ + this.storageBuffers = {}; + + /** + * @type {PrefixSumComputeFunctions} + */ + this.computeFunctions = {}; + + + /** + * @type {PrefixSumUtilityNodes} + */ + this.utilityNodes = {}; + + this.type = inputArrayType; + this.vecType = inputArrayType === 'uint' ? 'uvec4' : 'vec4'; + + /** + * The size of the data. + * + * @type {number} + */ + this.count = inputArray.length; + + /** + * The number of 4-dimensional vectors needed to fully represent the data in the data buffer. + * Buffers where this.count % 4 !== 0 will need an additional vec4 to hold the data buffer's + * remaining elements. + * + * @type {number} + */ + this.vecCount = divRoundUp( this.count, 4 ); + + while ( inputArray.length % 4 !== 0 ) { + + inputArray.push( 0 ); + + } + + /** + * The number of 4-dimensional vectors that will be read from global storage in each invocation of the reduction/downsweep step. + * Defaults to 4. + * + * @type {number} + */ + this.workPerInvocation = options.workPerInvocation ? options.workPerInvocation : 4; + + /** + * The number of unvectorized values to be read from the reduction buffer in each invocation of the spine/scan step. + * Derived from workPerInvocation and thus defaults to 16. + * + * @type {number} + */ + this.unvectorizedWorkPerInvocation = this.workPerInvocation * 4; + + /** + * The workgroup size of the compute shaders executed during the prefix sum. + * If no workgroupSize is defined, the workgroupSize defaults to the minimumn between the number of elements in the + * data buffer and 64. + * + * @type {number} + */ + this.workgroupSize = options.workgroupSize ? options.workgroupSize : Math.min( this.vecCount, 64 ); + + /** + * The maximum number of elements that will be read by an individual workgroup in the reduction step. + * Calculated as the number of invocations in the workgroup by the work per invocation by VEC4_SIZE + * + * @type {number} + */ + this.partitionSize = this.workgroupSize * this.unvectorizedWorkPerInvocation; + + /** + * The number of workgroups needed to properly execute the reduction and downsweepsteps. + * Calculated as the number of partitions within the count of elements. + * + * @type {number} + */ + this.numWorkgroups = divRoundUp( this.count, this.partitionSize ); + + /** + * The number of invocations dispatched in each step of the prefix sum. + * + * @type {number} + */ + this.dispatchSize = this.numWorkgroups * this.workgroupSize; + + this._createStorageBuffers( inputArray, inputArrayType, this.vecType, this.numWorkgroups ); + this._createUtilityNodes(); + + /** + * The step of the prefix sum to execute. + * + * @type {'Reduce' | 'Spine_Scan' | 'Downsweep'} + */ + this.currentStep = 'Reduce'; + + + this.computeFunctions.reduceFn = this._getReduceFn(); + this.computeFunctions.spineScanFn = this._getSpineScanFn(); + this.computeFunctions.downsweepFn = this._getDownsweepFn(); + + id += 1; + + } + + _createStorageBuffers( inputArray ) { + + this.arrayBuffer = this.type === 'uint' ? Uint32Array.from( inputArray ) : Float32Array.from( inputArray ); + + this.storageBuffers.unvectorizedDataBuffer = instancedArray( this.arrayBuffer, this.type ).setPBO( true ).setName( `Prefix_Sum_Input_Unvec_${id}` ); + this.storageBuffers.dataBuffer = instancedArray( this.arrayBuffer, this.vecType ).setPBO( true ).setName( `Prefix_Sum_Input_Vec_${id}` ); + this.storageBuffers.outputBuffer = instancedArray( this.arrayBuffer, this.vecType ).setName( `Prefix_Sum_Output_${id}` ); + this.storageBuffers.reductionBuffer = instancedArray( this.numWorkgroups, this.type ).setPBO( true ).setName( `Prefix_Sum_Reduction_${id}` ); + + } + + _createUtilityNodes() { + + this.utilityNodes.subgroupReductionArray = workgroupArray( this.type, Math.ceil( this.workgroupSize / 4 ) ); + this.utilityNodes.workgroupOffset = workgroupId.x.mul( uint( this.workgroupSize ).mul( this.workPerInvocation ) ).toVar( 'workgroupOffset' ); + this.utilityNodes.subgroupOffset = invocationSubgroupMetaIndex.mul( subgroupSize ).mul( this.workPerInvocation ).toVar( 'subgroupOffset' ); + this.utilityNodes.unvectorizedSubgroupOffset = invocationSubgroupMetaIndex.mul( subgroupSize ).mul( this.unvectorizedWorkPerInvocation ).toVar( 'unvectorizedSubgroupOffset' ); + this.utilityNodes.subgroupSizeLog = countTrailingZeros( subgroupSize ).toVar( 'subgroupSizeLog' ); + this.utilityNodes.spineSize = uint( this.workgroupSize ).shiftRight( this.utilityNodes.subgroupSizeLog ).toVar( 'spineSize' ); + this.utilityNodes.spineSizeLog = countTrailingZeros( this.utilityNodes.spineSize ).toVar( 'spineSizeLog' ); + + } + + _getSubgroupAlignedSize() { + + const { spineSizeLog, subgroupSizeLog } = this.utilityNodes; + + // Align size to powers of subgroupSize + const squaredSubgroupLog = ( spineSizeLog.add( subgroupSizeLog ).sub( 1 ) ); + squaredSubgroupLog.divAssign( subgroupSizeLog ); + squaredSubgroupLog.mulAssign( subgroupSizeLog ); + const subgroupAlignedSize = ( uint( 1 ).shiftLeft( squaredSubgroupLog ) ).toVar( 'subgroupAlignedSize' ); + + return subgroupAlignedSize; + + } + + + // NOTE: subgroupSizeLog needs to be defined in this._getSubgroupAlignedSize before this block can execute + _subgroupAlignedSizeBlock( subgroupAlignedSize, subgroupAllignedBlockCallback ) { + + // In cases where the number of subgroups in a workgroup is greater than the subgroup size itself, + // we need to iterate over the array again to capture all the data in the workgroup array buffer + // In many cases this loop will only run once + Loop( { start: subgroupSize, end: subgroupAlignedSize, condition: '<=', name: 'j', type: 'uint', update: '<<= subgroupSizeLog' }, ( { j } ) => { + + subgroupAllignedBlockCallback( j ); + + } ); + + } + + _getSpineAlignedSize() { + + const { numWorkgroups, partitionSize } = this; + + const SPINE_PARTITION_SIZE = uint( partitionSize ).toVar( 'spinePartitionSize' ); + + const spineAlignedSize = ( SPINE_PARTITION_SIZE.add( numWorkgroups ).sub( 1 ) ).toVar( 'spineAlignedSize' ); + spineAlignedSize.divAssign( SPINE_PARTITION_SIZE ); + spineAlignedSize.mulAssign( SPINE_PARTITION_SIZE ); + + return spineAlignedSize; + + } + + _getSpineAlignedBlock( spineAlignedSize, spineAlignedBlockCallback ) { + + // Allignment in cases where num elements is (SPINE_PARTITION_SIZE * SPINE_PARTITION_SIZE) + 1 + Loop( { start: 0, end: spineAlignedSize, condition: '<', name: 'j', type: 'uint', update: '+= spinePartitionSize' }, ( { j } ) => { + + spineAlignedBlockCallback( j ); + + } ); + + } + + _workPerInvocationBlock( workgroupCallback, lastWorkgroupCallback ) { + + const { numWorkgroups, workPerInvocation } = this; + + // Each thread will accumulate values from across 'workPerInvocation' subgroups + If( workgroupId.x.lessThan( uint( numWorkgroups ).sub( 1 ) ), () => { + + Loop( { + start: uint( 0 ), + end: workPerInvocation, + type: 'uint', + condition: '<', + name: 'currentSubgroupInBlock' + }, ( { currentSubgroupInBlock } ) => { + + workgroupCallback( currentSubgroupInBlock ); + + } ); + + } ); + + // Ensure that the last workgroup does not access out of bounds indices + If( workgroupId.x.equal( uint( numWorkgroups ).sub( 1 ) ), () => { + + Loop( { + start: uint( 0 ), + end: workPerInvocation, + type: 'uint', + condition: '<', + name: 'currentSubgroupInBlock' + }, ( { currentSubgroupInBlock } ) => { + + lastWorkgroupCallback( currentSubgroupInBlock ); + + } ); + + } ); + + } + + /** + * Create the compute shader that performs the reduce operation. + * + * @private + * @returns {ComputeNode} - A compute shader that executes a full local swap. + */ + _getReduceFn() { + + const { reductionBuffer, dataBuffer } = this.storageBuffers; + const { vecCount } = this; + const { subgroupSizeLog, subgroupReductionArray, subgroupOffset, workgroupOffset, spineSize } = this.utilityNodes; + + const fnDef = Fn( () => { + + // Each subgroup block scans across 4 subgroups. So when we move into a new subgroup, + // align that subgroups' accesses to the next 4 subgroups + const threadSubgroupOffset = subgroupOffset.add( invocationSubgroupIndex ).toVar( 'threadSubgroupOffset' ); + + const startThreadBase = threadSubgroupOffset.add( workgroupOffset ).toVar( 'startThreadBase' ); + + const startThread = startThreadBase.toVar( 'startThread' ); + + let subgroupReduction; + + if ( this.type === 'uint' ) { + + subgroupReduction = uint( 0 ); + + } else { + + subgroupReduction = float( 0 ); + + } + + this._workPerInvocationBlock( () => { + + // Get vectorized element from input array + const val = dataBuffer.element( startThread ); + + + // Sum values within vec4 together by using result of dot product + if ( this.vecType === 'uvec4' ) { + + subgroupReduction.addAssign( dot( uvec4( 1 ), val ) ); + + } else { + + subgroupReduction.addAssign( dot( vec4( 1 ), val ) ); + + } + + // Increment so thread will scan value in next subgroup + startThread.addAssign( subgroupSize ); + + + }, () => { + + let val; + if ( this.vecType === 'uvec4' ) { + + // Ensure index is less than number of available vectors in inputBuffer + val = select( startThread.lessThan( uint( vecCount ) ), dataBuffer.element( startThread ), uvec4( 0 ) ).uniformFlow(); + + subgroupReduction.addAssign( dot( val, uvec4( 1 ) ) ); + + } else { + + // Ensure index is less than number of available vectors in inputBuffer + val = select( startThread.lessThan( uint( vecCount ) ), dataBuffer.element( startThread ), vec4( 0 ) ).uniformFlow(); + + subgroupReduction.addAssign( dot( val, vec4( 1 ) ) ); + + + } + + startThread.addAssign( subgroupSize ); + + } ); + + subgroupReduction.assign( subgroupAdd( subgroupReduction ) ); + + // Assuming that each element in the input buffer is 1, we generally expect each invocation's subgroupReduction + // value to be ELEMENTS_PER_VEC4 * workPerInvocation * subgroupSize + + // Delegate one thread per subgroup to assign each subgroup's reduction to the workgroup array + If( invocationSubgroupIndex.equal( uint( 0 ) ), () => { + + subgroupReductionArray.element( invocationSubgroupMetaIndex ).assign( subgroupReduction ); + + } ); + + // Ensure that each workgroup has populated the perSubgroupReductionArray with data + // from each of it's subgroups + workgroupBarrier(); + + // WORKGROUP LEVEL REDUCE + + const subgroupAlignedSize = this._getSubgroupAlignedSize(); + + // aligned size 2 * 4 + + const offset = uint( 0 ); + + // In cases where the number of subgroups in a workgroup is greater than the subgroup size itself, + // we need to iterate over the array again to capture all the data in the workgroup array buffer + // In many cases this loop will only run once + this._subgroupAlignedSizeBlock( subgroupAlignedSize, () => { + + const subgroupIndex = ( ( invocationLocalIndex.add( 1 ) ).shiftLeft( offset ) ).sub( 1 ); + + const isValidSubgroupIndex = subgroupIndex.lessThan( spineSize ).toVar( 'isValidSubgroupIndex' ); + + // Reduce values within the local workgroup memory. + // Set toVar to ensure subgroupAdd executes before (not within) the if statement. + const t = subgroupAdd( + select( + isValidSubgroupIndex, + subgroupReductionArray.element( subgroupIndex ), + 0 + ).uniformFlow() + ).toVar( 't' ); + + // Can assign back to workgroupArray since all + // subgroup threads work in lockstop for subgroupAdd + If( isValidSubgroupIndex, () => { + + subgroupReductionArray.element( subgroupIndex ).assign( t ); + + } ); + + // Ensure all threads have completed work + + workgroupBarrier(); + + offset.addAssign( subgroupSizeLog ); + + } ); + + // Assign single thread from workgroup to assign workgroup reduction + If( invocationLocalIndex.equal( uint( 0 ) ), () => { + + const reducedWorkgroupSum = subgroupReductionArray.element( uint( spineSize ).sub( 1 ) ); + + // TODO: Comment out in prod + // dataBuffer.element( workgroupId.x ).assign( reducedWorkgroupSum ); + + reductionBuffer.element( workgroupId.x ).assign( reducedWorkgroupSum ); + + } ); + + } )().compute( this.dispatchSize, [ this.workgroupSize ] ); + + return fnDef; + + } + + /** + * Executes a downsweep operation on the data buffer. + * + * @param {Node} inputNode - The input node. + * @param {Node | number} maskNode - The number of bits to mask. + * @return {Node} + */ + _maskLowerBits( inputNode, maskNode ) { + + return ( inputNode.shiftRight( maskNode ) ).shiftLeft( maskNode ); + + } + + + /** + * Create the compute shader that performs the spine scan operation. + * + * @private + * @returns {ComputeNode} - A compute shader that executes a full local swap. + */ + _getSpineScanFn() { + + const { reductionBuffer } = this.storageBuffers; + const { subgroupReductionArray, unvectorizedSubgroupOffset, spineSize, subgroupSizeLog } = this.utilityNodes; + const { unvectorizedWorkPerInvocation } = this; + + const fnDef = Fn( () => { + + const subgroupAlignedSize = this._getSubgroupAlignedSize(); + const spineAlignedSize = this._getSpineAlignedSize(); + + const t_scan = array( 'uint', 16 ).toVar(); + const previousReduction = uint( 0 ).toVar( 'previousReduction' ); + + const s_offset = unvectorizedSubgroupOffset.add( invocationSubgroupIndex ).toVar( 's_offset' ); + + this._getSpineAlignedBlock( spineAlignedSize, ( devOffset ) => { + + const reducedWorkgroupIndex = s_offset.add( devOffset ); + + Loop( { + start: uint( 0 ), + end: uint( unvectorizedWorkPerInvocation ), + type: 'uint', + condition: '<', + name: 'k' + }, ( { k } ) => { + + // The reduction buffer holds a collection of reductions from within + // each indice's respective workgroup, so ensure that we only access + // valid workgroup indices + + If( reducedWorkgroupIndex.lessThan( this.numWorkgroups ), () => { + + t_scan.element( k ).assign( reductionBuffer.element( reducedWorkgroupIndex ) ); + + } ); + + reducedWorkgroupIndex.addAssign( subgroupSize ); + + } ); + + const prev = uint( 0 ).toVar( 'prev' ); + Loop( { + start: uint( 0 ), + end: uint( unvectorizedWorkPerInvocation ), + type: 'uint', + condition: '<', + update: '+= 1u', + name: 'k' + }, ( { k } ) => { + + const tScanElement = t_scan.element( k ); + + tScanElement.assign( subgroupInclusiveAdd( tScanElement ).add( prev ) ); + prev.assign( subgroupShuffle( tScanElement, subgroupSize.sub( 1 ) ) ); + + } ); + + if ( invocationSubgroupIndex.equal( subgroupSize.sub( 1 ) ) ) { + + subgroupReductionArray.element( invocationSubgroupMetaIndex ).assign( prev ); + + } + + workgroupBarrier(); + + const offset0 = uint( 0 ).toVar(); + const offset1 = uint( 0 ).toVar(); + + this._subgroupAlignedSizeBlock( subgroupAlignedSize, ( j ) => { + + const isValidSubgroupIndex = j.notEqual( subgroupSize ); + const isValidSubgroupInt = select( isValidSubgroupIndex, uint( 1 ), uint( 0 ) ).uniformFlow(); + + const i0 = ( invocationLocalIndex.add( offset0 ) ).shiftLeft( offset1 ).sub( isValidSubgroupInt ); + const pred0 = i0.lessThan( spineSize ); + + // Need to cast toVar() here otherwise subgroupInclusiveAdd gets inlined within a non-uniform block + const t0 = subgroupInclusiveAdd( select( pred0, subgroupReductionArray.element( i0 ), uint( 0 ) ).uniformFlow() ).toVar(); + + If( pred0, () => { + + subgroupReductionArray.element( i0 ).assign( t0 ); + + } ); + + If( isValidSubgroupIndex, () => { + + const rShift = j.shiftRight( subgroupSizeLog ); + const i1 = invocationLocalIndex.add( rShift ); + + const weirdValue = i1.bitAnd( j.sub( 1 ) ); + + If( weirdValue.greaterThanEqual( rShift ), () => { + + const pred1 = i1.lessThan( spineSize ); + + const t1 = select( pred1, subgroupReductionArray.element( this._maskLowerBits( i1, offset1 ).sub( 1 ) ), 0 ).uniformFlow(); + + If( + pred1.and( + ( i1.add( 1 ).bitAnd( rShift.sub( 1 ) ) ).notEqual( 0 ) + ), () => { + + subgroupReductionArray.element( i1 ).addAssign( t1 ); + + } ); + + + } ); + + + } ).Else( () => { + + offset0.addAssign( 1 ); + + } ); + + offset1.addAssign( subgroupSizeLog ); + + } ); + + workgroupBarrier(); + + const lastSubgroupReduction = select( + invocationSubgroupMetaIndex.notEqual( 0 ), + subgroupReductionArray.element( invocationSubgroupMetaIndex.sub( 1 ) ), + uint( 0 ) + ).uniformFlow(); + + const newPrev = lastSubgroupReduction.add( previousReduction ); + + const i = s_offset.add( devOffset ); + + Loop( { + start: uint( 0 ), + end: uint( unvectorizedWorkPerInvocation ), + type: 'uint', + condition: '<', + name: 'k' + }, ( { k } ) => { + + If( i.lessThan( this.numWorkgroups ), () => { + + reductionBuffer.element( i ).assign( t_scan.element( k ).add( newPrev ) ); + + } ); + + i.addAssign( subgroupSize ); + + + } ); + + previousReduction.addAssign( subgroupBroadcast( subgroupReductionArray.element( subgroupAlignedSize.sub( 1 ) ), 0 ) ); + workgroupBarrier(); + + } ); + + } )().compute( this.numWorkgroups, [ this.workgroupSize ] ); + + console.log( fnDef ); + + return fnDef; + + } + + _getDownsweepFn() { + + const { dataBuffer, reductionBuffer, outputBuffer } = this.storageBuffers; + const { vecType } = this; + const { subgroupOffset, workgroupOffset, subgroupReductionArray, subgroupSizeLog, spineSize } = this.utilityNodes; + + const { workPerInvocation, vecCount } = this; + + const fnDef = Fn( () => { + + const threadSubgroupOffset = subgroupOffset.add( invocationSubgroupIndex ); + + const startThreadBase = threadSubgroupOffset.add( workgroupOffset ); + + const startThread = startThreadBase.toVar(); + + const vec4FilledWithZeroArray = []; + + for ( let i = 0; i < workPerInvocation; i ++ ) { + + vec4FilledWithZeroArray.push( uvec4( 0 ) ); + + } + + const tScan = array( vec4FilledWithZeroArray ).toVar(); + + // Prefix Sum elements within individual vec4 elements + + this._workPerInvocationBlock( ( currentSubgroupInBlock ) => { + + const scanIn = dataBuffer.element( startThread ); + const currentTScanElement = tScan.element( currentSubgroupInBlock ); + + console.log( currentTScanElement ); + + currentTScanElement.assign( scanIn ); + + currentTScanElement.y.addAssign( currentTScanElement.x ); + currentTScanElement.z.addAssign( currentTScanElement.y ); + currentTScanElement.w.addAssign( currentTScanElement.z ); + + startThread.addAssign( subgroupSize ); + + }, ( currentSubgroupInBlock ) => { + + If( startThread.lessThan( uint( vecCount ) ), () => { + + const scanIn = dataBuffer.element( startThread ); + const currentTScanElement = tScan.element( currentSubgroupInBlock ); + + currentTScanElement.assign( scanIn ); + + currentTScanElement.y.addAssign( currentTScanElement.x ); + currentTScanElement.z.addAssign( currentTScanElement.y ); + currentTScanElement.w.addAssign( currentTScanElement.z ); + + startThread.addAssign( subgroupSize ); + + } ); + + } ); + + // Each thread now has prefix sums of the elements in 'workPerInvocation' vec4s + + const prev = uint( 0 ).toVar(); + + const laneMask = subgroupSize.sub( 1 ).toVar( 'laneMask' ); + const clockwiseShift = ( invocationSubgroupIndex.add( laneMask ) ).bitAnd( laneMask ).toVar( 'clockwiseShift' ); + + Loop( { + start: uint( 0 ), + end: uint( workPerInvocation ), + type: 'uint', + condition: '<', + name: 'currentSubgroupInBlock' + }, ( { currentSubgroupInBlock } ) => { + + + // previous greatest accumulated value + const prevAccGreatestValue = subgroupShuffle( + + // Get the largest element within each vector (always w since prefix sum) + // Then add together with the same element in each lane of the subgroup. + // Assume all values in data buffer are 1 and subgroupSize is 4 + // Subgroup 0, 1, 2, 3 values -> 4 + // Invocation 0 value after inclusiveAdd 4 + // Invocation 1 value after inclusiveAdd 8 + // Invocation 2 value after inclusiveAdd 12 + // Invocation 3 value after inclusiveAdd 16 + + subgroupInclusiveAdd( tScan.element( currentSubgroupInBlock ).w ), + + // Shuffle each value between lanes in the subgroup counterClockWise + // Effectively a looping subgroupShuffleDown + // Inv 0 gets inv 3 value 16 + // Invocation 1 gets inv 0 value 4 + // Invocation 2 gets inv 1 value 8 + // Invocation 3 gets inv 2 value 12 + + clockwiseShift + ).toVar( 'prevAccGreatestValue' ); + + const isNotInvocationSubgroupIndex0 = invocationSubgroupIndex.notEqual( uint( 0 ) ); + + let addEle; + + // Vector read by lane 0 does not get changed by since it is already prefix summed + // within context of its subgroup, so we don't want to add greatest value for it. + // The purpose of shuffling to all lanes of the subgroup including lane 0 is simply + // to have the greatest value accessible for the broadcast from lane 0. + + if ( this.vecType === 'uvec4' ) { + + addEle = prev.add( select( isNotInvocationSubgroupIndex0, prevAccGreatestValue, uvec4( 0 ) ).uniformFlow() ); + + } else { + + addEle = prev.add( select( isNotInvocationSubgroupIndex0, prevAccGreatestValue, vec4( 0 ) ).uniformFlow() ); + + } + + tScan.element( currentSubgroupInBlock ).addAssign( addEle ); + + // Broadcast value of invocationSubgroupIndex 0 (which is usually largest value ) to prev + prev.addAssign( subgroupBroadcast( prevAccGreatestValue, uint( 0 ) ) ); + + } ); + + If( invocationSubgroupIndex.equal( uint( 0 ) ), () => { + + subgroupReductionArray.element( invocationSubgroupMetaIndex ).assign( prev ); + + } ); + + workgroupBarrier(); + + + const offset0 = uint( 0 ).toVar(); + const offset1 = uint( 0 ).toVar(); + + + const subgroupAlignedSize = this._getSubgroupAlignedSize(); + + // In cases where the number of subgroups in a workgroup is greater than the subgroup size itself, + // we need to iterate over the array again to capture all the data in the workgroup array buffer + this._subgroupAlignedSizeBlock( subgroupAlignedSize, ( j ) => { + + const i0 = ( + ( invocationLocalIndex.add( offset0 ) ).shiftLeft( offset1 ) + ).sub( offset0 ); + + const pred0 = i0.lessThan( spineSize ); + + const t0 = subgroupInclusiveAdd( + select( pred0, subgroupReductionArray.element( i0 ), uint( 0 ) ).uniformFlow() + ).toVar(); + + If( pred0, () => { + + subgroupReductionArray.element( i0 ).assign( t0 ); + + } ); + + workgroupBarrier(); + + If( j.notEqual( subgroupSize ), () => { + + const rShift = j.shiftRight( subgroupSizeLog ); + const i1 = invocationLocalIndex.add( rShift ); + If( ( i1.bitAnd( j.sub( 1 ) ) ).greaterThanEqual( rShift ), () => { + + const pred1 = i1.lessThan( spineSize ); + const t1 = select( + pred1, + subgroupReductionArray.element( this._maskLowerBits( i1, offset1 ).sub( 1 ) ), + uint( 0 ) + ).uniformFlow(); + + If( + pred1.and( + ( i1.add( 1 ) ).bitAnd( rShift.sub( 1 ) ).notEqual( uint( 0 ) ) ) + , () => { + + subgroupReductionArray.element( i1 ).addAssign( t1 ); + + } + ); + + } ); + + } ).Else( () => { + + offset0.addAssign( 1 ); + + } ); + + offset1.addAssign( subgroupSize ); + + } ); + + workgroupBarrier(); + + const spineScanWorkgroupReduction = select( + workgroupId.x.notEqual( uint( 0 ) ), + reductionBuffer.element( workgroupId.x.sub( 1 ) ), + uint( 0 ) + ).uniformFlow(); + + const downsweepSubgroupReduction = select( + invocationSubgroupMetaIndex.notEqual( 0 ), + subgroupReductionArray.element( invocationSubgroupMetaIndex.sub( 1 ) ), + uint( 0 ) + ).uniformFlow(); + + prev.assign( spineScanWorkgroupReduction.add( downsweepSubgroupReduction ) ); + + // LAST BLOCK + + startThread.assign( startThreadBase ); + + this._workPerInvocationBlock( ( currentSubgroupInBlock ) => { + + const sweepValue = tScan.element( currentSubgroupInBlock ).add( prev ); + outputBuffer.element( startThread ).assign( sweepValue ); + startThread.addAssign( subgroupSize ); + + }, ( currentSubgroupInBlock ) => { + + If( startThread.lessThan( uint( vecCount ) ), () => { + + const sweepValue = tScan.element( currentSubgroupInBlock ).add( prev ); + outputBuffer.element( startThread ).assign( sweepValue ); + startThread.addAssign( subgroupSize ); + + } ); + + } ); + + } )().compute( this.dispatchSize, [ this.workgroupSize ] ); + + return fnDef; + + } + + + /** + * Executes an intermediate reduction operation on the data buffer. + * + * @param {Renderer} renderer - The current scene's renderer. + */ + async computeReduce() { + + this.renderer.compute( this.computeFunctions.reduceFn ); + + } + + /** + * Executes a spine scan operation on the data buffer. + * + * @param {Renderer} renderer - The current scene's renderer. + */ + async computeSpineScan() { + + this.renderer.compute( this.computeFunctions.spineScanFn ); + + } + + /** + * Executes a downsweep operation on the data buffer. + * + * @param {Renderer} renderer - The current scene's renderer. + */ + async computeDownsweep() { + + this.renderer.compute( this.computeFunctions.downsweepFn ); + + } + + /** + * Executes the next subsequent compute step of a prefix sum. + * + * @param {Renderer} renderer - A renderer with the ability to execute compute operations. + */ + async computeStep() { + + switch ( this.currentStep ) { + + case 'Reduce': { + + await this.computeReduce(); + this.currentStep = 'Spine_Scan'; + break; + + } + + case 'Spine_Scan': { + + await this.computeSpineScan(); + this.currenTstep = 'Downsweep'; + break; + + } + + case 'Downsweep': { + + await this.computeDownsweep(); + this.currentStep = 'Reduce'; + break; + + } + + } + + } + + /** + * Executes a complete prefix sum on the data buffer. + * + * @param {Renderer} renderer - The current scene's renderer. + */ + async compute() { + + await this.computeStep( this.currentStep ); + await this.computeStep( this.currentStep ); + await this.computeStep( this.currentStep ); + + } + +} diff --git a/examples/webgpu_compute_prefix_sum.html b/examples/webgpu_compute_prefix_sum.html new file mode 100644 index 00000000000000..5428e09fb32b46 --- /dev/null +++ b/examples/webgpu_compute_prefix_sum.html @@ -0,0 +1,339 @@ + + + three.js webgpu - compute reduction + + + + + + +
+ three.js +
This example demonstrates a prefix sum operation on a buffer of data. +
Reference implementations are translated from the WGSL code present in GPUPrefixSums by b0nes164 +
+ + + + + + \ No newline at end of file diff --git a/src/Three.TSL.js b/src/Three.TSL.js index 529be9a5798632..9ca6dd167c08fb 100644 --- a/src/Three.TSL.js +++ b/src/Three.TSL.js @@ -238,6 +238,7 @@ export const inverseSqrt = TSL.inverseSqrt; export const inversesqrt = TSL.inversesqrt; export const invocationLocalIndex = TSL.invocationLocalIndex; export const invocationSubgroupIndex = TSL.invocationSubgroupIndex; +export const invocationSubgroupMetaIndex = TSL.invocationSubgroupMetaIndex; export const ior = TSL.ior; export const iridescence = TSL.iridescence; export const iridescenceIOR = TSL.iridescenceIOR; diff --git a/src/nodes/core/IndexNode.js b/src/nodes/core/IndexNode.js index 5908cc694d75ea..55298557d08a99 100644 --- a/src/nodes/core/IndexNode.js +++ b/src/nodes/core/IndexNode.js @@ -1,5 +1,6 @@ import Node from './Node.js'; import { nodeImmutable, varying } from '../tsl/TSLBase.js'; +import { subgroupSize } from '../gpgpu/ComputeBuiltinNode.js'; /** * This class represents shader indices of different types. The following predefined node @@ -155,6 +156,14 @@ export const invocationSubgroupIndex = /*@__PURE__*/ nodeImmutable( IndexNode, I */ export const invocationLocalIndex = /*@__PURE__*/ nodeImmutable( IndexNode, IndexNode.INVOCATION_LOCAL ); +/** + * TSL object that represents the index of a compute invocation within the scope of a subgroup. + * + * @tsl + * @type {IndexNode} + */ +export const invocationSubgroupMetaIndex = /*@__PURE__*/ invocationLocalIndex.div( subgroupSize ).toVar( 'invocationSubgroupMetaIndex' ); + /** * TSL object that represents the index of a draw call. * diff --git a/src/nodes/core/NodeBuilder.js b/src/nodes/core/NodeBuilder.js index 822e74871ee1e4..ef237e1798ed8f 100644 --- a/src/nodes/core/NodeBuilder.js +++ b/src/nodes/core/NodeBuilder.js @@ -1226,9 +1226,16 @@ class NodeBuilder { if ( type === 'float' || type === 'int' || type === 'uint' ) value = 0; else if ( type === 'bool' ) value = false; else if ( type === 'color' ) value = new Color(); +<<<<<<< HEAD else if ( type === 'vec2' || type === 'uvec2' || type === 'ivec2' ) value = new Vector2(); else if ( type === 'vec3' || type === 'uvec3' || type === 'ivec3' ) value = new Vector3(); else if ( type === 'vec4' || type === 'uvec4' || type === 'ivec4' ) value = new Vector4(); +======= + else if ( type === 'vec2' || type === 'uvec2' ) value = new Vector2(); + else if ( type === 'vec3' || type === 'uvec3' ) value = new Vector3(); + // vec4 defaults to (0, 0, 0, 1) + else if ( type === 'vec4' || type === 'uvec4' ) value = new Vector4(); +>>>>>>> d83ef1ebb4 (sketch out solve) } diff --git a/src/nodes/math/BitcountNode.js b/src/nodes/math/BitcountNode.js index e11f63b2810cf0..55572dd76a7a31 100644 --- a/src/nodes/math/BitcountNode.js +++ b/src/nodes/math/BitcountNode.js @@ -1,4 +1,8 @@ +<<<<<<< HEAD import { float, Fn, If, nodeProxyIntent, uint, int, uvec2, uvec3, uvec4, ivec2, ivec3, ivec4 } from '../tsl/TSLCore.js'; +======= +import { addMethodChaining, float, Fn, If, nodeProxyIntent, uint, int, uvec2, uvec3, uvec4, ivec2, ivec3, ivec4 } from '../tsl/TSLCore.js'; +>>>>>>> 277001d084 (sketch out solve) import { bitcast, floatBitsToUint } from './BitcastNode.js'; import MathNode, { negate } from './MathNode.js'; @@ -41,7 +45,10 @@ class BitcountNode extends MathNode { /** * Casts the input value of the function to an integer if necessary. * +<<<<<<< HEAD * @private +======= +>>>>>>> 277001d084 (sketch out solve) * @param {Node|Node} inputNode - The input value. * @param {Node} outputNode - The output value. * @param {string} elementType - The type of the input value. @@ -60,6 +67,25 @@ class BitcountNode extends MathNode { } +<<<<<<< HEAD +======= + _returnBaseDataNode( elementType ) { + + if ( elementType === 'uint' ) { + + return uint; + + } + + if ( elementType === 'int' ) { + + return int; + + } + + } + +>>>>>>> 277001d084 (sketch out solve) _returnDataNode( inputType ) { switch ( inputType ) { @@ -112,6 +138,7 @@ class BitcountNode extends MathNode { } +<<<<<<< HEAD } } @@ -124,6 +151,15 @@ class BitcountNode extends MathNode { * @param {string} elementType - The type of the input value. * @returns {Function} - The generated function */ +======= + + } + + + + } + +>>>>>>> 277001d084 (sketch out solve) _createTrailingZerosBaseLayout( method, elementType ) { const outputConvertNode = this._returnDataNode( elementType ); @@ -153,6 +189,7 @@ class BitcountNode extends MathNode { } +<<<<<<< HEAD /** * Creates and registers a reusable GLSL function that emulates the behavior of countLeadingZeros. * @@ -161,6 +198,8 @@ class BitcountNode extends MathNode { * @param {string} elementType - The type of the input value. * @returns {Function} - The generated function */ +======= +>>>>>>> 277001d084 (sketch out solve) _createLeadingZerosBaseLayout( method, elementType ) { const outputConvertNode = this._returnDataNode( elementType ); @@ -225,6 +264,7 @@ class BitcountNode extends MathNode { } +<<<<<<< HEAD /** * Creates and registers a reusable GLSL function that emulates the behavior of countOneBits. * @@ -233,12 +273,18 @@ class BitcountNode extends MathNode { * @param {string} elementType - The type of the input value. * @returns {Function} - The generated function */ +======= +>>>>>>> 277001d084 (sketch out solve) _createOneBitsBaseLayout( method, elementType ) { const outputConvertNode = this._returnDataNode( elementType ); const fnDef = Fn( ( [ value ] ) => { +<<<<<<< HEAD +======= + +>>>>>>> 277001d084 (sketch out solve) const v = uint( 0.0 ); this._resolveElementType( value, v, elementType ); diff --git a/src/nodes/math/MathNode.js b/src/nodes/math/MathNode.js index ca221ec1c0e671..8271477e82f6ec 100644 --- a/src/nodes/math/MathNode.js +++ b/src/nodes/math/MathNode.js @@ -364,9 +364,6 @@ MathNode.FWIDTH = 'fwidth'; MathNode.TRANSPOSE = 'transpose'; MathNode.DETERMINANT = 'determinant'; MathNode.INVERSE = 'inverse'; -MathNode.COUNT_TRAILING_ZEROS = 'countTrailingZeros'; -MathNode.COUNT_LEADING_ZEROS = 'countLeadingZeros'; -MathNode.COUNT_ONE_BITS = 'countOneBits'; // 2 inputs @@ -1102,50 +1099,10 @@ export const atan2 = ( y, x ) => { // @deprecated, r172 }; - -/** - * Finds the number of consecutive 0 bits from the least significant bit of the input value, - * which is also the index of the least significant bit of the input value. - * - * Can only be used with {@link WebGPURenderer} and a WebGPU backend. - * - * @tsl - * @function - * @param {Node | number} x - The input value. - * @returns {Node} - */ -export const countTrailingZeros = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_TRAILING_ZEROS ).setParameterLength( 1 ); - -/** - * Finds the number of consecutive 0 bits starting from the most significant bit of the input value. - * - * Can only be used with {@link WebGPURenderer} and a WebGPU backend. - * - * @tsl - * @function - * @param {Node | number} x - The input value. - * @returns {Node} - */ -export const countLeadingZeros = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_LEADING_ZEROS ).setParameterLength( 1 ); - -/** - * Finds the number of '1' bits set in the input value - * - * Can only be used with {@link WebGPURenderer} and a WebGPU backend. - * - * @tsl - * @function - * @returns {Node} - */ -export const countOneBits = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_ONE_BITS ).setParameterLength( 1 ); - // GLSL alias function export const faceforward = faceForward; export const inversesqrt = inverseSqrt; -export const findLSB = countTrailingZeros; -export const findMSB = countLeadingZeros; -export const bitCount = countOneBits; // Method chaining @@ -1208,6 +1165,3 @@ addMethodChaining( 'transpose', transpose ); addMethodChaining( 'determinant', determinant ); addMethodChaining( 'inverse', inverse ); addMethodChaining( 'rand', rand ); -addMethodChaining( 'countTrailingZeros', countTrailingZeros ); -addMethodChaining( 'countLeadingZeros', countLeadingZeros ); -addMethodChaining( 'countOneBits', countOneBits ); From f070385215a3d756627ed900f2aca59399748b88 Mon Sep 17 00:00:00 2001 From: Christian Helgeson <62450112+cmhhelgeson@users.noreply.github.com> Date: Sun, 9 Nov 2025 13:58:42 -0800 Subject: [PATCH 03/10] prefix_sum --- examples/jsm/gpgpu/BitonicSort.js | 2 +- examples/jsm/gpgpu/PrefixSum.js | 55 ++++-- examples/webgpu_compute_prefix_sum.html | 221 +++++++--------------- examples/webgpu_compute_reduce.html | 9 +- examples/webgpu_compute_sort_bitonic.html | 5 +- 5 files changed, 124 insertions(+), 168 deletions(-) diff --git a/examples/jsm/gpgpu/BitonicSort.js b/examples/jsm/gpgpu/BitonicSort.js index 4059c23b02adca..32576011a2ca16 100644 --- a/examples/jsm/gpgpu/BitonicSort.js +++ b/examples/jsm/gpgpu/BitonicSort.js @@ -119,7 +119,7 @@ export class BitonicSort { * * @type {StorageBufferNode} */ - this.workgroupSize = options.workgroupSize ? Math.min( this.dispatchSize, options.workgroupSize ) : Math.min( this.dispatchSize, 64 ); + this.workgroupSize = options.workgroupSize ? Math.min( this.dispatchSize, options.workgroupSize ) : Math.min( this.dispatchSize, this.renderer.backend.device.limits.maxComputeWorkgroupSizeX ); /** * A node representing a workgroup scoped buffer that holds locally sorted elements. diff --git a/examples/jsm/gpgpu/PrefixSum.js b/examples/jsm/gpgpu/PrefixSum.js index f88d56bed84636..08c518bbd5a1af 100644 --- a/examples/jsm/gpgpu/PrefixSum.js +++ b/examples/jsm/gpgpu/PrefixSum.js @@ -1,4 +1,7 @@ -import { Fn, If, instancedArray, invocationLocalIndex, countTrailingZeros, Loop, workgroupArray, subgroupSize, workgroupBarrier, workgroupId, uint, select, invocationSubgroupIndex, dot, uvec4, vec4, float, subgroupAdd, array, subgroupShuffle, subgroupInclusiveAdd, subgroupBroadcast, invocationSubgroupMetaIndex, arrayBuffer } from 'three/tsl'; +import { + StorageInstancedBufferAttribute +} from 'three'; +import { Fn, If, instancedArray, invocationLocalIndex, countTrailingZeros, Loop, workgroupArray, subgroupSize, workgroupBarrier, workgroupId, uint, select, invocationSubgroupIndex, dot, uvec4, vec4, float, subgroupAdd, array, subgroupShuffle, subgroupInclusiveAdd, subgroupBroadcast, invocationSubgroupMetaIndex, arrayBuffer, storage } from 'three/tsl'; const divRoundUp = ( size, part_size ) => { @@ -68,6 +71,12 @@ export class PrefixSum { */ this.renderer = renderer; + if ( this.renderer.backend.device === null ) { + + renderer.backend.init(); + + } + /** * @type {PrefixSumStorageObjects} */ @@ -132,7 +141,14 @@ export class PrefixSum { * * @type {number} */ - this.workgroupSize = options.workgroupSize ? options.workgroupSize : Math.min( this.vecCount, 64 ); + this.workgroupSize = options.workgroupSize ? options.workgroupSize : Math.min( this.vecCount, this.renderer.backend.device.limits.maxComputeWorkgroupSizeX ); + + /** + * The minimumn subgroup size specified by the renderer's graphics device. + * + * @type {number} + */ + this.minSubgroupSize = ( this.renderer.backend.device.adapterInfo && this.renderer.backend.device.adapterInfo.subgroupMinSize ) ? this.renderer.backend.device.adapterInfo.subgroupMinSize : 4; /** * The maximum number of elements that will be read by an individual workgroup in the reduction step. @@ -179,10 +195,17 @@ export class PrefixSum { _createStorageBuffers( inputArray ) { this.arrayBuffer = this.type === 'uint' ? Uint32Array.from( inputArray ) : Float32Array.from( inputArray ); + this.outputArrayBuffer = this.type === 'uint' ? Uint32Array.from( inputArray ) : Float32Array.from( inputArray ); + + const inputAttribute = new StorageInstancedBufferAttribute( this.arrayBuffer, 1 ); + const outputAttribute = new StorageInstancedBufferAttribute( this.outputArrayBuffer, 1 ); + + this.storageBuffers.dataBuffer = storage( inputAttribute, this.vecType, inputAttribute.count / 4 ).setName( `Prefix_Sum_Input_Vec_${id}` ); + this.storageBuffers.unvectorizedDataBuffer = storage( inputAttribute, this.type, inputAttribute.count ).setName( `Prefix_Sum_Input_Unvec_${id}` ); + + this.storageBuffers.outputBuffer = storage( outputAttribute, this.vecType, outputAttribute.count / 4 ).setName( `Prefix_Sum_Output_Vec_${id}` ); + this.storageBuffers.unvectorizedOutputBuffer = storage( outputAttribute, this.type, outputAttribute.count ).setName( `Prefix_Sum_Output_Unvec_${id}` ); - this.storageBuffers.unvectorizedDataBuffer = instancedArray( this.arrayBuffer, this.type ).setPBO( true ).setName( `Prefix_Sum_Input_Unvec_${id}` ); - this.storageBuffers.dataBuffer = instancedArray( this.arrayBuffer, this.vecType ).setPBO( true ).setName( `Prefix_Sum_Input_Vec_${id}` ); - this.storageBuffers.outputBuffer = instancedArray( this.arrayBuffer, this.vecType ).setName( `Prefix_Sum_Output_${id}` ); this.storageBuffers.reductionBuffer = instancedArray( this.numWorkgroups, this.type ).setPBO( true ).setName( `Prefix_Sum_Reduction_${id}` ); } @@ -472,6 +495,19 @@ export class PrefixSum { _getSpineScanFn() { const { reductionBuffer } = this.storageBuffers; + + if ( this.numWorkgroups <= this.minSubgroupSize ) { + + const fnDef = Fn( () => { + + reductionBuffer.element( invocationSubgroupIndex ).assign( subgroupInclusiveAdd( reductionBuffer.element( invocationSubgroupIndex ) ) ); + + } )().compute( this.numWorkgroups, [ this.workgroupSize ] ); + + return fnDef; + + } + const { subgroupReductionArray, unvectorizedSubgroupOffset, spineSize, subgroupSizeLog } = this.utilityNodes; const { unvectorizedWorkPerInvocation } = this; @@ -630,8 +666,6 @@ export class PrefixSum { } )().compute( this.numWorkgroups, [ this.workgroupSize ] ); - console.log( fnDef ); - return fnDef; } @@ -639,7 +673,6 @@ export class PrefixSum { _getDownsweepFn() { const { dataBuffer, reductionBuffer, outputBuffer } = this.storageBuffers; - const { vecType } = this; const { subgroupOffset, workgroupOffset, subgroupReductionArray, subgroupSizeLog, spineSize } = this.utilityNodes; const { workPerInvocation, vecCount } = this; @@ -958,9 +991,9 @@ export class PrefixSum { */ async compute() { - await this.computeStep( this.currentStep ); - await this.computeStep( this.currentStep ); - await this.computeStep( this.currentStep ); + await this.computeReduce(); + await this.computeSpineScan(); + await this.computeDownsweep(); } diff --git a/examples/webgpu_compute_prefix_sum.html b/examples/webgpu_compute_prefix_sum.html index 5428e09fb32b46..67cccca332a16b 100644 --- a/examples/webgpu_compute_prefix_sum.html +++ b/examples/webgpu_compute_prefix_sum.html @@ -27,7 +27,7 @@