Skip to content

Commit 24b72d2

Browse files
[JS/WebGPU] Preserve zero size input tensor dims. (#19737)
### Description For Concat operation, the zero-size input tensor shape need to be preserved and, unlike non-zero tensors, the dims are not constrained to match other input tensors' dims. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 6c3bed6 commit 24b72d2

File tree

2 files changed

+149
-77
lines changed

2 files changed

+149
-77
lines changed

js/web/lib/wasm/jsep/webgpu/ops/concat.ts

+69-77
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
1313
readonly axis: number;
1414
}
1515

16-
const validateInputs = (inputs: readonly TensorView[]): void => {
16+
const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
1717
if (!inputs || inputs.length < 1) {
1818
throw new Error('too few inputs');
1919
}
20-
21-
const inputType = inputs[0].dataType;
22-
const inputDimensionality = inputs[0].dims.length;
23-
24-
for (const input of inputs) {
20+
const referenceIndex = 0;
21+
const referenceInput = inputs[referenceIndex];
22+
const inputType = referenceInput.dataType;
23+
const inputRank = referenceInput.dims.length;
24+
inputs.forEach((input, i) => {
25+
if (i === referenceIndex) {
26+
return;
27+
}
2528
// make sure types of all inputs match
2629
if (input.dataType !== inputType) {
2730
throw new Error('input tensors should be one type');
2831
}
29-
3032
// make sure the dimensionality of all inputs are the same
31-
if (input.dims.length !== inputDimensionality) {
33+
if (input.dims.length !== inputRank) {
3234
throw new Error('input tensors should have the same shape');
3335
}
34-
}
36+
input.dims.forEach((dim, i) => {
37+
if (i !== axis && dim !== referenceInput.dims[i]) {
38+
throw new Error('non concat dimensions must match');
39+
}
40+
});
41+
});
3542
};
3643

3744
const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
@@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
6471
return codeLines.join('\n');
6572
};
6673

