Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 78 additions & 17 deletions js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,73 @@ const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newSh
return { newShape, newPerm };
};

const isTransposeReshape = (perm: number[], shape: readonly number[]) => {
// As long as the dims with values > 1 stay in the same order, it's a reshape.
// Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
let lastPermutedAxis = 0;
for (let i = 0; i < perm.length; ++i) {
if (shape[perm[i]] === 1) {
continue;
}
if (perm[i] < lastPermutedAxis) {
return false;
}
lastPermutedAxis = perm[i];
}
return true;
};

export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
const inputDataType = inputTensor.dataType;
const inputRank = inputTensor.dims.length;
const perm = getAdjustedPerm(inputRank, permAttr);
const outputShape = getOutputShape(inputTensor.dims, perm);
let newInputShape = inputTensor.dims;
let newOutputShape = outputShape;
const transposeAsReshape = isTransposeReshape(perm, inputTensor.dims);
let getShaderSource;
if (transposeAsReshape) {
getShaderSource = (shaderHelper: ShaderHelper) => {
const input = inputVariable('input', inputDataType, newInputShape, 4);
const output = outputVariable('output', inputDataType, newOutputShape, 4);
return `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
output[global_idx] = input[global_idx];
}`;
};

return {
name: 'TransposeCopy',
shaderCache: { inputDependencies: ['type'] },
getRunData: () => {
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* components */) },
programUniforms: [{ type: DataType.uint32, data: Math.ceil(outputSize / 4) }],
};
},
getShaderSource,
};
}
const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm);
const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]);
const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]);
const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst;
let newInputShape = useShared ? newShape : inputTensor.dims;
let newOutputShape = outputShape;
const useShared = newShape.length === 2 || channelsLast || channelsFirst;
if (useShared) {
newInputShape = channelsLast
? [newShape[0], newShape[1] * newShape[2]]
: channelsFirst
? [newShape[0] * newShape[1], newShape[2]]
: newShape;
newOutputShape = [newInputShape[1], newInputShape[0]];
}
const input = inputVariable('a', inputDataType, newInputShape.length);
const output = outputVariable('output', inputDataType, newOutputShape.length);
const tileSize = 16;
let getShaderSource;
if (useShared) {
getShaderSource = (shaderHelper: ShaderHelper) => `
const tileSize = 16;
getShaderSource = (shaderHelper: ShaderHelper) => {
const input = inputVariable('a', inputDataType, newInputShape.length);
const output = outputVariable('output', inputDataType, newOutputShape.length);
return `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
var<workgroup> tile : array<array<${output.type.value}, ${tileSize + 1}>, ${tileSize}>;
${shaderHelper.mainStart([tileSize, tileSize, 1])}
Expand All @@ -92,8 +134,29 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')}
}
}`;
} else {
getShaderSource = (shaderHelper: ShaderHelper) => `
};
return {
name: 'TransposeShared',
shaderCache: { inputDependencies: ['type'] },
getRunData: () => {
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
dispatchGroup: { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) },
programUniforms: [
{ type: DataType.uint32, data: outputSize },
...createTensorShapeVariables(newInputShape, newOutputShape),
],
};
},
getShaderSource,
};
}

getShaderSource = (shaderHelper: ShaderHelper) => {
const input = inputVariable('a', inputDataType, newInputShape.length);
const output = outputVariable('output', inputDataType, newOutputShape.length);
return `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}

${permFunctionBody(perm, inputRank, input, output)}
Expand All @@ -106,17 +169,15 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu

${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
}`;
}
};
return {
name: useShared ? 'TransposeShared' : 'Transpose',
name: 'Transpose',
shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] },
getRunData: () => {
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
dispatchGroup: useShared
? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }
: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms: [
{ type: DataType.uint32, data: outputSize },
...createTensorShapeVariables(newInputShape, newOutputShape),
Expand Down
24 changes: 24 additions & 0 deletions js/web/test/data/ops/transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,30 @@
}
]
},
{
"name": "Transpose as reshape - perms:[1, 0, 2, 4, 3]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [1, 0, 2, 4, 3], "type": "ints" }],
"cases": [
{
"name": "T[3, 1, 2, 1, 4]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [3, 1, 2, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [1, 3, 2, 4, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Transpose - perms:[1, 0]",
"operator": "Transpose",
Expand Down
Loading