Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 77 additions & 19 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ export const createConvTranspose2DProgramInfo = (
const inputChannelsPerGroup = wShape[2] / group;
const outputChannelsPerGroup = wShape[3];
const aComponents = isChannelsLast ? getMaxComponents(inputChannelsPerGroup) : 1;
const packInputAs4 = isChannelsLast && outputChannelsPerGroup === 1 && inputChannelsPerGroup >= 4;
const inputChannelsPerGroupInt = packInputAs4
? Math.floor(inputChannelsPerGroup / 4) * 4
: Math.floor(inputChannelsPerGroup / aComponents) * aComponents;
const inputChannelsRemainder = inputChannelsPerGroup - inputChannelsPerGroupInt;
const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1;
const bComponents = isChannelsLast ? (outputChannelsPerGroup === 1 ? aComponents : components) : 1;
const outputSize = ShapeUtil.size(outputShape) / components;
Expand Down Expand Up @@ -78,6 +83,7 @@ export const createConvTranspose2DProgramInfo = (
{ type: DataType.uint32, data: dilations },
{ type: DataType.uint32, data: effectiveFilterDims },
{ type: DataType.int32, data: pads },
{ type: DataType.uint32, data: inputChannelsPerGroupInt },
{ type: DataType.uint32, data: inputChannelsPerGroup },
{ type: DataType.uint32, data: outputChannelsPerGroup },
...createTensorShapeVariables(inputs[0].dims, inputs[1].dims),
Expand All @@ -96,6 +102,7 @@ export const createConvTranspose2DProgramInfo = (
{ name: 'dilations', type: 'u32', length: filterDims.length },
{ name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length },
{ name: 'pads', type: 'i32', length: pads.length },
{ name: 'input_channels_per_group_int', type: 'u32' },
{ name: 'input_channels_per_group', type: 'u32' },
{ name: 'output_channels_per_group', type: 'u32' },
];
Expand All @@ -114,16 +121,40 @@ export const createConvTranspose2DProgramInfo = (

const calculateResult = (): string => {
let calcStr = '';
if (aComponents === 1) {
calcStr += `
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)};
dotProd = dotProd + xValue * wValue;`;
if (packInputAs4) {
if (aComponents === 4) {
calcStr += `
let xValue = ${dy.getByOffset('x_offset')};
let wValue = ${w.getByOffset('w_offset')};
dotProd = dotProd + dot(xValue, wValue);
x_offset += 1u;
w_offset += 1u;`;
} else if (aComponents === 2) {
calcStr += `
dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}));
x_offset += 2u;
w_offset += 2u;`;
} else if (aComponents === 1) {
calcStr += `
dotProd = dotProd + dot(vec4<${dataType}>(${dy.getByOffset('x_offset')}, ${dy.getByOffset('x_offset + 1u')}, ${dy.getByOffset('x_offset + 2u')}, ${dy.getByOffset('x_offset + 3u')}), vec4<${dataType}>(${w.getByOffset('w_offset')}, ${w.getByOffset('w_offset + 1u')}, ${w.getByOffset('w_offset + 2u')}, ${w.getByOffset('w_offset + 3u')}));
x_offset += 4u;
w_offset += 4u;`;
}
} else {
if (outputChannelsPerGroup === 1) {
calcStr += `
let xValue = ${
isChannelsLast
? dy.getByOffset(
`${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`,
)
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
`;
if (aComponents === 1) {
calcStr += `
let wValue = ${w.getByOffset(`${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)} / ${bComponents}`)};
dotProd = dotProd + dot(xValue, wValue);`;
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
let wValue = ${w.getByOffset(`w_offset / ${bComponents}`)};
dotProd = dotProd + xValue * wValue;`;
} else {
for (let c = 0; c < aComponents; c++) {
calcStr += `
Expand All @@ -134,6 +165,32 @@ export const createConvTranspose2DProgramInfo = (
}
return calcStr;
};
const calculateRemainder = (): string => {
if (inputChannelsRemainder === 0) {
return '';
}
if (!packInputAs4) {
throw new Error(`packInputAs4 ${packInputAs4} is not true.`);
}
let calcStr = '';
if (aComponents === 1) {
calcStr += 'dotProd = dotProd';
for (let i = 0; i < inputChannelsRemainder; i++) {
calcStr += `
+ ${dy.getByOffset(`x_offset + ${i}`)} * ${w.getByOffset(`w_offset + ${i}`)}`;
}
calcStr += ';';
} else if (aComponents === 2) {
if (inputChannelsRemainder !== 2) {
throw new Error(`Invalid inputChannelsRemainder ${inputChannelsRemainder}.`);
}
calcStr += `
let xValue = ${dy.getByOffset('x_offset')};
let wValue = ${w.getByOffset('w_offset')};
dotProd = dotProd + dot(xValue, wValue);`;
}
return calcStr;
};
const codeSnippet = `
let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)};
let batch = ${output.indicesGet('outputIndices', 0)};
Expand Down Expand Up @@ -169,7 +226,6 @@ export const createConvTranspose2DProgramInfo = (
// Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0
wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner);
}

for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
if (wC % uniforms.dilations.y != 0) {
continue;
Expand All @@ -182,17 +238,19 @@ export const createConvTranspose2DProgramInfo = (
}
let idyC: u32 = u32(dyC);
var inputChannel = groupId * uniforms.input_channels_per_group;
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + ${aComponents}) {
let xValue = ${
isChannelsLast
? dy.getByOffset(
`${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents}`,
)
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
${
packInputAs4
? `
var x_offset = ${dy.indicesToOffset(`${dy.type.indices}(batch, idyR, idyC, inputChannel)`)} / ${aComponents};
var w_offset = ${w.indicesToOffset(`${w.type.indices}(wRPerm, wCPerm, inputChannel, wOutChannel)`)} / ${bComponents};
`
: ''
}
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + ${packInputAs4 ? 4 : aComponents}) {
${calculateResult()}
inputChannel = inputChannel + ${aComponents};
inputChannel = inputChannel + ${packInputAs4 ? 4 : aComponents};
}
${calculateRemainder()}
wC = wC + uniforms.strides.y - 1;
}
wR = wR + uniforms.strides[0] - 1;
Expand All @@ -211,7 +269,7 @@ export const createConvTranspose2DProgramInfo = (
return {
name: 'ConvTranspose2D',
shaderCache: {
hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${outputChannelsPerGroup === 1}`,
hint: `${attributes.cacheKey};${aComponents}${bComponents}${components}${packInputAs4}${inputChannelsRemainder}`,
inputDependencies,
},
getRunData: () => ({
Expand Down
122 changes: 122 additions & 0 deletions js/web/test/data/ops/conv-transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,128 @@
}
]
},
{
"name": "ConvTranspose NHWC- group - A",
"operator": "ConvTranspose",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [1, 1], "type": "ints" },
{ "name": "group", "data": 2, "type": "int" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0],
"dims": [1, 2, 3, 3],
"type": "float32"
},
{
"data": [1.0, 2.0],
"dims": [2, 1, 1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 36, 40, 44, 48, 52, 56, 60, 64, 68],
"dims": [1, 2, 3, 3],
"type": "float32"
}
]
}
]
},
{
"name": "ConvTranspose NHWC- group - B",
"operator": "ConvTranspose",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "group", "data": 3, "type": "int" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0
],
"dims": [1, 3, 3, 3],
"type": "float32"
},
{
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
"dims": [3, 1, 2, 2],
"type": "float32"
},
{
"data": [0.125, 0.25, 0.375],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.125, 1.125, 4.125, 4.125, 3.125, 13.125, 23.125, 18.125, 15.125, 43.125, 53.125, 36.125, 18.125, 45.125,
52.125, 32.125, 45.25, 104.25, 115.25, 66.25, 123.25, 279.25, 305.25, 172.25, 159.25, 357.25, 383.25,
214.25, 105.25, 232.25, 247.25, 136.25, 162.375, 351.375, 370.375, 200.375, 387.375, 833.375, 875.375,
470.375, 231.375, 494.375, 517.375, 276.375, 0.375, 0.375, 0.375, 0.375
],
"dims": [1, 3, 4, 4],
"type": "float32"
}
]
}
]
},
{
"name": "ConvTranspose NHWC- group - C",
"operator": "ConvTranspose",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "group", "data": 3, "type": "int" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
],
"dims": [1, 3, 3, 4],
"type": "float32"
},
{
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
"dims": [3, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [
0, 1, 4, 7, 6, 4, 16, 26, 36, 26, 20, 56, 66, 76, 50, 24, 59, 66, 73, 44, 60, 137, 148, 159, 90, 164, 368,
394, 420, 234, 212, 472, 498, 524, 290, 140, 307, 322, 337, 184, 216, 465, 484, 503, 270, 516, 1104, 1146,
1188, 634, 596, 1272, 1314, 1356, 722, 352, 747, 770, 793, 420
],
"dims": [1, 3, 4, 5],
"type": "float32"
}
]
}
]
},
{
"name": "ConvTranspose with bias addition C",
"operator": "ConvTranspose",
Expand Down
Loading