Skip to content

Commit fcefc67

Browse files
axingingfs-eire
authored andcommitted
[js/webgpu] Refactor createTensorShapeVariables (#18883)
1 parent e305794 commit fcefc67

22 files changed

+40
-64
lines changed

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts

+1-2
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ export const createConv2DMatMulProgramInfo =
195195
{type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations}
196196
];
197197
appendActivationUniformsData(attributes, programUniforms);
198-
programUniforms.push(
199-
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
198+
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
200199
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
201200
if (hasBias) {
202201
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts

+1-2
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,7 @@ export const createConv2DTransposeMatMulProgramInfo =
204204
{type: DataType.int32, data: pads}
205205
];
206206
appendActivationUniformsData(attributes, programUniforms);
207-
programUniforms.push(
208-
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
207+
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
209208

210209
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
211210
if (hasBias) {

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ export const createConvTranspose2DProgramInfo =
269269
{type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations},
270270
{type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads},
271271
{type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup},
272-
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)
272+
...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)
273273
];
274274
if (hasBias) {
275275
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));

js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts

+1-3
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,7 @@ export const createMatmulProgramInfo =
453453
{type: DataType.int32, data: dimInner}
454454
];
455455
appendActivationUniformsData(activationAttributes, programUniforms);
456-
programUniforms.push(
457-
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
458-
...createTensorShapeVariables(bShapeTemp));
456+
programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp));
459457
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
460458

461459
const hasBias = inputs.length > 2;

js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts

+1-3
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,7 @@ const createBinaryOpProgramInfo =
180180
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
181181
programUniforms: [
182182
{type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
183-
...createTensorShapeVariables(a.dims),
184-
...createTensorShapeVariables(b.dims),
185-
...createTensorShapeVariables(outputShape),
183+
...createTensorShapeVariables(a.dims, b.dims, outputShape)
186184
],
187185
}),
188186
};

js/web/lib/wasm/jsep/webgpu/ops/common.ts

+10-3
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,16 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 =
259259
return typeof mappedType === 'string' ? mappedType : mappedType[1];
260260
};
261261

262-
export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ?
263-
[] :
264-
[{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}];
262+
export const createTensorShapeVariables = (...dims: ReadonlyArray<readonly number[]>): ProgramUniform[] => {
263+
const programUniforms: ProgramUniform[] = [];
264+
dims.forEach(dim => {
265+
if (dim.length !== 0) {
266+
programUniforms.push(
267+
{type: DataType.uint32, data: dim}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dim)});
268+
}
269+
});
270+
return programUniforms;
271+
};
265272

