Skip to content

Commit

Permalink
matmul (no-broadcast)
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent a8def8e commit 21b5dfe
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 88 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {gather, parseGatherAttributes} from './ops/gather';
import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm';
import {matMul, parseMatMulAttributes} from './ops/matmul';
import {reshape} from './ops/reshape';
import {parseSliceAttributes, slice, sliceV10} from './ops/slice';
import * as unaryOps from './ops/unary-op';
Expand Down Expand Up @@ -41,8 +42,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes],
['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes],
// ['Less', '', '7+', binaryOps.less],
['Log', '', '6+', unaryOps.log],
// ['MatMul', '', '1+', matMul, parseMatMulAttributes],
['Log', '', '6+', unaryOps.log], ['MatMul', '', '1+', matMul, parseMatMulAttributes],
// // TODO: support new attributes for MaxPool-8 and MaxPool-10
// ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes],
['Mul', '', '7+', binaryOps.mul], ['Neg', '', '6+', unaryOps.neg],
Expand Down
23 changes: 7 additions & 16 deletions js/web/lib/onnxjs/backends/webgpu/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

import {Attribute} from '../../../attribute';
import {MAX_CLIP, MIN_CLIP} from '../../../util';
import {GlslValueFunction} from '../glsl-definitions';

import {glslClip, glslRelu, glslSigmoid} from './unary-op';

export interface InternalActivationAttributes {
readonly activation: string;
Expand All @@ -15,26 +12,20 @@ export interface InternalActivationAttributes {
}

export function getActicationSnippet(attributes: InternalActivationAttributes) {
let func: GlslValueFunction;
switch (attributes.activation) {
case 'Relu':
func = glslRelu();
break;
return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'};
case 'Sigmoid':
func = glslSigmoid();
break;
return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'};
case 'Clip':
func = glslClip(attributes.clipMin!, attributes.clipMax!);
break;
// TODO: adding other activations that can be fused.
return {
activationFunction: `let clip_min_=f32(${attributes.clipMin!});let clip_max_=f32(${attributes.clipMax!});`,
applyActivation: 'value = clamp(value, clip_min_, clip_max_);'
};
// TODO: adding other activations that can be fused.
default:
return {activationFunction: '', applyActivation: ''};
}

const activationName = func.name;
const activationFunction = func.body;
const applyActivation = `value = ${activationName}_(value);`;
return {activationFunction, applyActivation};
}

export const parseInternalActivationAttributes = (attributes: Attribute): InternalActivationAttributes => {
Expand Down
111 changes: 44 additions & 67 deletions js/web/lib/onnxjs/backends/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {BroadcastUtil, ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {WORKGROUP_SIZE} from './common';
import {getActicationSnippet, InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';

export const matMul: OperatorAsyncImplementation<InternalActivationAttributes> =
Expand Down Expand Up @@ -36,42 +37,55 @@ function createMatmulProgramInfo(
if (!outputShape) {
throw new Error('Can\'t use matmul on the given tensors');
}
const coordsDataType = getCoordsDataType(outputShape.length);
const allGlChannels = getGlChannels();
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);
const outputSize = ShapeUtil.size(outputShape);
// TODO: support broadcasting

const hasBias = inputs.length > 2;
const processBias = hasBias ? 'value += getBiasForMatmul();' : '';
const getBiasForMatmulSnippet =
hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, false)}` : '';
const dataType = 'f32'; // TODO: support other data type
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);

const rank = outputShape.length;
const arank = aShape.length;
const brank = bShape.length;
const sharedDim = aShape[aShape.length - 1];
const M = outputShape[outputShape.length - 2];
const K = aShape[aShape.length - 1];
const N = outputShape[outputShape.length - 1];
const shaderSource = `
${activationFunction}
${getBiasForMatmulSnippet}
float process(int indices[${rank}]) {
int a[${arank}];
int b[${brank}];
bcastMatmulIndices_A(indices, a);
bcastMatmulIndices_B(indices, b);
float value;
for (int k=0; k<${sharedDim}; ++k) {
a[${arank - 1}] = k;
b[${brank - 2}] = k;
value += _A(a) * _B(b);
}
${processBias}
${applyActivation}
return value;
}`;
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
let M: u32 = ${M}u;
let N: u32 = ${N}u;
let K: u32 = ${K}u;
@group(0) @binding(0) var<storage, read> a : array<${dataType}>;
@group(0) @binding(1) var<storage, read> b : array<${dataType}>;
@group(0) @binding(2) var<storage, write> output : array<${dataType}>;
${activationFunction}
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
if (global_id.x >= ${outputSize}u) {
return;
}
let stack = global_id.x / (M * N);
let mn = global_id.x % (M * N);
let n = global_id.x % N;
let m = mn / N;
let offsetA = stack * (M * K);
let offsetB = stack * (K * N);
var value = ${dataType}(0);
for (var k: u32 = 0u; k<${K}u; k++) {
value += a[offsetA + m * K + k] * b[offsetB + k * N + n];
}
${applyActivation}
output[global_id.x] = value;
}`;
return {
...metadata,
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked},
outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}],
shaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
}

Expand Down Expand Up @@ -99,40 +113,3 @@ const validateInputs = (inputs: Tensor[]): void => {
throw new Error('inputs types should match');
}
};

export function getBiasForMatmul(
coordsDataType: string, allGlChannels: readonly string[], inShape: readonly number[], outShape: readonly number[],
isPacked: boolean): string {
let unpackedCoordsSnippet = '';
const inRank = inShape.length;
const outRank = outShape.length;
const rankDiff = outRank - inRank;
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
} else {
unpackedCoordsSnippet = inShape.map((s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', ');
}
const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape);
const coordsSnippet = broadcastDims.map(d => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n');
const inSize = ShapeUtil.size(inShape);
const isInputScalar = inSize === 1;
let output = 'vec4(outputValue.xx, outputValue.yy)';
if (isInputScalar) {
output = 'vec4(outputValue.x)';
}
const getBiasForMatmulSource = isPacked ? `
vec4 getBiasForMatmul() {
${coordsDataType} coords = getOutputCoords();
${coordsSnippet}
vec4 outputValue = getBias(${unpackedCoordsSnippet});
return ${output};
}` :
`
float getBiasForMatmul() {
${coordsDataType} coords = getOutputCoords();
${coordsSnippet}
return getBias(coords.x);
}`;

return getBiasForMatmulSource;
}
6 changes: 3 additions & 3 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@
"test_leakyrelu",
// "test_lrn_default", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-2 (100x increase), it passes the test.
// "test_lrn", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-3 (10x increase), it passes the test.
// "test_matmul_2d",
// "test_matmul_3d",
// "test_matmul_4d",
"test_matmul_2d",
"test_matmul_3d",
"test_matmul_4d",
// "test_maxpool_1d_default",
// "test_maxpool_2d_default",
// "v12/test_maxpool_2d_pads",
Expand Down

0 comments on commit 21b5dfe

Please sign in to comment.