From e8711389907caf9f21e5150bfa58660be12b395d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 27 May 2022 16:12:56 -0700 Subject: [PATCH] slice (scalar) --- .../lib/onnxjs/backends/webgpu/ops/slice.ts | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/slice.ts b/js/web/lib/onnxjs/backends/webgpu/ops/slice.ts index fda8659993188..c5642c0921811 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/slice.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/slice.ts @@ -41,6 +41,29 @@ export const parseSliceAttributes: OperatorInitialization = (no return createAttributeWithCacheKey({starts, ends, axes}); }; +const offsetToIndices = (offset: string, strides: readonly number[], indicesPrefix: string): string => { + const outputLines: string[] = []; + + for (let i = 0; i < strides.length - 1; i++) { + outputLines.push(`var ${indicesPrefix}${i}=${offset}/${strides[i]}u;`); + outputLines.push(`${offset}%=${strides[i]}u;`); + } + outputLines.push(`var ${indicesPrefix}${strides.length - 1}=${offset};`); + + return outputLines.join('\n'); +}; + +const indicesToOffset = (indicesPrefix: string, strides: readonly number[], offset: string): string => { + const outputLines: string[] = []; + + for (let i = 0; i < strides.length - 1; i++) { + outputLines.push(`${offset}+=${indicesPrefix}${i} * ${strides[i]}u;`); + } + outputLines.push(`${offset}+=${indicesPrefix}${strides.length - 1};`); + + return outputLines.join('\n'); +}; + const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, dataType = 'f32'): ProgramInfo => { const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((val, i) => i) : attributes.axes; const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length); @@ -59,12 +82,11 @@ const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, data const outputShape = input.dims.slice(); - const sliceOps: Array<[number, number]> = []; + const sliceOps: string[] = []; for (let i = 0; i < normalizedAxes.length; i++) { outputShape[normalizedAxes[i]] = ends[i] - starts[i]; if (starts[i] > 0) { - // sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`); - sliceOps.push([normalizedAxes[i], starts[i]]); + sliceOps.push(`idx_${normalizedAxes[i]} += ${starts[i]}u;`); } // else { sliceOps.push(`outputIdx[${normalizedAxes[i]}] += 0;`); } } @@ -84,8 +106,11 @@ const createSliceProgramInfo = (input: Tensor, attributes: SliceAttributes, data } var offset = global_id.x; - ${sliceOps.map(i => `offset += ${i[1]}u * ${outputStrides[i[0]]}u;`).join('')} - output[global_id.x] = input[offset]; + ${offsetToIndices('offset', outputStrides, 'idx_')} + ${sliceOps.join('')} + var offsetInput = 0u; + ${indicesToOffset('idx_', ShapeUtil.computeStrides(input.dims), 'offsetInput')} + output[global_id.x] = input[offsetInput]; }`; return { ...sliceProgramMetadata,