Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 69 additions & 77 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
}

const validateInputs = (inputs: readonly TensorView[]): void => {
const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}

const inputType = inputs[0].dataType;
const inputDimensionality = inputs[0].dims.length;

for (const input of inputs) {
const referenceIndex = 0;
const referenceInput = inputs[referenceIndex];
const inputType = referenceInput.dataType;
const inputRank = referenceInput.dims.length;
inputs.forEach((input, i) => {
if (i === referenceIndex) {
return;
}
// make sure types of all inputs match
if (input.dataType !== inputType) {
throw new Error('input tensors should be one type');
}

// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputDimensionality) {
if (input.dims.length !== inputRank) {
throw new Error('input tensors should have the same shape');
}
}
input.dims.forEach((dim, i) => {
if (i !== axis && dim !== referenceInput.dims[i]) {
throw new Error('non concat dimensions must match');
}
});
});
};

const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
Expand Down Expand Up @@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
return codeLines.join('\n');
};

const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => {
const inputShape = inputs[0].dims.slice();
if (axis >= inputShape.length || axis < (-1 * inputShape.length)) {
throw new Error('axis specified for concat doesn\'t match input dimensionality');
}
const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
// ensure all of the non-concatenated axes match each other
// calculate the shape of the output tensor while we do that
const outputShape = inputShape.slice(0);
for (let i = 1; i < inputs.length; i++) {
const dataNShape = inputs[i].dims.slice();
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
// add to the placeholder for computing output shape
if (axisIndex === adjustedAxis) {
outputShape[adjustedAxis] += dataNShape[axisIndex];
const createConcatProgramInfo =
(inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
// ensure all non-cancatenated axes match each other
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
throw new Error('non concat dimensions must match');
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
}
}

const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);
const dataType = inputs[0].dataType;

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(outputShape));

const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `

${(() => {
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}

${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}

Expand All @@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
${assignOutputData(inputVars, output)}
}`;

return {
name: 'Concat',
shaderCache: {hint: `${axis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};
return {
name: 'Concat',
shaderCache: {hint: `${adjustedAxis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};

export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
validateInputs(context.inputs);
const inputs = context.inputs;
const inputShape = inputs[0].dims;
const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0);
validateInputs(inputs, adjustedAxis);
const outputShape = inputShape.slice();
outputShape[adjustedAxis] =
inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
// 0 length tensors are valid for concat, remove them
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs});
const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(
createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs});
};

export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
Expand Down
80 changes: 80 additions & 0 deletions js/web/test/data/ops/concat_zero-sized.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -557,5 +557,85 @@
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [
{
"name": "axis",
"data": 0,
"type": "int"
}
],
"cases": [
{
"name": "Some but not all input tensors are zero-sized",
"inputs": [
{
"data": [],
"dims": [0, 1],
"type": "float32"
},
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [
{
"name": "axis",
"data": 1,
"type": "int"
}
],
"cases": [
{
"name": "All input tensors are zero-sized",
"inputs": [
{
"data": [],
"dims": [0, 0],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 2],
"type": "float32"
},
{
"data": [],
"dims": [0, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [],
"dims": [0, 6],
"type": "float32"
}
]
}
]
}
]