Skip to content

Commit

Permalink
gemm (scalar)
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent c1185b4 commit 9d92513
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
30 changes: 25 additions & 5 deletions js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ const createGemmProgramInfoLoader = (inputs: Tensor[], attributes: GemmAttribute
return {...metadata, get: () => createGemmProgramInfo(metadata, inputs, attributes)};
};

const offsetC = (m: number, n: number, dims: readonly number[]): string => {
const broadcastM = (dims.length === 1 && m !== 1) || (dims.length === 2 && dims[0] !== m);
const broadcastN = dims[dims.length - 1] !== n;

let offset = '0u';
if (!broadcastM) {
offset += `+ m * ${dims[dims.length - 1]}u`;
}
if (!broadcastN) {
offset += '+n';
}

return offset;
};

const createGemmProgramInfo =
(metadata: ProgramMetadata, inputs: Tensor[], attributes: GemmAttributes): ProgramInfo => {
const aShape = inputs[0].dims.slice();
Expand All @@ -63,17 +78,18 @@ const createGemmProgramInfo =
const outputSize = ShapeUtil.size(outputShape);
let line = '';
if (attributes.transA && attributes.transB) {
line = 'value += _A_T(a) * _B_T(b);';
line = 'value += a[k * M + m] * b[n * K + k];';
} else if (attributes.transA && !attributes.transB) {
line = 'value += _A_T(a) * _B(b);';
line = 'value += a[k * M + m] * b[k * N + n];';
} else if (!attributes.transA && attributes.transB) {
line = 'value += _A(a) * _B_T(b);';
line = 'value += a[m * K + k] * b[n * K + k];';
} else if (!attributes.transA && !attributes.transB) {
line = 'value += a[m * K + k] * b[k * N + n];';
}

const dataType = 'f32'; // TODO: support other data type
const calculateC = inputs.length === 3 ? `value += ${dataType}(${attributes.beta}) * c[TODO];` : '';
const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;';
const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : '';
const inputStorageBuffersDeclarations = [
`@group(0) @binding(0) var<storage, read> a : array<${dataType}>;`,
`@group(0) @binding(1) var<storage, read> b : array<${dataType}>;`
Expand All @@ -83,8 +99,11 @@ const createGemmProgramInfo =
}
const shaderSource = `
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
let M: u32 = ${M}u;
let N: u32 = ${N}u;
let K: u32 = ${K}u;
let alpha = ${dataType}(${attributes.alpha});
let beta = ${dataType}(${attributes.beta});
${inputStorageBuffersDeclarations.join('\n')}
@group(0) @binding(${inputs.length}) var<storage, write> output : array<${dataType}>;
Expand All @@ -100,11 +119,12 @@ const createGemmProgramInfo =
let m = global_id.x / N;
let n = global_id.x % N;
let value = ${dataType}(0);
var value = ${dataType}(0);
for (var k: u32 = 0u; k<${K}u; k++) {
${line}
}
${calculateAlpha}
${calculateC}
output[global_id.x] = value;
Expand Down
3 changes: 2 additions & 1 deletion js/web/lib/onnxjs/backends/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Profiler} from '../../instrument';
import {Logger, Profiler} from '../../instrument';

import {Artifact, GpuData, ProgramInfo} from './types';

Expand Down Expand Up @@ -62,6 +62,7 @@ export class ProgramManager {
const device = this.device;

const shaderModule = device.createShaderModule({code: programInfo.shaderSource});
Logger.verbose('WebGpuProgram', programInfo.shaderSource);

const computePipeline = device.createComputePipeline({compute: {module: shaderModule, entryPoint: 'main'}});

Expand Down

0 comments on commit 9d92513

Please sign in to comment.