Skip to content

Commit 4902654

Browse files
Revert unintended changes to where.js
1 parent 6b247f5 commit 4902654

File tree

1 file changed

+31
-29
lines changed

1 file changed

+31
-29
lines changed

js/web/lib/wasm/jsep/webgpu/ops/where.ts

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view';
66
import {BroadcastUtil, ShapeUtil} from '../../util';
77
import {ComputeContext, ProgramInfo} from '../types';
88

9-
import {inputVariable, outputVariable, ShaderHelper} from './common';
9+
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
1010

1111
const createWhereOpProgramShader =
1212
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean,
1313
typeOutput: number) => {
14-
const outputSize = ShapeUtil.size(dimsOutput);
15-
const vecSize = Math.ceil(outputSize / 4);
16-
17-
const output = outputVariable('outputData', typeOutput, dimsOutput, 4);
18-
const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4);
19-
const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4);
20-
const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4);
14+
const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4);
15+
const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4);
16+
const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4);
17+
const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4);
2118

2219
let assignment: string;
2320
const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`;
@@ -27,21 +24,21 @@ const createWhereOpProgramShader =
2724
expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')));
2825
} else {
2926
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
30-
const expressionA = `aData[indexA${x}][componentA${x}]`;
31-
const expressionB = `bData[indexB${x}][componentB${x}]`;
27+
const expressionA = `a_data[index_a${x}][component_a${x}]`;
28+
const expressionB = `b_data[index_b${x}][component_b${x}]`;
3229
// eslint-disable-next-line no-bitwise
33-
const expressionC = `bool(cData[indexC${x}] & (0xffu << (componentC${x} * 8)))`;
30+
const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
3431
return `
35-
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
36-
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
37-
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
38-
let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
39-
let indexA${x} = offsetA${x} / 4u;
40-
let indexB${x} = offsetB${x} / 4u;
41-
let indexC${x} = offsetC${x} / 4u;
42-
let componentA${x} = offsetA${x} % 4u;
43-
let componentB${x} = offsetB${x} % 4u;
44-
let componentC${x} = offsetC${x} % 4u;
32+
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
33+
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
34+
let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)};
35+
let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)};
36+
let index_a${x} = offset_a${x} / 4u;
37+
let index_b${x} = offset_b${x} / 4u;
38+
let index_c${x} = offset_c${x} / 4u;
39+
let component_a${x} = offset_a${x} % 4u;
40+
let component_b${x} = offset_b${x} % 4u;
41+
let component_c${x} = offset_c${x} % 4u;
4542
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
4643
`;
4744
};
@@ -52,21 +49,21 @@ const createWhereOpProgramShader =
5249
${singleAssignment('data', 1, 'u32')}
5350
${singleAssignment('data', 2, 'u32')}
5451
${singleAssignment('data', 3, 'u32')}
55-
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
52+
output_data[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
5653
} else {
5754
assignment = `
58-
${singleAssignment('outputData[global_idx]', 0)}
59-
${singleAssignment('outputData[global_idx]', 1)}
60-
${singleAssignment('outputData[global_idx]', 2)}
61-
${singleAssignment('outputData[global_idx]', 3)}
55+
${singleAssignment('output_data[global_idx]', 0)}
56+
${singleAssignment('output_data[global_idx]', 1)}
57+
${singleAssignment('output_data[global_idx]', 2)}
58+
${singleAssignment('output_data[global_idx]', 3)}
6259
`;
6360
}
6461
}
6562

6663
return `
67-
${shaderHelper.declareVariables(c, a, b, output)}
64+
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)}
6865
${shaderHelper.mainStart()}
69-
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
66+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
7067
${assignment}
7168
}`;
7269
};
@@ -91,13 +88,18 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
9188
outputSize = ShapeUtil.size(outputShape);
9289
}
9390

91+
const vecSize = Math.ceil(outputSize / 4);
92+
9493
return {
9594
name: 'Where',
95+
shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},
9696
getShaderSource: (shaderHelper) =>
9797
createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType),
9898
getRunData: () => ({
9999
outputs: [{dims: outputShape, dataType: outputDataType}],
100-
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}
100+
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)},
101+
programUniforms:
102+
[{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)],
101103
}),
102104
};
103105
};

0 commit comments

Comments
 (0)