From 86c75bb486e503d80d170de4884bedda1ca3fecd Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 24 May 2022 11:39:35 -0700 Subject: [PATCH] gemm --- js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts b/js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts new file mode 100644 index 0000000000000..670ee2ff82ae8 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; +import {Graph} from '../../../graph'; +import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators'; +import {Tensor} from '../../../tensor'; +import {GemmUtil} from '../../../util'; +import {WebGpuInferenceHandler} from '../inference-handler'; +import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +export interface GemmAttributes extends AttributeWithCacheKey { + transA: boolean; + transB: boolean; + alpha: number; + beta: number; + isOptionalC: boolean; // in opset 11, C becomes optional +} + +export const gemm: OperatorAsyncImplementation = async( + inferenceHandler: WebGpuInferenceHandler, inputs: Tensor[], attributes: GemmAttributes): Promise => { + validateInputs(inputs, attributes); + return inferenceHandler.run(createGemmProgramInfoLoader(inputs, attributes), inputs); +}; + +const parseGemmAttributes = (node: Graph.Node, isOptionalC: boolean): GemmAttributes => { + const transA = node.attributes.getInt('transA', 0) !== 0; + const transB = node.attributes.getInt('transB', 0) !== 0; + const alpha = node.attributes.getFloat('alpha', 1.0); + const beta = node.attributes.getFloat('beta', 1.0); + return createAttributeWithCacheKey({transA, transB, alpha, beta, isOptionalC}); +}; + +export const parseGemmAttributesV7: OperatorInitialization = (node: Graph.Node): GemmAttributes => + parseGemmAttributes(node, false); + +export const parseGemmAttributesV11: OperatorInitialization = (node: Graph.Node): GemmAttributes => + parseGemmAttributes(node, true); + +const createGemmProgramInfoLoader = (inputs: Tensor[], attributes: GemmAttributes): ProgramInfoLoader => { + const metadata = { + name: 'Gemm', + inputTypes: inputs.length === 3 ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint: attributes.cacheKey + }; + + return {...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes)}; +}; + +const createGemmProgramInfo = + (metadata: ProgramMetadata, inputs: Tensor[], attributes: GemmAttributes): ProgramInfo => { + const aShape = inputs[0].dims.slice(); + const bShape = inputs[1].dims.slice(); + const [M, N] = GemmUtil.getShapeOfGemmResult( + aShape, attributes.transA, bShape, attributes.transB, inputs.length === 3 ? inputs[2].dims : undefined); + const outputShape = [M, N]; + if (!outputShape) { + throw new Error('Can\'t use gemm on the given tensors'); + } + let sharedDim = aShape[aShape.length - 1]; + let line = ''; + if (attributes.transA) { + sharedDim = aShape[0]; + } + if (attributes.transA && attributes.transB) { + line = 'value += _A_T(a) * _B_T(b);'; + } else if (attributes.transA && !attributes.transB) { + line = 'value += _A_T(a) * _B(b);'; + } else if (!attributes.transA && attributes.transB) { + line = 'value += _A(a) * _B_T(b);'; + } else if (!attributes.transA && !attributes.transB) { + line = 'value += _A(a) * _B(b);'; + } + const rank = outputShape.length; + const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : ''; + const broadcastC = inputs.length === 3 ? 'bcastIndices_C(indices, c);' : ''; + const calculateC = inputs.length === 3 ? 'value += beta * _C(c);' : ''; + const shaderSource = ` + float process(int indices[${rank}]) { + int a[${rank}]; + int b[${rank}]; + ${declareC} + + copyVec(indices, a); + copyVec(indices, b); + ${broadcastC} + + float value = 0.0; + for (int k=0; k<${sharedDim}; ++k) { + a[${rank - 1}] = k; + b[${rank - 2}] = k; + ${line} + } + + value = value * alpha; + ${calculateC} + return value; + }`; + return { + ...metadata, + output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, + variables: [ + {name: 'alpha', type: 'float', data: attributes.alpha}, {name: 'beta', type: 'float', data: attributes.beta} + ], + shaderSource + }; + }; + +const validateInputs = (inputs: Tensor[], attributes: GemmAttributes): void => { + if (!inputs) { + throw new Error('Input is missing'); + } + if (attributes.isOptionalC && (inputs.length < 2 || inputs.length > 3)) { + throw new Error('Invaid input shape.'); + } + if (!attributes.isOptionalC && inputs.length !== 3) { + throw new Error('Gemm requires 3 inputs'); + } + + // 'C' can be of dimensionality 1 or 2 only + if (inputs.length === 3 && inputs[2].dims.length !== 1 && inputs[2].dims.length !== 2) { + throw new Error('Invalid input shape of C'); + } + + if ((inputs[0].type !== 'float32' && inputs[0].type !== 'float64') || + (inputs[1].type !== 'float32' && inputs[1].type !== 'float64') || + (inputs.length === 3 && inputs[2].type !== 'float32' && inputs[2].type !== 'float64')) { + throw new Error('Invalid input type.'); + } + + if ((inputs[0].type !== inputs[1].type) || (inputs.length === 3 && inputs[0].type !== inputs[2].type)) { + throw new Error('Input types are mismatched'); + } +};