diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index 6fce8e887ee1e..f5a741478926a 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -13,6 +13,8 @@ import {averagePool, globalAveragePool, globalMaxPool, maxPool, parseAveragePool import {reshape} from './ops/reshape'; import {shape} from './ops/shape'; import {parseSliceAttributes, slice, sliceV10} from './ops/slice'; +import {parseSqueezeAttributes, squeeze, squeezeV13} from './ops/squeeze'; +import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze'; @@ -75,14 +77,12 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // // When the attribute is missing, we need the count of number of outputs // // so that we can determine the 'split' attribute from the runtime input to the Operator // ['Split', '', '2-12', split, parseSplitAttributes], - ['Sqrt', '', '6+', unaryOps.sqrt], - // ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes], - // ['Squeeze', '', '13+', squeezeV13], - ['Sub', '', '7+', binaryOps.sub], + ['Sqrt', '', '6+', unaryOps.sqrt], ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes], + ['Squeeze', '', '13+', squeezeV13], ['Sub', '', '7+', binaryOps.sub], // ['Sum', '', '6+', sum], ['Tan', '', '7+', unaryOps.tan], ['Tanh', '', '6+', unaryOps.tanh], // ['Tile', '', '6+', tile], - // ['Transpose', '', '1+', transpose, parseTransposeAttributes], + ['Transpose', '', '1+', transpose, parseTransposeAttributes], // ['Upsample', '', '7-8', upsample, parseUpsampleAttributesV7], // ['Upsample', '', '9', upsample, parseUpsampleAttributesV9], ['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes], ['Unsqueeze', '', '13+', unsqueezeV13], diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/squeeze.ts b/js/web/lib/onnxjs/backends/webgpu/ops/squeeze.ts new file mode 100644 index 0000000000000..7cd85e6877b03 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/ops/squeeze.ts @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Graph} from '../../../graph'; +import {OperatorImplementation, OperatorInitialization} from '../../../operators'; +import {Tensor} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {WebGpuInferenceHandler} from '../inference-handler'; + +export const squeeze: OperatorImplementation = + (inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], axes: number[]): Tensor[] => { + validateInputs(inputs); + const outputShape = ShapeUtil.squeezeShape(inputs[0].dims, axes); + const output = inferenceHandler.reshape(inputs[0], outputShape); + return [output]; + }; + +export const squeezeV13 = (inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => { + validateInputsV13(inputs); + return squeeze(inferenceHandler, [inputs[0]], Array.from(inputs[1].integerData)); +}; + +export const parseSqueezeAttributes: OperatorInitialization = (node: Graph.Node): number[] => + node.attributes.getInts('axes'); + +const validateInputs = (inputs: Tensor[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Squeeze requires 1 input.'); + } + + if (inputs[0].type === 'string') { + throw new Error('invalid input tensor types.'); + } +}; + +const validateInputsV13 = (inputs: Tensor[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('Squeeze requires 2 inputs.'); + } + + if (inputs[1].type !== 'int32') { + throw new Error('Invalid input type.'); + } +}; diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/transpose.ts b/js/web/lib/onnxjs/backends/webgpu/ops/transpose.ts new file mode 100644 index 0000000000000..e83dd7fcbb0b9 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/ops/transpose.ts @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; +import {Graph} from '../../../graph'; +import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators'; +import {Tensor} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {WebGpuInferenceHandler} from '../inference-handler'; +import {GpuDataType, ProgramInfo} from '../types'; + +import {createIndicesHelper, WORKGROUP_SIZE} from './common'; + +export interface TransposeAttributes extends AttributeWithCacheKey { + readonly perm: number[]; +} + +const transposeProgramMetadata = { + name: 'Transpose', + inputTypes: [GpuDataType.default] +}; + +export const transpose: OperatorAsyncImplementation = async( + inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: TransposeAttributes): Promise => { + validateInputs(inputs); + return inferenceHandler.run( + { + ...transposeProgramMetadata, + cacheHint: attributes.cacheKey, + get: () => createTransposeProgramInfo(inferenceHandler, inputs[0], attributes.perm) + }, + inputs); +}; + +export const parseTransposeAttributes: OperatorInitialization = + (node: Graph.Node): TransposeAttributes => createAttributeWithCacheKey({perm: node.attributes.getInts('perm', [])}); + +const createTransposeProgramInfo = + (_inferenceHandler: WebGpuInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => { + const dataType = 'f32'; // TODO: support other data type + const inputShape = input.dims; + perm = getAdjustedPerm(inputShape, perm); + const outputShape = getOutputShape(inputShape, perm); + const rank = inputShape.length; + const outputSize = ShapeUtil.size(outputShape); + // A dims=[${inputs[0].dims.toString()}] + // out Dims=[${unpackedOutputShape.toString()}] + // based on perm=[${perm.toString()}] + + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const inputIndicesHelper = createIndicesHelper('a', inputShape); + + const shaderSource = ` + const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u; + + @group(0) @binding(0) var a : array<${dataType}>; + @group(0) @binding(1) var output : array<${dataType}>; + + ${permFunctionBody(perm, rank)} + ${outputIndicesHelper.o2iImpl} + ${inputIndicesHelper.i2oImpl} + + @compute @workgroup_size(WORKGROUP_SIZE) + fn main(@builtin(global_invocation_id) global_id : vec3) { + + // Guard against out-of-bounds work group sizes + if (global_id.x >= ${outputSize}u) { + return; + } + + ${outputIndicesHelper.indicesVariableDeclaration('indices')} + ${outputIndicesHelper.o2iCall('global_id.x', 'indices')} + ${inputIndicesHelper.indicesVariableDeclaration('aIndices')} + perm(&aIndices, &indices); + + output[global_id.x] = a[${inputIndicesHelper.i2oExpression('aIndices')}]; + }`; + return { + ...transposeProgramMetadata, + outputs: [{dims: outputShape, type: input.type, gpuDataType: GpuDataType.default}], + shaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] => { + if (perm && perm.length !== inputShape.length) { + perm = [...(inputShape.keys())].reverse(); + } + return perm; +}; + +const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] => { + perm = getAdjustedPerm(inputShape, perm); + return ShapeUtil.sortBasedOnPerm(inputShape, perm); +}; + +const permFunctionBody = (perm: number[], rank: number): string => { + const reverseFunc = []; + reverseFunc.push(`fn perm(a: ptr>, i: ptr>) {`); + for (let i = 0; i < rank; ++i) { + reverseFunc.push(`\t(*a)[${perm[i]}]=(*i)[${i}];`); + } + reverseFunc.push('\t}'); + return reverseFunc.join('\n'); +}; + +const validateInputs = (inputs: Tensor[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Transpose requires 1 input.'); + } + + if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') { + throw new Error('input should be float tensor'); + } +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index dc7bd9859383b..da3f01abce432 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -487,16 +487,16 @@ "test_tan_example", "test_tan", "test_tanh_example", - "test_tanh" + "test_tanh", // "test_tile", // "test_tile_precomputed", - // "test_transpose_all_permutations_0", - // "test_transpose_all_permutations_1", - // "test_transpose_all_permutations_2", - // "test_transpose_all_permutations_3", - // "test_transpose_all_permutations_4", - // "test_transpose_all_permutations_5", - // "test_transpose_default", + "test_transpose_all_permutations_0", + "test_transpose_all_permutations_1", + "test_transpose_all_permutations_2", + "test_transpose_all_permutations_3", + "test_transpose_all_permutations_4", + "test_transpose_all_permutations_5", + "test_transpose_default" // "test_unsqueeze", // "test_xor_bcast3v1d", // "test_xor_bcast3v2d", @@ -547,8 +547,8 @@ //"split.jsonc", "sqrt.jsonc", "sub.jsonc", - "tan.jsonc" - //"transpose.jsonc", + "tan.jsonc", + "transpose.jsonc" //"xor.jsonc" ] },