Skip to content

Commit

Permalink
gemm...
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent 99653f5 commit c1185b4
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 42 deletions.
6 changes: 3 additions & 3 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {OpSet} from '../../opset';
import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {gather, parseGatherAttributes} from './ops/gather';
import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm';
import {reshape} from './ops/reshape';
import * as unaryOps from './ops/unary-op';
import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze';
Expand All @@ -29,9 +30,8 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['Flatten', '', '1+', flatten, parseFlattenAttributes],
['Floor', '', '6+', unaryOps.floor],
// ['FusedConv', 'com.microsoft', '1+', conv, parseConvAttributes],
['Gather', '', '1+', gather, parseGatherAttributes],
// ['Gemm', '', '7-10', gemm, parseGemmAttributesV7],
// ['Gemm', '', '11+', gemm, parseGemmAttributesV11],
['Gather', '', '1+', gather, parseGatherAttributes], ['Gemm', '', '7-10', gemm, parseGemmAttributesV7],
['Gemm', '', '11+', gemm, parseGemmAttributesV11],
// ['GlobalAveragePool', '', '1+', globalAveragePool, parseGlobalAveragePoolAttributes],
// ['GlobalMaxPool', '', '1+', globalMaxPool],
// ['Greater', '', '7+', binaryOps.greater],
Expand Down
82 changes: 46 additions & 36 deletions js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attri
import {Graph} from '../../../graph';
import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {GemmUtil} from '../../../util';
import {GemmUtil, ShapeUtil} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {WORKGROUP_SIZE} from './common';

export interface GemmAttributes extends AttributeWithCacheKey {
transA: boolean;
transB: boolean;
Expand Down Expand Up @@ -52,58 +54,66 @@ const createGemmProgramInfo =
(metadata: ProgramMetadata, inputs: Tensor[], attributes: GemmAttributes): ProgramInfo => {
const aShape = inputs[0].dims.slice();
const bShape = inputs[1].dims.slice();
const [M, N] = GemmUtil.getShapeOfGemmResult(
const [M, N, K] = GemmUtil.getShapeOfGemmResult(
aShape, attributes.transA, bShape, attributes.transB, inputs.length === 3 ? inputs[2].dims : undefined);
const outputShape = [M, N];
if (!outputShape) {
throw new Error('Can\'t use gemm on the given tensors');
}
let sharedDim = aShape[aShape.length - 1];
const outputSize = ShapeUtil.size(outputShape);
let line = '';
if (attributes.transA) {
sharedDim = aShape[0];
}
if (attributes.transA && attributes.transB) {
line = 'value += _A_T(a) * _B_T(b);';
} else if (attributes.transA && !attributes.transB) {
line = 'value += _A_T(a) * _B(b);';
} else if (!attributes.transA && attributes.transB) {
line = 'value += _A(a) * _B_T(b);';
} else if (!attributes.transA && !attributes.transB) {
line = 'value += _A(a) * _B(b);';
line = 'value += a[m * K + k] * b[k * N + n];';
}

const dataType = 'f32'; // TODO: support other data type
const calculateC = inputs.length === 3 ? `value += ${dataType}(${attributes.beta}) * c[TODO];` : '';
const inputStorageBuffersDeclarations = [
`@group(0) @binding(0) var<storage, read> a : array<${dataType}>;`,
`@group(0) @binding(1) var<storage, read> b : array<${dataType}>;`
];
if (inputs.length === 3) {
inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var<storage, read> c : array<${dataType}>;`);
}
const rank = outputShape.length;
const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : '';
const broadcastC = inputs.length === 3 ? 'bcastIndices_C(indices, c);' : '';
const calculateC = inputs.length === 3 ? 'value += beta * _C(c);' : '';
const shaderSource = `
float process(int indices[${rank}]) {
int a[${rank}];
int b[${rank}];
${declareC}
copyVec(indices, a);
copyVec(indices, b);
${broadcastC}
float value = 0.0;
for (int k=0; k<${sharedDim}; ++k) {
a[${rank - 1}] = k;
b[${rank - 2}] = k;
${line}
}
value = value * alpha;
${calculateC}
return value;
}`;
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
let N: u32 = ${N}u;
let K: u32 = ${K}u;
${inputStorageBuffersDeclarations.join('\n')}
@group(0) @binding(${inputs.length}) var<storage, write> output : array<${dataType}>;
@stage(compute) @workgroup_size(WORKGROUP_SIZE)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
// Guard against out-of-bounds work group sizes
if (global_id.x >= ${outputSize}u) {
return;
}
let m = global_id.x / N;
let n = global_id.x % N;
let value = ${dataType}(0);
for (var k: u32 = 0u; k<${K}u; k++) {
${line}
}
${calculateC}
output[global_id.x] = value;
}`;
return {
...metadata,
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked},
variables: [
{name: 'alpha', type: 'float', data: attributes.alpha}, {name: 'beta', type: 'float', data: attributes.beta}
],
shaderSource
outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}],
shaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

Expand Down
6 changes: 3 additions & 3 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@
// "test_flatten_default_axis",
"test_gather_0",
"test_gather_1",
// "test_gemm_nobroadcast",
// "test_gemm_broadcast",
"test_gemm_nobroadcast",
"test_gemm_broadcast",
// "test_globalaveragepool_precomputed",
// "test_globalaveragepool",
// "test_globalmaxpool_precomputed",
Expand Down Expand Up @@ -523,7 +523,7 @@
"exp.jsonc",
"floor.jsonc",
//"global-average-pool.jsonc",
//"gemm.jsonc",
"gemm.jsonc",
//"greater.jsonc",
////"identity.jsonc",
//"image-scaler.jsonc",
Expand Down

0 comments on commit c1185b4

Please sign in to comment.