Skip to content

Commit

Permalink
squeeze + transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 86d8d3a commit 306a19b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 15 deletions.
10 changes: 5 additions & 5 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import {averagePool, globalAveragePool, globalMaxPool, maxPool, parseAveragePool
import {reshape} from './ops/reshape';
import {shape} from './ops/shape';
import {parseSliceAttributes, slice, sliceV10} from './ops/slice';
import {parseSqueezeAttributes, squeeze, squeezeV13} from './ops/squeeze';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze';

Expand Down Expand Up @@ -75,14 +77,12 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// // When the attribute is missing, we need the count of number of outputs
// // so that we can determine the 'split' attribute from the runtime input to the Operator
// ['Split', '', '2-12', split, parseSplitAttributes],
['Sqrt', '', '6+', unaryOps.sqrt],
// ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes],
// ['Squeeze', '', '13+', squeezeV13],
['Sub', '', '7+', binaryOps.sub],
['Sqrt', '', '6+', unaryOps.sqrt], ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes],
['Squeeze', '', '13+', squeezeV13], ['Sub', '', '7+', binaryOps.sub],
// ['Sum', '', '6+', sum],
['Tan', '', '7+', unaryOps.tan], ['Tanh', '', '6+', unaryOps.tanh],
// ['Tile', '', '6+', tile],
// ['Transpose', '', '1+', transpose, parseTransposeAttributes],
['Transpose', '', '1+', transpose, parseTransposeAttributes],
// ['Upsample', '', '7-8', upsample, parseUpsampleAttributesV7],
// ['Upsample', '', '9', upsample, parseUpsampleAttributesV9],
['Unsqueeze', '', '1-12', unsqueeze, parseUnsqueezeAttributes], ['Unsqueeze', '', '13+', unsqueezeV13],
Expand Down
44 changes: 44 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/squeeze.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Graph} from '../../../graph';
import {OperatorImplementation, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';

export const squeeze: OperatorImplementation<number[]> =
(inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], axes: number[]): Tensor[] => {
validateInputs(inputs);
const outputShape = ShapeUtil.squeezeShape(inputs[0].dims, axes);
const output = inferenceHandler.reshape(inputs[0], outputShape);
return [output];
};

export const squeezeV13 = (inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => {
validateInputsV13(inputs);
return squeeze(inferenceHandler, [inputs[0]], Array.from(inputs[1].integerData));
};

export const parseSqueezeAttributes: OperatorInitialization<number[]> = (node: Graph.Node): number[] =>
node.attributes.getInts('axes');

const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Squeeze requires 1 input.');
}

if (inputs[0].type === 'string') {
throw new Error('invalid input tensor types.');
}
};

const validateInputsV13 = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Squeeze requires 2 inputs.');
}

if (inputs[1].type !== 'int32') {
throw new Error('Invalid input type.');
}
};
116 changes: 116 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key';
import {Graph} from '../../../graph';
import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo} from '../types';

import {createIndicesHelper, WORKGROUP_SIZE} from './common';

export interface TransposeAttributes extends AttributeWithCacheKey {
readonly perm: number[];
}

const transposeProgramMetadata = {
name: 'Transpose',
inputTypes: [GpuDataType.default]
};

export const transpose: OperatorAsyncImplementation<TransposeAttributes> = async(
inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: TransposeAttributes): Promise<Tensor[]> => {
validateInputs(inputs);
return inferenceHandler.run(
{
...transposeProgramMetadata,
cacheHint: attributes.cacheKey,
get: () => createTransposeProgramInfo(inferenceHandler, inputs[0], attributes.perm)
},
inputs);
};

export const parseTransposeAttributes: OperatorInitialization<TransposeAttributes> =
(node: Graph.Node): TransposeAttributes => createAttributeWithCacheKey({perm: node.attributes.getInts('perm', [])});

const createTransposeProgramInfo =
(_inferenceHandler: WebGpuInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => {
const dataType = 'f32'; // TODO: support other data type
const inputShape = input.dims;
perm = getAdjustedPerm(inputShape, perm);
const outputShape = getOutputShape(inputShape, perm);
const rank = inputShape.length;
const outputSize = ShapeUtil.size(outputShape);
// A dims=[${inputs[0].dims.toString()}]
// out Dims=[${unpackedOutputShape.toString()}]
// based on perm=[${perm.toString()}]

const outputIndicesHelper = createIndicesHelper('output', outputShape);
const inputIndicesHelper = createIndicesHelper('a', inputShape);

const shaderSource = `
const WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> a : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> output : array<${dataType}>;
${permFunctionBody(perm, rank)}
${outputIndicesHelper.o2iImpl}
${inputIndicesHelper.i2oImpl}
@compute @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
if (global_id.x >= ${outputSize}u) {
return;
}
${outputIndicesHelper.indicesVariableDeclaration('indices')}
${outputIndicesHelper.o2iCall('global_id.x', 'indices')}
${inputIndicesHelper.indicesVariableDeclaration('aIndices')}
perm(&aIndices, &indices);
output[global_id.x] = a[${inputIndicesHelper.i2oExpression('aIndices')}];
}`;
return {
...transposeProgramMetadata,
outputs: [{dims: outputShape, type: input.type, gpuDataType: GpuDataType.default}],
shaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] => {
if (perm && perm.length !== inputShape.length) {
perm = [...(inputShape.keys())].reverse();
}
return perm;
};

const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] => {
perm = getAdjustedPerm(inputShape, perm);
return ShapeUtil.sortBasedOnPerm(inputShape, perm);
};

const permFunctionBody = (perm: number[], rank: number): string => {
const reverseFunc = [];
reverseFunc.push(`fn perm(a: ptr<function, array<u32, ${rank}>>, i: ptr<function, array<u32, ${rank}>>) {`);
for (let i = 0; i < rank; ++i) {
reverseFunc.push(`\t(*a)[${perm[i]}]=(*i)[${i}];`);
}
reverseFunc.push('\t}');
return reverseFunc.join('\n');
};

const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Transpose requires 1 input.');
}

if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
throw new Error('input should be float tensor');
}
};
20 changes: 10 additions & 10 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -487,16 +487,16 @@
"test_tan_example",
"test_tan",
"test_tanh_example",
"test_tanh"
"test_tanh",
// "test_tile",
// "test_tile_precomputed",
// "test_transpose_all_permutations_0",
// "test_transpose_all_permutations_1",
// "test_transpose_all_permutations_2",
// "test_transpose_all_permutations_3",
// "test_transpose_all_permutations_4",
// "test_transpose_all_permutations_5",
// "test_transpose_default",
"test_transpose_all_permutations_0",
"test_transpose_all_permutations_1",
"test_transpose_all_permutations_2",
"test_transpose_all_permutations_3",
"test_transpose_all_permutations_4",
"test_transpose_all_permutations_5",
"test_transpose_default"
// "test_unsqueeze",
// "test_xor_bcast3v1d",
// "test_xor_bcast3v2d",
Expand Down Expand Up @@ -547,8 +547,8 @@
//"split.jsonc",
"sqrt.jsonc",
"sub.jsonc",
"tan.jsonc"
//"transpose.jsonc",
"tan.jsonc",
"transpose.jsonc"
//"xor.jsonc"
]
},
Expand Down

0 comments on commit 306a19b

Please sign in to comment.