67-
const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => {
68-
const inputShape = inputs[0].dims.slice();
69-
if (axis >= inputShape.length || axis < (-1 * inputShape.length)) {
70-
throw new Error('axis specified for concat doesn\'t match input dimensionality');
71-
}
72-
const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
73-
// ensure all of the non-concatenated axes match each other
74-
// calculate the shape of the output tensor while we do that
75-
const outputShape = inputShape.slice(0);
76-
for (let i = 1; i < inputs.length; i++) {
77-
const dataNShape = inputs[i].dims.slice();
78-
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
79-
// add to the placeholder for computing output shape
80-
if (axisIndex === adjustedAxis) {
81-
outputShape[adjustedAxis] += dataNShape[axisIndex];
74+
const createConcatProgramInfo =
75+
(inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
76+
const outputSize = ShapeUtil.size(outputShape);
77+
78+
const sizeInConcatAxis = new Array<number>(inputs.length);
79+
const inputVars = new Array<IndicesHelper>(inputs.length);
80+
81+
let previousSum = 0;
82+
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
83+
const inputRanks = [];
84+
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
85+
for (let i = 0; i < inputs.length; ++i) {
86+
previousSum += inputs[i].dims[adjustedAxis];
87+
sizeInConcatAxis[i] = previousSum;
88+
inputRanks.push(inputs[i].dims.length);
89+
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
90+
inputDependencies.push('rank');
91+
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
8292
}
83-
// ensure all non-cancatenated axes match each other
84-
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
85-
throw new Error('non concat dimensions must match');
93+
for (let i = 0; i < inputs.length; ++i) {
94+
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
8695
}
87-
}
88-
}
89-
90-
const outputSize = ShapeUtil.size(outputShape);
91-
92-
const sizeInConcatAxis = new Array<number>(inputs.length);
93-
const inputVars = new Array<IndicesHelper>(inputs.length);
94-
const dataType = inputs[0].dataType;
95-
96-
let previousSum = 0;
97-
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
98-
const inputRanks = [];
99-
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
100-
for (let i = 0; i < inputs.length; ++i) {
101-
previousSum += inputs[i].dims[adjustedAxis];
102-
sizeInConcatAxis[i] = previousSum;
103-
inputRanks.push(inputs[i].dims.length);
104-
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
105-
inputDependencies.push('rank');
106-
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
107-
}
108-
for (let i = 0; i < inputs.length; ++i) {
109-
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
110-
}
111-
programUniforms.push(...createTensorShapeVariables(outputShape));
96+
programUniforms.push(...createTensorShapeVariables(outputShape));
11297

113-
const output = outputVariable('output', dataType, outputShape.length);
114-
const indicesAxis = output.indicesGet('indices', adjustedAxis);
115-
const sizeInConcatAxisStr =
116-
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
117-
const getShaderSource = (shaderHelper: ShaderHelper) => `
98+
const output = outputVariable('output', dataType, outputShape.length);
99+
const indicesAxis = output.indicesGet('indices', adjustedAxis);
100+
const sizeInConcatAxisStr =
101+
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
102+
const getShaderSource = (shaderHelper: ShaderHelper) => `
118103
119104
${(() => {
120-
shaderHelper.registerUniform('outputSize', 'u32');
121-
for (let i = 0; i < inputs.length; i++) {
122-
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
123-
}
124-
return shaderHelper.declareVariables(...inputVars, output);
125-
})()}
105+
shaderHelper.registerUniform('outputSize', 'u32');
106+
for (let i = 0; i < inputs.length; i++) {
107+
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
108+
}
109+
return shaderHelper.declareVariables(...inputVars, output);
110+
})()}
126111
127112
${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
128113
@@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
140125
${assignOutputData(inputVars, output)}
141126
}`;
142127

143-
return {
144-
name: 'Concat',
145-
shaderCache: {hint: `${axis}`, inputDependencies},
146-
getRunData: () => ({
147-
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
148-
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
149-
programUniforms,
150-
}),
151-
getShaderSource,
152-
};
153-
};
128+
return {
129+
name: 'Concat',
130+
shaderCache: {hint: `${adjustedAxis}`, inputDependencies},
131+
getRunData: () => ({
132+
outputs: [{dims: outputShape, dataType}],
133+
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
134+
programUniforms,
135+
}),
136+
getShaderSource,
137+
};
138+
};
154139

155140
export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
156-
validateInputs(context.inputs);
141+
const inputs = context.inputs;
142+
const inputShape = inputs[0].dims;
143+
const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
144+
validateInputs(inputs, adjustedAxis);
145+
const outputShape = inputShape.slice();
146+
outputShape[adjustedAxis] =
147+
inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
157148
// 0 length tensors are valid for concat, remove them
158-
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
159-
context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs});
149+
const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0);
150+
context.compute(
151+
createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs});
160152
};
161153

162154
export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>

js/web/test/data/ops/concat_zero-sized.jsonc

+80
Original file line numberDiff line numberDiff line change
@@ -557,5 +557,85 @@
557557
]
558558
}
559559
]
560+
},
561+
{
562+
"name": "Concat 2D axis=1; Preserve dims",
563+
"operator": "Concat",
564+
"attributes": [
565+
{
566+
"name": "axis",
567+
"data": 0,
568+
"type": "int"
569+
}
570+
],
571+
"cases": [
572+
{
573+
"name": "Some but not all input tensors are zero-sized",
574+
"inputs": [
575+
{
576+
"data": [],
577+
"dims": [0, 1],
578+
"type": "float32"
579+
},
580+
{
581+
"data": [1],
582+
"dims": [1, 1],
583+
"type": "float32"
584+
}
585+
],
586+
"outputs": [
587+
{
588+
"data": [1],
589+
"dims": [1, 1],
590+
"type": "float32"
591+
}
592+
]
593+
}
594+
]
595+
},
596+
{
597+
"name": "Concat 2D axis=1; Preserve dims",
598+
"operator": "Concat",
599+
"attributes": [
600+
{
601+
"name": "axis",
602+
"data": 1,
603+
"type": "int"
604+
}
605+
],
606+
"cases": [
607+
{
608+
"name": "All input tensors are zero-sized",
609+
"inputs": [
610+
{
611+
"data": [],
612+
"dims": [0, 0],
613+
"type": "float32"
614+
},
615+
{
616+
"data": [],
617+
"dims": [0, 1],
618+
"type": "float32"
619+
},
620+
{
621+
"data": [],
622+
"dims": [0, 2],
623+
"type": "float32"
624+
},
625+
{
626+
"data": [],
627+
"dims": [0, 3],
628+
"type": "float32"
629+
}
630+
],
631+
"outputs": [
632+
{
633+
"data": [],
634+
"dims": [0, 6],
635+
"type": "float32"
636+
}
637+
]
638+
}
639+
]
560640
}
561641
]

0 commit comments

Comments
 (0)