diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index 1a04f780be643..6fce8e887ee1e 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -11,6 +11,7 @@ import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm'; import {matMul, parseMatMulAttributes} from './ops/matmul'; import {averagePool, globalAveragePool, globalMaxPool, maxPool, parseAveragePoolAttributes, parseGlobalAveragePoolAttributes, parseMaxPoolAttributes} from './ops/pool'; import {reshape} from './ops/reshape'; +import {shape} from './ops/shape'; import {parseSliceAttributes, slice, sliceV10} from './ops/slice'; import * as unaryOps from './ops/unary-op'; import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze'; @@ -63,8 +64,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ ['Relu', '', '6+', unaryOps.relu], ['Reshape', '', '5+', reshape], // ['Resize', '', '10', resize, parseResizeAttributesV10], // ['Resize', '', '11+', resize, parseResizeAttributesV11], - // ['Shape', '', '1+', shape], - ['Sigmoid', '', '6+', unaryOps.sigmoid], ['Sin', '', '7+', unaryOps.sin], + ['Shape', '', '1+', shape], ['Sigmoid', '', '6+', unaryOps.sigmoid], ['Sin', '', '7+', unaryOps.sin], ['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10 ['Slice', '', '1-9', slice, parseSliceAttributes], // // The "semantic" meaning of axis has changed in opset-13. diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts b/js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts index 0efb46960c2c9..323e80bdb596a 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts @@ -5,7 +5,18 @@ import {Tensor} from '../../../tensor'; import {ShapeUtil} from '../../../util'; import {WebGpuInferenceHandler} from '../inference-handler'; -export const reshape = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => { - const reshapedDims = ShapeUtil.calculateReshapedDims(inputs[0].dims, inputs[1].integerData); +export const reshape = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => { + validateInputs(inputs); + const shape = await inputs[1].getData(); + const reshapedDims = ShapeUtil.calculateReshapedDims(inputs[0].dims, shape as Int32Array); return [handler.reshape(inputs[0], reshapedDims)]; }; + +const validateInputs = (inputs: Tensor[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('Reshape requires 2 inputs.'); + } + if (inputs[1].type !== 'int32') { + throw new Error('Invalid input type.'); + } +}; diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/shape.ts b/js/web/lib/onnxjs/backends/webgpu/ops/shape.ts new file mode 100644 index 0000000000000..94ba9293c457a --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/ops/shape.ts @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Tensor} from '../../../tensor'; +import {WebGpuInferenceHandler} from '../inference-handler'; + +export const shape = async(inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => { + validateInputs(inputs); + return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))]; +}; + +const validateInputs = (inputs: Tensor[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Shape requires 1 input.'); + } +}; diff --git a/js/web/lib/onnxjs/tensor.ts b/js/web/lib/onnxjs/tensor.ts index bd68a44806a26..db5e599fd68dc 100644 --- a/js/web/lib/onnxjs/tensor.ts +++ b/js/web/lib/onnxjs/tensor.ts @@ -131,7 +131,15 @@ export class Tensor { */ async getData(): Promise { if (this.cache === undefined) { - this.cache = await this.asyncDataProvider!(this.dataId); + if (this.asyncDataProvider) { + const data = await this.asyncDataProvider(this.dataId); + if (data.length !== this.size) { + throw new Error('Length of data provided by the Data Provider is inconsistent with the dims of this Tensor.'); + } + this.cache = data; + } else { + return this.data; + } } return this.cache; }