266273
/**
267274
* A helper function to get maximum vector size for specified data length

js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts

+2-6
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ export const createGroupedConvProgramInfo =
3535
{type: DataType.uint32, data: outputChannelsPerGroup}
3636
];
3737
appendActivationUniformsData(attributes, programUniforms);
38-
programUniforms.push(
39-
...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape),
40-
...createTensorShapeVariables(outputShape));
38+
programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShape));
4139
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
4240
if (hasBias) {
4341
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
@@ -134,9 +132,7 @@ export const createGroupedConvVectorizeProgramInfo =
134132
{type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}
135133
];
136134
appendActivationUniformsData(attributes, programUniforms);
137-
programUniforms.push(
138-
...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape),
139-
...createTensorShapeVariables(outputShapeInShader));
135+
programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader));
140136
const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
141137
const getShaderSource = (shaderHelper: ShaderHelper) => {
142138
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);

js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const createCumsumProgramInfo =
5555
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
5656
programUniforms: [
5757
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis},
58-
...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)
58+
...createTensorShapeVariables(inputShape, inputShape)
5959
]
6060

6161
}),

js/web/lib/wasm/jsep/webgpu/ops/expand.ts

+2-4
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,8 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
8484
${assignment}`;
8585
};
8686

87-
const programUniforms: ProgramUniform[] = [
88-
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape),
89-
...createTensorShapeVariables(outputShape)
90-
];
87+
const programUniforms: ProgramUniform[] =
88+
[{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)];
9189
return {
9290
name: 'Expand',
9391
shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']},

js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts

+1-3
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ const createGatherElementsProgramInfo =
5151
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit},
5252
{type: DataType.uint32, data: axis}
5353
];
54-
programUniforms.push(...createTensorShapeVariables(inputShape));
55-
programUniforms.push(...createTensorShapeVariables(indicesShape));
56-
programUniforms.push(...createTensorShapeVariables(outputShape));
54+
programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape));
5755
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
5856

5957
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits

js/web/lib/wasm/jsep/webgpu/ops/gather.ts

+1-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
3535

3636
const programUniforms: ProgramUniform[] = [
3737
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit},
38-
{type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims),
39-
...createTensorShapeVariables(inputs[1].dims), ...createTensorShapeVariables(outputShape)
38+
{type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape)
4039
];
4140

4241
const getShaderSource = (shaderHelper: ShaderHelper) => {

js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ const createInstanceNormProgramInfo =
2626
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
2727
const programUniforms: ProgramUniform[] =
2828
[{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}];
29-
programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape));
29+
programUniforms.push(...createTensorShapeVariables(inputShape, inputShape));
3030

3131
const getShaderSource = (shaderHelper: ShaderHelper) => {
3232
const x = inputVariable('x', inputs[0].dataType, inputShape.length, components);

js/web/lib/wasm/jsep/webgpu/ops/matmul.ts

+1-3
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ export const createNaiveMatmulProgramInfo =
3434
{type: DataType.uint32, data: K}
3535
];
3636
appendActivationUniformsData(activationAttributes, programUniforms);
37-
programUniforms.push(
38-
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
39-
...createTensorShapeVariables(bShape));
37+
programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape));
4038
if (hasBias) {
4139
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
4240
}

js/web/lib/wasm/jsep/webgpu/ops/pad.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
158158
programUniforms.push({type: inputs[0].dataType, data: attributes.value});
159159
}
160160

161-
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape));
161+
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape));
162162
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
163163

164164
const getShaderSource = (shaderHelper: ShaderHelper) => {

js/web/lib/wasm/jsep/webgpu/ops/pool.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ const createAveragePoolProgramInfo =
298298
}
299299
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
300300
getUniformAndPadInfo(outputShape, adjustedAttributes);
301-
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
301+
programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
302302
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
303303
return {
304304
name,
@@ -370,7 +370,7 @@ const createMaxPoolProgramInfo =
370370
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
371371
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
372372
getUniformAndPadInfo(outputShape, adjustedAttributes);
373-
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
373+
programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
374374
return {
375375
name,
376376
shaderCache:

js/web/lib/wasm/jsep/webgpu/ops/reduce.ts

+2-4
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,8 @@ export const createReduceProgramInfo =
100100
getRunData: () => ({
101101
outputs: [{dims: outputShape, dataType: outputDataType}],
102102
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
103-
programUniforms: [
104-
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape),
105-
...createTensorShapeVariables(outputShape)
106-
]
103+
programUniforms:
104+
[{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)]
107105
}),
108106
};
109107
};

js/web/lib/wasm/jsep/webgpu/ops/resize.ts

+2-5
Original file line numberDiff line numberDiff line change
@@ -642,11 +642,8 @@ const createResizeProgramInfo =
642642
outputs: [{dims: outputShape, dataType: inputTensor.dataType}],
643643
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
644644
programUniforms: [
645-
{type: DataType.uint32, data: outputSize},
646-
{type: DataType.float, data: scales},
647-
{type: DataType.float, data: roi},
648-
...createTensorShapeVariables(inputShape),
649-
...createTensorShapeVariables(outputShape),
645+
{type: DataType.uint32, data: outputSize}, {type: DataType.float, data: scales},
646+
{type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape, outputShape)
650647
]
651648
})
652649
};

js/web/lib/wasm/jsep/webgpu/ops/slice.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
157157
const programUniforms: ProgramUniform[] = [
158158
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts},
159159
{type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps},
160-
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)
160+
...createTensorShapeVariables(inputs[0].dims, outputShape)
161161
];
162162

163163
const getShaderSource = (shaderHelper: ShaderHelper) => `

js/web/lib/wasm/jsep/webgpu/ops/split.ts

+2-3
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,8 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
8383
outputs[i] = outputVariable(`output${i}`, dataType, outputShape);
8484
outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType});
8585
}
86-
programUniforms.push({type: DataType.uint32, data: sizeInSplitAxis});
87-
programUniforms.push(...createTensorShapeVariables(inputShape));
88-
outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape)));
86+
programUniforms.push(
87+
{type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes));
8988
const getShaderSource = (shaderHelper: ShaderHelper) => `
9089
${
9190
shaderHelper.registerUniform('input_size', 'u32')

js/web/lib/wasm/jsep/webgpu/ops/tile.ts

+2-4
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,8 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf
7979
getRunData: () => ({
8080
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
8181
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
82-
programUniforms: [
83-
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims),
84-
...createTensorShapeVariables(outputShape)
85-
],
82+
programUniforms:
83+
[{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)],
8684
}),
8785
getShaderSource,
8886
};

js/web/lib/wasm/jsep/webgpu/ops/transpose.ts

+2-5
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,8 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
6565
return {
6666
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
6767
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
68-
programUniforms: [
69-
{type: DataType.uint32, data: outputSize},
70-
...createTensorShapeVariables(inputs[0].dims),
71-
...createTensorShapeVariables(outputShape),
72-
],
68+
programUniforms:
69+
[{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)],
7370
};
7471
},
7572
getShaderSource,

js/web/lib/wasm/jsep/webgpu/ops/where.ts

+2-5
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
9797
getRunData: () => ({
9898
outputs: [{dims: outputShape, dataType: outputDataType}],
9999
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)},
100-
programUniforms: [
101-
{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC),
102-
...createTensorShapeVariables(dimsA), ...createTensorShapeVariables(dimsB),
103-
...createTensorShapeVariables(outputShape)
104-
],
100+
programUniforms:
101+
[{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)],
105102
}),
106103
};
107104
};

0 commit comments

Comments
 (0)