Skip to content

Commit 7d73099

Browse files
authored
Nodes: Access Remaining Compute Builtins (#29469)
* init * remove duplicated function
1 parent 52f640d commit 7d73099

File tree

5 files changed

+143
-4
lines changed

5 files changed

+143
-4
lines changed

examples/webgpu_compute_sort_bitonic.html

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
<script type="module">
5555

5656
import * as THREE from 'three';
57-
import { storageObject, If, vec3, not, uniform, uv, uint, float, Fn, vec2, abs, int, invocationLocalIndex, workgroupArray, uvec2, floor, instanceIndex, workgroupBarrier, atomicAdd, atomicStore } from 'three/tsl';
57+
import { storageObject, If, vec3, not, uniform, uv, uint, float, Fn, vec2, abs, int, invocationLocalIndex, workgroupArray, uvec2, floor, instanceIndex, workgroupBarrier, atomicAdd, atomicStore, workgroupId } from 'three/tsl';
5858

5959
import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
6060

@@ -264,8 +264,7 @@
264264

265265
// Get ids of indices needed to populate workgroup local buffer.
266266
// Use .toVar() to prevent these values from being recalculated multiple times.
267-
const workgroupId = instanceIndex.div( WORKGROUP_SIZE[ 0 ] ).toVar();
268-
const localOffset = uint( WORKGROUP_SIZE[ 0 ] ).mul( 2 ).mul( workgroupId ).toVar();
267+
const localOffset = uint( WORKGROUP_SIZE[ 0 ] ).mul( 2 ).mul( workgroupId.x ).toVar();
269268

270269
const localID1 = invocationLocalIndex.mul( 2 );
271270
const localID2 = invocationLocalIndex.mul( 2 ).add( 1 );

src/nodes/TSL.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ export * from './geometry/RangeNode.js';
142142

143143
// gpgpu
144144
export * from './gpgpu/ComputeNode.js';
145+
export * from './gpgpu/ComputeBuiltinNode.js';
145146
export * from './gpgpu/BarrierNode.js';
146147
export * from './gpgpu/WorkgroupInfoNode.js';
147148
export * from './gpgpu/AtomicFunctionNode.js';

src/nodes/core/IndexNode.js

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,34 @@ class IndexNode extends Node {
2828

2929
if ( scope === IndexNode.VERTEX ) {
3030

31+
// The index of a vertex within a mesh.
3132
propertyName = builder.getVertexIndex();
3233

3334
} else if ( scope === IndexNode.INSTANCE ) {
3435

36+
// The index of either a mesh instance or an invocation of a compute shader.
3537
propertyName = builder.getInstanceIndex();
3638

3739
} else if ( scope === IndexNode.DRAW ) {
3840

41+
// The index of a draw call.
3942
propertyName = builder.getDrawIndex();
4043

4144
} else if ( scope === IndexNode.INVOCATION_LOCAL ) {
4245

46+
// The index of a compute invocation within the scope of a workgroup load.
4347
propertyName = builder.getInvocationLocalIndex();
4448

49+
} else if ( scope === IndexNode.INVOCATION_SUBGROUP ) {
50+
51+
// The index of a compute invocation within the scope of a subgroup.
52+
propertyName = builder.getInvocationSubgroupIndex();
53+
54+
} else if ( scope === IndexNode.SUBGROUP ) {
55+
56+
// The index of the subgroup the current compute invocation belongs to.
57+
propertyName = builder.getSubgroupIndex();
58+
4559
} else {
4660

4761
throw new Error( 'THREE.IndexNode: Unknown scope: ' + scope );
@@ -70,12 +84,16 @@ class IndexNode extends Node {
7084

7185
IndexNode.VERTEX = 'vertex';
7286
IndexNode.INSTANCE = 'instance';
87+
IndexNode.SUBGROUP = 'subgroup';
7388
IndexNode.INVOCATION_LOCAL = 'invocationLocal';
89+
IndexNode.INVOCATION_SUBGROUP = 'invocationSubgroup';
7490
IndexNode.DRAW = 'draw';
7591

7692
export default IndexNode;
7793

7894
export const vertexIndex = /*@__PURE__*/ nodeImmutable( IndexNode, IndexNode.VERTEX );
7995
export const instanceIndex = /*@__PURE__*/ nodeImmutable( IndexNode, IndexNode.INSTANCE );
96+
export const subgroupIndex = /*@__PURE__*/ nodeImmutable( IndexNode, IndexNode.SUBGROUP );
97+
export const invocationSubgroupIndex = /*@__PURE__*/ nodeImmutable( IndexNode, IndexNode.INVOCATION_SUBGROUP );
8098
export const invocationLocalIndex = /*@__PURE__*/ nodeImmutable( IndexNode, IndexNode.INVOCATION_LOCAL );
8199
export const drawIndex = /*@__PURE__*/ nodeImmutable( IndexNode, IndexNode.DRAW );
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import Node from '../core/Node.js';
2+
import { nodeObject } from '../tsl/TSLBase.js';
3+
4+
class ComputeBuiltinNode extends Node {
5+
6+
static get type() {
7+
8+
return 'ComputeBuiltinNode';
9+
10+
}
11+
12+
constructor( builtinName, nodeType ) {
13+
14+
super( nodeType );
15+
16+
this._builtinName = builtinName;
17+
18+
}
19+
20+
getHash( builder ) {
21+
22+
return this.getBuiltinName( builder );
23+
24+
}
25+
26+
getNodeType( /*builder*/ ) {
27+
28+
return this.nodeType;
29+
30+
}
31+
32+
setBuiltinName( builtinName ) {
33+
34+
this._builtinName = builtinName;
35+
36+
return this;
37+
38+
}
39+
40+
getBuiltinName( /*builder*/ ) {
41+
42+
console.log( this._builtinName );
43+
44+
return this._builtinName;
45+
46+
}
47+
48+
hasBuiltin( builder ) {
49+
50+
builder.hasBuiltin( this._builtinName );
51+
52+
}
53+
54+
generate( builder, output ) {
55+
56+
const builtinName = this.getBuiltinName( builder );
57+
const nodeType = this.getNodeType( builder );
58+
59+
if ( builder.shaderStage === 'compute' ) {
60+
61+
return builder.format( builtinName, nodeType, output );
62+
63+
} else {
64+
65+
console.warn( `ComputeBuiltinNode: Compute built-in value ${builtinName} can not be accessed in the ${builder.shaderStage} stage` );
66+
return builder.generateConst( nodeType );
67+
68+
}
69+
70+
}
71+
72+
serialize( data ) {
73+
74+
super.serialize( data );
75+
76+
data.global = this.global;
77+
data._builtinName = this._builtinName;
78+
79+
}
80+
81+
deserialize( data ) {
82+
83+
super.deserialize( data );
84+
85+
this.global = data.global;
86+
this._builtinName = data._builtinName;
87+
88+
}
89+
90+
}
91+
92+
export default ComputeBuiltinNode;
93+
94+
const computeBuiltin = ( name, nodeType ) => nodeObject( new ComputeBuiltinNode( name, nodeType ) );
95+
96+
export const numWorkgroups = /*@__PURE__*/ computeBuiltin( 'numWorkgroups', 'uvec3' );
97+
export const workgroupId = /*@__PURE__*/ computeBuiltin( 'workgroupId', 'uvec3' );
98+
export const localId = /*@__PURE__*/ computeBuiltin( 'localId', 'uvec3' );
99+
export const subgroupSize = /*@__PURE__*/ computeBuiltin( 'subgroupSize', 'uint' );
100+

src/renderers/webgpu/nodes/WGSLNodeBuilder.js

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,12 @@ class WGSLNodeBuilder extends NodeBuilder {
584584

585585
}
586586

587+
hasBuiltin( name, shaderStage = this.shaderStage ) {
588+
589+
return ( this.builtins[ shaderStage ] !== undefined && this.builtins[ shaderStage ].has( name ) );
590+
591+
}
592+
587593
getVertexIndex() {
588594

589595
if ( this.shaderStage === 'vertex' ) {
@@ -656,11 +662,19 @@ ${ flowData.code }
656662

657663
}
658664

665+
getInvocationSubgroupIndex() {
666+
667+
this.enableSubGroups();
668+
669+
return this.getBuiltin( 'subgroup_invocation_id', 'invocationSubgroupIndex', 'u32', 'attribute' );
670+
671+
}
672+
659673
getSubgroupIndex() {
660674

661675
this.enableSubGroups();
662676

663-
return this.getBuiltin( 'subgroup_invocation_id', 'subgroupIndex', 'u32', 'attribute' );
677+
return this.getBuiltin( 'subgroup_id', 'subgroupIndex', 'u32', 'attribute' );
664678

665679
}
666680

@@ -819,6 +833,13 @@ ${ flowData.code }
819833
this.getBuiltin( 'local_invocation_id', 'localId', 'vec3<u32>', 'attribute' );
820834
this.getBuiltin( 'num_workgroups', 'numWorkgroups', 'vec3<u32>', 'attribute' );
821835

836+
if ( this.renderer.hasFeature( 'subgroups' ) ) {
837+
838+
this.enableDirective( 'subgroups', shaderStage );
839+
this.getBuiltin( 'subgroup_size', 'subgroupSize', 'u32', 'attribute' );
840+
841+
}
842+
822843
}
823844

824845
if ( shaderStage === 'vertex' || shaderStage === 'compute' ) {

0 commit comments

Comments
 (0)