Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 91 additions & 83 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,35 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
}

const validateInputs = (inputs: readonly TensorView[]): void => {
const validateInputs = (inputs: readonly TensorView[], referenceIndex: number, axis: number): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}

const inputType = inputs[0].dataType;
const inputDimensionality = inputs[0].dims.length;

for (const input of inputs) {
const referenceInput = inputs[referenceIndex];
const inputType = referenceInput.dataType;
const inputRank = referenceInput.dims.length;
const referenceInputSize = ShapeUtil.size(referenceInput.dims);
inputs.forEach((input, i) => {
if (i === referenceIndex) {
return;
}
// make sure types of all inputs match
if (input.dataType !== inputType) {
throw new Error('input tensors should be one type');
}

// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputDimensionality) {
throw new Error('input tensors should have the same shape');
if (referenceInputSize > 0 && ShapeUtil.size(input.dims) > 0) {
// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputRank) {
throw new Error('input tensors should have the same shape');
}
input.dims.forEach((dim, i) => {
if (i !== axis && dim !== referenceInput.dims[i]) {
throw new Error('non concat dimensions must match');
}
});
}
}
});
};

const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
Expand Down Expand Up @@ -64,65 +74,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
return codeLines.join('\n');
};

const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => {
const inputShape = inputs[0].dims.slice();
if (axis >= inputShape.length || axis < (-1 * inputShape.length)) {
throw new Error('axis specified for concat doesn\'t match input dimensionality');
}
const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
// ensure all of the non-concatenated axes match each other
// calculate the shape of the output tensor while we do that
const outputShape = inputShape.slice(0);
for (let i = 1; i < inputs.length; i++) {
const dataNShape = inputs[i].dims.slice();
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
// add to the placeholder for computing output shape
if (axisIndex === adjustedAxis) {
outputShape[adjustedAxis] += dataNShape[axisIndex];
const createConcatProgramInfo =
(inputs: readonly TensorView[], axis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[axis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
// ensure all non-cancatenated axes match each other
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
throw new Error('non concat dimensions must match');
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
}
}
programUniforms.push(...createTensorShapeVariables(outputShape));

const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);
const dataType = inputs[0].dataType;

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));

const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', axis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `

${(() => {
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}

${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}

Expand All @@ -132,31 +120,51 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
var indices = ${output.offsetToIndices('global_idx')};

let inputIndex = calculateInputIndex(${indicesAxis});
if (inputIndex != 0u) {
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
}
if (inputIndex < ${inputs.length}u) {
if (inputIndex != 0u) {
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
}

${assignOutputData(inputVars, output)}
${assignOutputData(inputVars, output)}
} else {
${output.setByOffset('global_idx', '0')}
}
}`;

return {
name: 'Concat',
shaderCache: {hint: `${axis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};
return {
name: 'Concat',
shaderCache: {hint: `${axis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};

export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
validateInputs(context.inputs);
// find a none zero tensor to determine the output shape
// Choose input with max rank if all input tensors are zero size to make the output shape independent of the order of
// the inputs.
const inputs = context.inputs;
let referenceIndex = inputs.findIndex(input => ShapeUtil.size(input.dims) > 0);
if (referenceIndex === -1) {
referenceIndex = inputs.reduce(
(maxRankIndex, input, index, array) => input.dims > array[maxRankIndex].dims ? index : maxRankIndex, 0);
}

const inputShape = inputs[referenceIndex].dims;
const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0);
validateInputs(inputs, referenceIndex, adjustedAxis);
const outputShape = inputShape.slice();
outputShape[adjustedAxis] =
inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
// 0 length tensors are valid for concat, remove them
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs});
const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(
createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs});
};

export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
Expand Down
Loading