Skip to content

Commit

Permalink
shape
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent a2197f0 commit e104d17
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm';
import {matMul, parseMatMulAttributes} from './ops/matmul';
import {averagePool, globalAveragePool, globalMaxPool, maxPool, parseAveragePoolAttributes, parseGlobalAveragePoolAttributes, parseMaxPoolAttributes} from './ops/pool';
import {reshape} from './ops/reshape';
import {shape} from './ops/shape';
import {parseSliceAttributes, slice, sliceV10} from './ops/slice';
import * as unaryOps from './ops/unary-op';
import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze';
Expand Down Expand Up @@ -63,8 +64,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
['Relu', '', '6+', unaryOps.relu], ['Reshape', '', '5+', reshape],
// ['Resize', '', '10', resize, parseResizeAttributesV10],
// ['Resize', '', '11+', resize, parseResizeAttributesV11],
// ['Shape', '', '1+', shape],
['Sigmoid', '', '6+', unaryOps.sigmoid], ['Sin', '', '7+', unaryOps.sin],
['Shape', '', '1+', shape], ['Sigmoid', '', '6+', unaryOps.sigmoid], ['Sin', '', '7+', unaryOps.sin],
['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10
['Slice', '', '1-9', slice, parseSliceAttributes],
// // The "semantic" meaning of axis has changed in opset-13.
Expand Down
15 changes: 13 additions & 2 deletions js/web/lib/onnxjs/backends/webgpu/ops/reshape.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@ import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';

export const reshape = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => {
const reshapedDims = ShapeUtil.calculateReshapedDims(inputs[0].dims, inputs[1].integerData);
export const reshape = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> => {
validateInputs(inputs);
const shape = await inputs[1].getData();
const reshapedDims = ShapeUtil.calculateReshapedDims(inputs[0].dims, shape as Int32Array);
return [handler.reshape(inputs[0], reshapedDims)];
};

const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Reshape requires 2 inputs.');
}
if (inputs[1].type !== 'int32') {
throw new Error('Invalid input type.');
}
};
16 changes: 16 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/shape.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Tensor} from '../../../tensor';
import {WebGpuInferenceHandler} from '../inference-handler';

export const shape = async(inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> => {
validateInputs(inputs);
return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))];
};

const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Shape requires 1 input.');
}
};
10 changes: 9 additions & 1 deletion js/web/lib/onnxjs/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,15 @@ export class Tensor {
*/
async getData(): Promise<TensorData> {
if (this.cache === undefined) {
this.cache = await this.asyncDataProvider!(this.dataId);
if (this.asyncDataProvider) {
const data = await this.asyncDataProvider(this.dataId);
if (data.length !== this.size) {
throw new Error('Length of data provided by the Data Provider is inconsistent with the dims of this Tensor.');
}
this.cache = data;
} else {
return this.data;
}
}
return this.cache;
}
Expand Down

0 comments on commit e104d17

Please sign in to comment.