From 21b5dfe1c4de3b0d70fab044fb1ac92c51ffe152 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 7 Jun 2022 16:13:12 -0700 Subject: [PATCH] matmul (no-broadcast) --- .../backends/webgpu/op-resolve-rules.ts | 4 +- .../onnxjs/backends/webgpu/ops/fuse-utils.ts | 23 ++-- .../lib/onnxjs/backends/webgpu/ops/matmul.ts | 111 +++++++----------- js/web/test/suite-test-list.jsonc | 6 +- 4 files changed, 56 insertions(+), 88 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index 1d8258ec7399f..c5b59565a68da 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -7,6 +7,7 @@ import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {gather, parseGatherAttributes} from './ops/gather'; import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm'; +import {matMul, parseMatMulAttributes} from './ops/matmul'; import {reshape} from './ops/reshape'; import {parseSliceAttributes, slice, sliceV10} from './ops/slice'; import * as unaryOps from './ops/unary-op'; @@ -41,8 +42,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes], ['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes], // ['Less', '', '7+', binaryOps.less], - ['Log', '', '6+', unaryOps.log], - // ['MatMul', '', '1+', matMul, parseMatMulAttributes], + ['Log', '', '6+', unaryOps.log], ['MatMul', '', '1+', matMul, parseMatMulAttributes], // // TODO: support new attributes for MaxPool-8 and MaxPool-10 // ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes], ['Mul', '', '7+', binaryOps.mul], ['Neg', '', '6+', unaryOps.neg], diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/fuse-utils.ts b/js/web/lib/onnxjs/backends/webgpu/ops/fuse-utils.ts index 124000801a6c8..fae2c9fb6e9b2 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/fuse-utils.ts @@ -3,9 +3,6 @@ import {Attribute} from '../../../attribute'; import {MAX_CLIP, MIN_CLIP} from '../../../util'; -import {GlslValueFunction} from '../glsl-definitions'; - -import {glslClip, glslRelu, glslSigmoid} from './unary-op'; export interface InternalActivationAttributes { readonly activation: string; @@ -15,26 +12,20 @@ export interface InternalActivationAttributes { } export function getActicationSnippet(attributes: InternalActivationAttributes) { - let func: GlslValueFunction; switch (attributes.activation) { case 'Relu': - func = glslRelu(); - break; + return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; case 'Sigmoid': - func = glslSigmoid(); - break; + return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; case 'Clip': - func = glslClip(attributes.clipMin!, attributes.clipMax!); - break; - // TODO: adding other activations that can be fused. + return { + activationFunction: `let clip_min_=f32(${attributes.clipMin!});let clip_max_=f32(${attributes.clipMax!});`, + applyActivation: 'value = clamp(value, clip_min_, clip_max_);' + }; + // TODO: adding other activations that can be fused. default: return {activationFunction: '', applyActivation: ''}; } - - const activationName = func.name; - const activationFunction = func.body; - const applyActivation = `value = ${activationName}_(value);`; - return {activationFunction, applyActivation}; } export const parseInternalActivationAttributes = (attributes: Attribute): InternalActivationAttributes => { diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/matmul.ts b/js/web/lib/onnxjs/backends/webgpu/ops/matmul.ts index fc7177c000a01..e0b7ec2cb848b 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/matmul.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/matmul.ts @@ -8,6 +8,7 @@ import {BroadcastUtil, ShapeUtil} from '../../../util'; import {WebGpuInferenceHandler} from '../inference-handler'; import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {WORKGROUP_SIZE} from './common'; import {getActicationSnippet, InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils'; export const matMul: OperatorAsyncImplementation = @@ -36,42 +37,55 @@ function createMatmulProgramInfo( if (!outputShape) { throw new Error('Can\'t use matmul on the given tensors'); } - const coordsDataType = getCoordsDataType(outputShape.length); - const allGlChannels = getGlChannels(); - const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); + const outputSize = ShapeUtil.size(outputShape); + // TODO: support broadcasting - const hasBias = inputs.length > 2; - const processBias = hasBias ? 'value += getBiasForMatmul();' : ''; - const getBiasForMatmulSnippet = - hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}` : ''; + const dataType = 'f32'; // TODO: support other data type + const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); - const rank = outputShape.length; - const arank = aShape.length; - const brank = bShape.length; - const sharedDim = aShape[aShape.length - 1]; + const M = outputShape[outputShape.length - 2]; + const K = aShape[aShape.length - 1]; + const N = outputShape[outputShape.length - 1]; const shaderSource = ` - ${activationFunction} - ${getBiasForMatmulSnippet} - float process(int indices[${rank}]) { - int a[${arank}]; - int b[${brank}]; - bcastMatmulIndices_A(indices, a); - bcastMatmulIndices_B(indices, b); - - float value; - for (int k=0; k<${sharedDim}; ++k) { - a[${arank - 1}] = k; - b[${brank - 2}] = k; - value += _A(a) * _B(b); - } - ${processBias} - ${applyActivation} - return value; - }`; + let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u; + let M: u32 = ${M}u; + let N: u32 = ${N}u; + let K: u32 = ${K}u; + + @group(0) @binding(0) var a : array<${dataType}>; + @group(0) @binding(1) var b : array<${dataType}>; + @group(0) @binding(2) var output : array<${dataType}>; + + ${activationFunction} + + @stage(compute) @workgroup_size(WORKGROUP_SIZE) + fn main(@builtin(global_invocation_id) global_id : vec3) { + + // Guard against out-of-bounds work group sizes + if (global_id.x >= ${outputSize}u) { + return; + } + + let stack = global_id.x / (M * N); + let mn = global_id.x % (M * N); + let n = global_id.x % N; + let m = mn / N; + + let offsetA = stack * (M * K); + let offsetB = stack * (K * N); + + var value = ${dataType}(0); + for (var k: u32 = 0u; k<${K}u; k++) { + value += a[offsetA + m * K + k] * b[offsetB + k * N + n]; + } + ${applyActivation} + output[global_id.x] = value; + }`; return { ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, + outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}], shaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }; } @@ -99,40 +113,3 @@ const validateInputs = (inputs: Tensor[]): void => { throw new Error('inputs types should match'); } }; - -export function getBiasForMatmul( - coordsDataType: string, allGlChannels: readonly string[], inShape: readonly number[], outShape: readonly number[], - isPacked: boolean): string { - let unpackedCoordsSnippet = ''; - const inRank = inShape.length; - const outRank = outShape.length; - const rankDiff = outRank - inRank; - if (outRank < 2 && inRank > 0) { - unpackedCoordsSnippet = 'coords'; - } else { - unpackedCoordsSnippet = inShape.map((s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); - } - const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape); - const coordsSnippet = broadcastDims.map(d => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n'); - const inSize = ShapeUtil.size(inShape); - const isInputScalar = inSize === 1; - let output = 'vec4(outputValue.xx, outputValue.yy)'; - if (isInputScalar) { - output = 'vec4(outputValue.x)'; - } - const getBiasForMatmulSource = isPacked ? ` -vec4 getBiasForMatmul() { - ${coordsDataType} coords = getOutputCoords(); - ${coordsSnippet} - vec4 outputValue = getBias(${unpackedCoordsSnippet}); - return ${output}; -}` : - ` -float getBiasForMatmul() { - ${coordsDataType} coords = getOutputCoords(); - ${coordsSnippet} - return getBias(coords.x); -}`; - - return getBiasForMatmulSource; -} diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 71e8e6ae225ad..001ce4fccc2f9 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -377,9 +377,9 @@ "test_leakyrelu", // "test_lrn_default", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-2 (100x increase), it passes the test. // "test_lrn", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-3 (10x increase), it passes the test. - // "test_matmul_2d", - // "test_matmul_3d", - // "test_matmul_4d", + "test_matmul_2d", + "test_matmul_3d", + "test_matmul_4d", // "test_maxpool_1d_default", // "test_maxpool_2d_default", // "v12/test_maxpool_2d_pads",