From b6e7fbae4d20e323daff6e92572dd7755afdbc19 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 23 Mar 2022 15:36:58 -0700 Subject: [PATCH] reshape --- .../backends/webgpu/gpu-data-manager.ts | 80 +++++++---- .../backends/webgpu/inference-handler.ts | 45 +++---- .../backends/webgpu/op-resolve-rules.ts | 7 +- .../lib/onnxjs/backends/webgpu/ops/reshape.ts | 11 ++ .../onnxjs/backends/webgpu/ops/unsqueeze.ts | 43 ++++++ .../onnxjs/backends/webgpu/session-handler.ts | 6 +- .../backends/webgpu/tensor-data-manager.ts | 127 ++++++++++++++++++ js/web/lib/onnxjs/operators.ts | 6 +- js/web/lib/onnxjs/opset.ts | 10 +- js/web/test/test-runner.ts | 5 +- .../unittests/backends/webgl/test-conv-new.ts | 2 +- 11 files changed, 275 insertions(+), 67 deletions(-) create mode 100644 js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts create mode 100644 js/web/lib/onnxjs/backends/webgpu/ops/unsqueeze.ts create mode 100644 js/web/lib/onnxjs/backends/webgpu/tensor-data-manager.ts diff --git a/js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts b/js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts index 71caefedd0f6f..3b3fd430caa32 100644 --- a/js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts +++ b/js/web/lib/onnxjs/backends/webgpu/gpu-data-manager.ts @@ -10,13 +10,29 @@ import {GpuData, GpuDataId, GpuDataType} from './types'; * manages GpuDataId -> GpuBuffer */ export interface GpuDataManager { - uploadData(tensor: Tensor, gpuDataType: GpuDataType): GpuData; - createData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData; - releaseData(tensorId: Tensor.Id): void; - downloadData(tensorId: Tensor.Id): Promise; + /** + * upload data to GPU. if the ID already exists in cache, returns the cached value without uploading anything. + */ + upload(id: GpuDataId, data: Tensor.NumberType, gpuDataType: GpuDataType): Promise; + /** + * create new data on GPU. + */ + create(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData; + /** + * get GPU data by ID. + */ + get(id: GpuDataId): GpuData|undefined; + /** + * release the data on GPU by ID. + */ + release(id: GpuDataId): void; + /** + * download the data from GPU. + */ + download(id: GpuDataId): Promise; } -interface DefaultCacheValue { +interface StorageCacheValue { gpuData: GpuData; size: number; } @@ -27,27 +43,25 @@ interface DownloadCacheValue { } class GpuDataManagerImpl implements GpuDataManager { - defaultCache: Map; + // GPU Data ID => GPU Data ( storage buffer ) + storageCache: Map; + + // GPU Data ID => GPU Data ( read buffer ) downloadCache: Map; + constructor(private device: GPUDevice) { - this.defaultCache = new Map(); + this.storageCache = new Map(); this.downloadCache = new Map(); } - uploadData(tensor: Tensor, gpuDataType: GpuDataType): GpuData { + async upload(id: GpuDataId, data: Tensor.NumberType, gpuDataType: GpuDataType): Promise { if (gpuDataType !== GpuDataType.default) { throw new Error('we only support default GPU data type now'); } - const cachedData = this.defaultCache.get(tensor.dataId); - if (cachedData) { - return cachedData.gpuData; - } - - const src = tensor.numberData; - const srcArrayBuffer = src.buffer; - const srcOffset = src.byteOffset; - const srcLength = src.byteLength; + const srcArrayBuffer = data.buffer; + const srcOffset = data.byteOffset; + const srcLength = data.byteLength; // create gpu buffer const gpuBuffer = @@ -58,12 +72,12 @@ class GpuDataManagerImpl implements GpuDataManager { new Uint8Array(arrayBuffer).set(new Uint8Array(srcArrayBuffer, srcOffset, srcLength)); gpuBuffer.unmap(); - const gpuData = {id: tensor.dataId, type: GpuDataType.default, buffer: gpuBuffer}; - this.defaultCache.set(gpuData.id, {gpuData, size: srcLength}); + const gpuData = {id, type: GpuDataType.default, buffer: gpuBuffer}; + this.storageCache.set(id, {gpuData, size: srcLength}); return gpuData; } - createData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData { + create(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData { if (gpuDataType !== GpuDataType.default) { throw new Error('we only support default GPU data type now'); } @@ -82,27 +96,39 @@ class GpuDataManagerImpl implements GpuDataManager { this.device.createBuffer({size: bufferLength, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC}); const gpuData = {id: Guid.create(), type: GpuDataType.default, buffer: gpuBuffer}; - this.defaultCache.set(gpuData.id, {gpuData, size: bufferLength}); + this.storageCache.set(gpuData.id, {gpuData, size: bufferLength}); return gpuData; } - releaseData(tensorId: Tensor.Id): void { - const cachedData = this.defaultCache.get(tensorId); + get(id: GpuDataId): GpuData|undefined { + return this.storageCache.get(id)?.gpuData; + } + + release(id: GpuDataId): void { + const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('releasing data does not exist'); } - this.defaultCache.delete(tensorId); + this.storageCache.delete(id); cachedData.gpuData.buffer.destroy(); + + const downloadingData = this.downloadCache.get(id); + if (downloadingData) { + void downloadingData.data.then(() => { + downloadingData.gpuData.buffer.destroy(); + }); + this.downloadCache.delete(id); + } } - async downloadData(tensorId: Tensor.Id): Promise { - const downloadData = this.downloadCache.get(tensorId); + async download(id: GpuDataId): Promise { + const downloadData = this.downloadCache.get(id); if (downloadData) { return downloadData.data; } - const cachedData = this.defaultCache.get(tensorId); + const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('data does not exist'); } diff --git a/js/web/lib/onnxjs/backends/webgpu/inference-handler.ts b/js/web/lib/onnxjs/backends/webgpu/inference-handler.ts index aa2d6db8be111..5b296df903538 100644 --- a/js/web/lib/onnxjs/backends/webgpu/inference-handler.ts +++ b/js/web/lib/onnxjs/backends/webgpu/inference-handler.ts @@ -2,10 +2,10 @@ // Licensed under the MIT License. import {InferenceHandler} from '../../backend'; -import {createView, Tensor} from '../../tensor'; +import {Tensor} from '../../tensor'; -import {createGpuDataManager, GpuDataManager} from './gpu-data-manager'; import {WebGpuSessionHandler} from './session-handler'; +import {createTensorDataManager, TensorDataManager} from './tensor-data-manager'; import {GpuData, GpuDataType, ProgramInfo, ProgramInfoLoader} from './types'; const getProgramInfoUniqueKey = (programInfo: ProgramInfo|ProgramInfoLoader, inputGpuDatas: GpuData[]): string => { @@ -19,24 +19,26 @@ const getProgramInfoUniqueKey = (programInfo: ProgramInfo|ProgramInfoLoader, inp }; export class WebGpuInferenceHandler implements InferenceHandler { - dataManager: GpuDataManager; + // per inference context + dataManager: TensorDataManager; + constructor(public session: WebGpuSessionHandler) { - this.dataManager = createGpuDataManager(session.backend.device); + this.dataManager = createTensorDataManager(session.backend.device); } private uploadGpuData(tensor: Tensor, textureType: GpuDataType): GpuData { if (this.session.isInitializer(tensor.dataId)) { - return this.session.dataManager.uploadData(tensor, textureType); + return this.session.dataManager.uploadTensorToGpu(tensor, textureType); } - return this.dataManager.uploadData(tensor, textureType); + return this.dataManager.uploadTensorToGpu(tensor, textureType); } - private createGpuData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): GpuData { - return this.dataManager.createData(type, dims, gpuDataType); + private createGpuData(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): [Tensor, GpuData] { + return this.dataManager.createGpuTensor(type, dims, gpuDataType); } - run(program: ProgramInfoLoader|ProgramInfo, inputs: readonly Tensor[]): Tensor[] { + async run(program: ProgramInfoLoader|ProgramInfo, inputs: readonly Tensor[]): Promise { if (inputs.length !== program.inputTypes.length) { throw new Error(`Input size must be equal to ${program.inputTypes.length}.`); } @@ -56,9 +58,12 @@ export class WebGpuInferenceHandler implements InferenceHandler { // create texture info for outputs const outputDatas: GpuData[] = []; + const outputTensors: Tensor[] = []; for (let i = 0; i < programInfo.outputs.length; ++i) { - outputDatas.push(this.createGpuData( - programInfo.outputs[i].type, programInfo.outputs[i].dims, programInfo.outputs[i].gpuDataType)); + const [tensor, gpuData] = this.createGpuData( + programInfo.outputs[i].type, programInfo.outputs[i].dims, programInfo.outputs[i].gpuDataType); + outputTensors.push(tensor); + outputDatas.push(gpuData); } if (!artifact) { @@ -68,20 +73,14 @@ export class WebGpuInferenceHandler implements InferenceHandler { this.session.programManager.run(artifact, inputDatas, outputDatas, artifact.programInfo.dispatchGroup(inputs)); - const outputTensors: Tensor[] = []; - for (let i = 0; i < outputDatas.length; i++) { - const outputTensorInfo = artifact.programInfo.outputs[i]; - const dims = outputTensorInfo.dims; - const type = outputTensorInfo.type; - const outputData = outputDatas[i]; - const tensor = new Tensor(dims, type, undefined, async () => { - const data = await this.dataManager.downloadData(outputData.id); - return createView(data, type); - }, undefined, outputData.id); - outputTensors.push(tensor); - } return outputTensors; } + reshape(input: Tensor, reshapedDims: readonly number[]): Tensor { + return this.dataManager.hasGpuData(input.dataId) ? + this.dataManager.createGpuRef(input.dataId, input.type, reshapedDims)[0] : + new Tensor(reshapedDims, input.type, undefined, undefined, input.data); + } + dispose(): void {} } diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index 432fa4c7eeb00..c7595d8325661 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -3,7 +3,9 @@ import {OpSet} from '../../opset'; +import {reshape} from './ops/reshape'; import * as unaryOps from './ops/unary-op'; +import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze'; export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ ['Abs', '', '6+', unaryOps.abs], ['Acos', '', '7+', unaryOps.acos], @@ -59,7 +61,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // ['ReduceSum', '', '1-12', reduceSum, parseReduceAttributes], // ['ReduceSumSquare', '', '1+', reduceLogSumSquare, parseReduceAttributes], // ['Relu', '', '6+', unaryOps.relu], - // ['Reshape', '', '5+', reshape], + ['Reshape', '', '5+', reshape], // ['Resize', '', '10', resize, parseResizeAttributesV10], // ['Resize', '', '11+', resize, parseResizeAttributesV11], // ['Shape', '', '1+', shape], @@ -86,7 +88,6 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // ['Transpose', '', '1+', transpose, parseTransposeAttributes], // ['Upsample', '', '7-8', upsample, parseUpsampleAttributesV7], // ['Upsample', '', '9', upsample, parseUpsampleAttributesV9], - // ['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes], - // ['Unsqueeze', '', '13+', unsqueezeV13], + ['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes], ['Unsqueeze', '', '13+', unsqueezeV13], // ['Xor', '', '7+', binaryOps.xor], ]; diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts b/js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts new file mode 100644 index 0000000000000..0efb46960c2c9 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Tensor} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {WebGpuInferenceHandler} from '../inference-handler'; + +export const reshape = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => { + const reshapedDims = ShapeUtil.calculateReshapedDims(inputs[0].dims, inputs[1].integerData); + return [handler.reshape(inputs[0], reshapedDims)]; +}; diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/unsqueeze.ts b/js/web/lib/onnxjs/backends/webgpu/ops/unsqueeze.ts new file mode 100644 index 0000000000000..8a099dc92cbd9 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/ops/unsqueeze.ts @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Graph} from '../../../graph'; +import {OperatorInitialization} from '../../../operators'; +import {Tensor} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {WebGpuInferenceHandler} from '../inference-handler'; + +export const unsqueeze = (inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], axes: number[]): Tensor[] => { + validateInputs(inputs); + const outputShape = ShapeUtil.unsqueezeShape(inputs[0].dims, axes); + const output = inferenceHandler.reshape(inputs[0], outputShape); + return [output]; +}; + +export const unsqueezeV13 = (inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => { + validateInputsV13(inputs); + return unsqueeze(inferenceHandler, [inputs[0]], Array.from(inputs[1].integerData)); +}; + +export const parseUnsqueezeAttributes: OperatorInitialization = (node: Graph.Node): number[] => + node.attributes.getInts('axes'); + +const validateInputs = (inputs: Tensor[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Unsqueeze requires 1 input.'); + } + + if (inputs[0].type === 'string') { + throw new Error('invalid input tensor types.'); + } +}; + +const validateInputsV13 = (inputs: Tensor[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('Unsqueeze requires 2 inputs.'); + } + + if (inputs[1].type !== 'int32') { + throw new Error('Invalid input type.'); + } +}; diff --git a/js/web/lib/onnxjs/backends/webgpu/session-handler.ts b/js/web/lib/onnxjs/backends/webgpu/session-handler.ts index 91ce347f28c56..db0d893f2d5bf 100644 --- a/js/web/lib/onnxjs/backends/webgpu/session-handler.ts +++ b/js/web/lib/onnxjs/backends/webgpu/session-handler.ts @@ -9,18 +9,18 @@ import {Session} from '../../session'; import {Tensor} from '../../tensor'; import {WebGpuBackend} from '../backend-webgpu'; -import {createGpuDataManager, GpuDataManager} from './gpu-data-manager'; import {WebGpuInferenceHandler} from './inference-handler'; import {WEBGPU_OP_RESOLVE_RULES} from './op-resolve-rules'; import {ProgramManager} from './program-manager'; +import {createTensorDataManager, TensorDataManager} from './tensor-data-manager'; export class WebGpuSessionHandler implements SessionHandler { private initializers: Set; - readonly dataManager: GpuDataManager; + readonly dataManager: TensorDataManager; programManager: ProgramManager; constructor(public readonly backend: WebGpuBackend, public readonly context: Session.Context) { - this.dataManager = createGpuDataManager(this.backend.device); + this.dataManager = createTensorDataManager(this.backend.device); this.programManager = new ProgramManager(this.backend.device, this.context.profiler); } diff --git a/js/web/lib/onnxjs/backends/webgpu/tensor-data-manager.ts b/js/web/lib/onnxjs/backends/webgpu/tensor-data-manager.ts new file mode 100644 index 0000000000000..72a5239a5f4c2 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/tensor-data-manager.ts @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {createView, Tensor} from '../../tensor'; +import {createGpuDataManager, GpuDataManager} from './gpu-data-manager'; +import {GpuData, GpuDataId, GpuDataType} from './types'; + +/** + * manages Tensor ID -> Gpu Data ID + */ +export interface TensorDataManager { + /** + * upload a CPU tensor to GPU. + */ + uploadTensorToGpu(tensor: Tensor, gpuDataType: GpuDataType): GpuData; + + /** + * create a new GPU tensor. + */ + createGpuTensor(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): [Tensor, GpuData]; + + /** + * check whether the tensor has GPU data + */ + hasGpuData(tensorId: Tensor.Id): boolean; + + /** + * create a reference to the GPU data. + */ + createGpuRef(tensorId: Tensor.Id, type: Tensor.DataType, dims: readonly number[]): [Tensor, GpuData]; + + /** + * release the GPU resources referred by the tensor. + */ + releaseGpuTensor(tensorId: Tensor.Id): void; +} + +class TensorDataManagerImpl implements TensorDataManager { + private map: Map; + private reverseMap: Map>; + + constructor(private gpuDataManager: GpuDataManager) { + this.map = new Map(); + this.reverseMap = new Map(); + } + + private registerIdMapping(tensorId: Tensor.Id, gpuDataId: GpuDataId): void { + this.map.set(tensorId, gpuDataId); + + let tensorIds = this.reverseMap.get(gpuDataId); + if (!tensorIds) { + tensorIds = new Set(); + this.reverseMap.set(gpuDataId, tensorIds); + } + tensorIds.add(tensorId); + } + + uploadTensorToGpu(tensor: Tensor, gpuDataType: GpuDataType): GpuData { + const gpuDataId = this.map.get(tensor.dataId); + if (gpuDataId) { + const gpuData = this.gpuDataManager.get(gpuDataId); + if (!gpuData) { + throw new Error('internal error. this should never happen'); + } + return gpuData; + } + + const gpuData = this.gpuDataManager.create(tensor.type, tensor.dims, gpuDataType); + this.registerIdMapping(tensor.dataId, gpuData.id); + return gpuData; + } + + createGpuTensor(type: Tensor.DataType, dims: readonly number[], gpuDataType: GpuDataType): [Tensor, GpuData] { + const gpuData = this.gpuDataManager.create(type, dims, gpuDataType); + const tensor = new Tensor(dims, type, undefined, async () => { + const data = await this.gpuDataManager.download(gpuData.id); + return createView(data, type); + }); + + this.registerIdMapping(tensor.dataId, gpuData.id); + return [tensor, gpuData]; + } + + hasGpuData(tensorId: Tensor.Id): boolean { + return this.map.has(tensorId); + } + + createGpuRef(tensorId: Tensor.Id, type: Tensor.DataType, dims: readonly number[]): [Tensor, GpuData] { + const gpuDataId = this.map.get(tensorId); + if (!gpuDataId) { + throw new Error('internal error. this should never happen'); + } + + const gpuData = this.gpuDataManager.get(gpuDataId); + if (!gpuData) { + throw new Error('internal error. this should never happen'); + } + + const tensor = new Tensor(dims, type, undefined, async () => { + const data = await this.gpuDataManager.download(gpuData.id); + return createView(data, type); + }); + + this.registerIdMapping(tensor.dataId, gpuData.id); + return [tensor, gpuData]; + } + + releaseGpuTensor(tensorId: Tensor.Id): void { + const gpuDataId = this.map.get(tensorId); + if (gpuDataId) { + this.map.delete(tensorId); + + const tensorIds = this.reverseMap.get(gpuDataId); + if (!tensorIds) { + throw new Error('internal error. this should never happen'); + } + tensorIds.delete(tensorId); + if (tensorIds.size === 0) { + this.gpuDataManager.release(gpuDataId); + this.reverseMap.delete(gpuDataId); + } + } + } +} + +export const createTensorDataManager = (device: GPUDevice): TensorDataManager => + new TensorDataManagerImpl(createGpuDataManager(device)); diff --git a/js/web/lib/onnxjs/operators.ts b/js/web/lib/onnxjs/operators.ts index 4d664f6dcda5a..2117484316dca 100644 --- a/js/web/lib/onnxjs/operators.ts +++ b/js/web/lib/onnxjs/operators.ts @@ -5,11 +5,13 @@ import {InferenceHandler} from './backend'; import {Graph} from './graph'; import {Tensor} from './tensor'; -export type OperatorImplementation = (inferenceHandler: InferenceHandler, inputs: Tensor[], context: T) => Tensor[]; +export type OperatorImplementation = Tensor[]> = + (inferenceHandler: InferenceHandler, inputs: Tensor[], context: ContextType) => ReturnType; +export type OperatorAsyncImplementation = OperatorImplementation>; export type OperatorInitialization = (node: Graph.Node, graph: Graph) => T; export interface Operator { - readonly impl: OperatorImplementation; + readonly impl: OperatorImplementation|OperatorAsyncImplementation; readonly context: Graph.Node|unknown; } diff --git a/js/web/lib/onnxjs/opset.ts b/js/web/lib/onnxjs/opset.ts index e23a288b4e22b..12618969efc1a 100644 --- a/js/web/lib/onnxjs/opset.ts +++ b/js/web/lib/onnxjs/opset.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {Graph} from './graph'; -import {OperatorImplementation, OperatorInitialization} from './operators'; +import {OperatorAsyncImplementation, OperatorImplementation, OperatorInitialization} from './operators'; export interface OpSet { domain: string; @@ -19,9 +19,11 @@ export declare namespace OpSet { * A resolve rule consists of 4 or 5 items: opType, opSetDomain, versionSelector, operatorImplementation and * operatorInitialization (optional) */ - type ResolveRule = [ - string, Domain, string, OperatorImplementation - ]|[string, Domain, string, OperatorImplementation, OperatorInitialization]; + type ResolveRule = + [ + string, Domain, string, OperatorImplementation| OperatorAsyncImplementation + ]|[string, Domain, string, OperatorImplementation| OperatorAsyncImplementation, + OperatorInitialization]; } export function resolveOperator(node: Graph.Node, opsets: readonly OpSet[], rules: readonly OpSet.ResolveRule[]) { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 8a11322ebca94..0478715f29cda 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -558,10 +558,7 @@ async function runOpTestcase( const inputTensors = testcase.inputs.map(input => createTensor(input.dims, input.type as Tensor.DataType, input.data)); - const results = operator.impl(inferenceHandler, inputTensors, operator.context); - // if ('then' in results) { - // results = await results; - // } + const results = await operator.impl(inferenceHandler, inputTensors, operator.context); results.forEach((output, i) => { Logger.verbose('TestOpRunner', ` Result'${i}': ${output.type}[${output.dims.join(',')}]`); diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 0fddddf58181c..fa783acb6c4d0 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -832,7 +832,7 @@ function webglConv( if (biasTensor) { inputs.push(biasTensor); } - return (op.impl(webglInferenceHandler!, inputs, op.context))[0]; + return (op.impl(webglInferenceHandler!, inputs, op.context) as Tensor[])[0]; } function cpuConv( inputTensor: Tensor, kernelTensor: Tensor, biasTensor: Tensor|null, autoPad: string|undefined, dilations: number[],