diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 05f02b07c4d89..1538644412afd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Pool ops requires 1 input.'); } - if (inputs[0].dims.length !== 4) { - throw new Error('Pool ops supports 2-D inputs only for now.'); + if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) { + throw new Error('Pool ops supports 1-D or 2-D inputs only for now.'); } }; const getAdjustedPoolAttributesAndOutputShape = ( input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => { const isChannelsLast = attributes.format === 'NHWC'; - const inputShapeAsChannelFirst = - isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice(); + const inputShapeAsChannelFirst = input.dims.slice(); + if (isChannelsLast) { + inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position. + } const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); const kernelShape = attributes.kernelShape.slice(); const strides = attributes.strides.slice(); @@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = ( @@ -76,22 +72,22 @@ const generatePoolingCode = = ${inputDims[dimIdxW]}) { - pad++; - continue; - } - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) { + pad++; + continue; + } + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } else { codeW = ` - for (var i: u32 = 0u; i < ${kw}u; i++) { - xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } if (attributes.kernelShape.length === 2) {