@@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view';
66import { BroadcastUtil , ShapeUtil } from '../../util' ;
77import { ComputeContext , ProgramInfo } from '../types' ;
88
9- import { inputVariable , outputVariable , ShaderHelper } from './common' ;
9+ import { createTensorShapeVariables , inputVariable , outputVariable , ShaderHelper } from './common' ;
1010
1111const createWhereOpProgramShader =
1212 ( shaderHelper : ShaderHelper , inputs : readonly TensorView [ ] , dimsOutput : readonly number [ ] , isBroadcast : boolean ,
1313 typeOutput : number ) => {
14- const outputSize = ShapeUtil . size ( dimsOutput ) ;
15- const vecSize = Math . ceil ( outputSize / 4 ) ;
16-
17- const output = outputVariable ( 'outputData' , typeOutput , dimsOutput , 4 ) ;
18- const a = inputVariable ( 'aData' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims , 4 ) ;
19- const b = inputVariable ( 'bData' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims , 4 ) ;
20- const c = inputVariable ( 'cData' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims , 4 ) ;
14+ const output = outputVariable ( 'output_data' , typeOutput , dimsOutput . length , 4 ) ;
15+ const a = inputVariable ( 'a_data' , inputs [ 1 ] . dataType , inputs [ 1 ] . dims . length , 4 ) ;
16+ const b = inputVariable ( 'b_data' , inputs [ 2 ] . dataType , inputs [ 2 ] . dims . length , 4 ) ;
17+ const c = inputVariable ( 'c_data' , inputs [ 0 ] . dataType , inputs [ 0 ] . dims . length , 4 ) ;
2118
2219 let assignment : string ;
2320 const expression = ( a : string , b : string , c : string ) => `select(${ b } , ${ a } , ${ c } )` ;
@@ -27,21 +24,21 @@ const createWhereOpProgramShader =
2724 expression ( a . getByOffset ( 'global_idx' ) , b . getByOffset ( 'global_idx' ) , c . getByOffset ( 'global_idx' ) ) ) ;
2825 } else {
2926 const singleAssignment = ( resStr : string , x : number , typeCast = '' ) => {
30- const expressionA = `aData[indexA ${ x } ][componentA ${ x } ]` ;
31- const expressionB = `bData[indexB ${ x } ][componentB ${ x } ]` ;
27+ const expressionA = `a_data[index_a ${ x } ][component_a ${ x } ]` ;
28+ const expressionB = `b_data[index_b ${ x } ][component_b ${ x } ]` ;
3229 // eslint-disable-next-line no-bitwise
33- const expressionC = `bool(cData[indexC ${ x } ] & (0xffu << (componentC ${ x } * 8)))` ;
30+ const expressionC = `bool(c_data[index_c ${ x } ] & (0xffu << (component_c ${ x } * 8)))` ;
3431 return `
35- let outputIndices ${ x } = ${ output . offsetToIndices ( `global_idx * 4u + ${ x } u` ) } ;
36- let offsetA ${ x } = ${ a . broadcastedIndicesToOffset ( `outputIndices ${ x } ` , output ) } ;
37- let offsetB ${ x } = ${ b . broadcastedIndicesToOffset ( `outputIndices ${ x } ` , output ) } ;
38- let offsetC ${ x } = ${ c . broadcastedIndicesToOffset ( `outputIndices ${ x } ` , output ) } ;
39- let indexA ${ x } = offsetA ${ x } / 4u;
40- let indexB ${ x } = offsetB ${ x } / 4u;
41- let indexC ${ x } = offsetC ${ x } / 4u;
42- let componentA ${ x } = offsetA ${ x } % 4u;
43- let componentB ${ x } = offsetB ${ x } % 4u;
44- let componentC ${ x } = offsetC ${ x } % 4u;
32+ let output_indices ${ x } = ${ output . offsetToIndices ( `global_idx * 4u + ${ x } u` ) } ;
33+ let offset_a ${ x } = ${ a . broadcastedIndicesToOffset ( `output_indices ${ x } ` , output ) } ;
34+ let offset_b ${ x } = ${ b . broadcastedIndicesToOffset ( `output_indices ${ x } ` , output ) } ;
35+ let offset_c ${ x } = ${ c . broadcastedIndicesToOffset ( `output_indices ${ x } ` , output ) } ;
36+ let index_a ${ x } = offset_a ${ x } / 4u;
37+ let index_b ${ x } = offset_b ${ x } / 4u;
38+ let index_c ${ x } = offset_c ${ x } / 4u;
39+ let component_a ${ x } = offset_a ${ x } % 4u;
40+ let component_b ${ x } = offset_b ${ x } % 4u;
41+ let component_c ${ x } = offset_c ${ x } % 4u;
4542 ${ resStr } [${ x } ] = ${ typeCast } (${ expression ( expressionA , expressionB , expressionC ) } );
4643 ` ;
4744 } ;
@@ -52,21 +49,21 @@ const createWhereOpProgramShader =
5249 ${ singleAssignment ( 'data' , 1 , 'u32' ) }
5350 ${ singleAssignment ( 'data' , 2 , 'u32' ) }
5451 ${ singleAssignment ( 'data' , 3 , 'u32' ) }
55- outputData [global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));` ;
52+ output_data [global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));` ;
5653 } else {
5754 assignment = `
58- ${ singleAssignment ( 'outputData [global_idx]' , 0 ) }
59- ${ singleAssignment ( 'outputData [global_idx]' , 1 ) }
60- ${ singleAssignment ( 'outputData [global_idx]' , 2 ) }
61- ${ singleAssignment ( 'outputData [global_idx]' , 3 ) }
55+ ${ singleAssignment ( 'output_data [global_idx]' , 0 ) }
56+ ${ singleAssignment ( 'output_data [global_idx]' , 1 ) }
57+ ${ singleAssignment ( 'output_data [global_idx]' , 2 ) }
58+ ${ singleAssignment ( 'output_data [global_idx]' , 3 ) }
6259 ` ;
6360 }
6461 }
6562
6663 return `
67- ${ shaderHelper . declareVariables ( c , a , b , output ) }
64+ ${ shaderHelper . registerUniform ( 'vec_size' , 'u32' ) . declareVariables ( c , a , b , output ) }
6865 ${ shaderHelper . mainStart ( ) }
69- ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( vecSize ) }
66+ ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.vec_size' ) }
7067 ${ assignment }
7168 }` ;
7269 } ;
@@ -91,13 +88,18 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
9188 outputSize = ShapeUtil . size ( outputShape ) ;
9289 }
9390
91+ const vecSize = Math . ceil ( outputSize / 4 ) ;
92+
9493 return {
9594 name : 'Where' ,
95+ shaderCache : { inputDependencies : [ 'rank' , 'rank' , 'rank' ] } ,
9696 getShaderSource : ( shaderHelper ) =>
9797 createWhereOpProgramShader ( shaderHelper , inputs , outputShape , isBroadcast , outputDataType ) ,
9898 getRunData : ( ) => ( {
9999 outputs : [ { dims : outputShape , dataType : outputDataType } ] ,
100- dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ / 4 /* vec size */ ) }
100+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ / 4 /* vec size */ ) } ,
101+ programUniforms :
102+ [ { type : DataType . uint32 , data : vecSize } , ...createTensorShapeVariables ( dimsC , dimsA , dimsB , outputShape ) ] ,
101103 } ) ,
102104 } ;
103105} ;
0 commit comments