From 67eda7755a8e64e013a32599247049ab264536bb Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 16 Oct 2023 14:23:14 -0700 Subject: [PATCH 1/2] Enabled 1d spacial input to GlobalAveragePool --- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 50 ++++++++++++------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 05f02b07c4d89..48e0bb0df10f6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Pool ops requires 1 input.'); } - if (inputs[0].dims.length !== 4) { - throw new Error('Pool ops supports 2-D inputs only for now.'); + if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) { + throw new Error('Pool ops supports 1-D or 2-D inputs only for now.'); } }; const getAdjustedPoolAttributesAndOutputShape = ( input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => { const isChannelsLast = attributes.format === 'NHWC'; - const inputShapeAsChannelFirst = - isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice(); + const inputShapeAsChannelFirst = input.dims.slice(); + if (isChannelsLast) { + inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position. + } const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); const kernelShape = attributes.kernelShape.slice(); const strides = attributes.strides.slice(); @@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = ( @@ -76,22 +72,22 @@ const generatePoolingCode = = ${inputDims[dimIdxW]}) { - pad++; - continue; - } - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) { + pad++; + continue; + } + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } else { codeW = ` - for (var i: u32 = 0u; i < ${kw}u; i++) { - xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } if (attributes.kernelShape.length === 2) { From da2fb5e68bf3e513a575187cb041059cc3ff8953 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 17 Oct 2023 13:15:50 -0700 Subject: [PATCH 2/2] Fixed computing channel-last output shape --- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 48e0bb0df10f6..1538644412afd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -47,7 +47,7 @@ const getAdjustedPoolAttributesAndOutputShape =