diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index 4a3a5dfbf5003..637dfbfd182a6 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -6,6 +6,7 @@ import {OpSet} from '../../opset'; import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {gather, parseGatherAttributes} from './ops/gather'; +import {gemm, parseGemmAttributesV11, parseGemmAttributesV7} from './ops/gemm'; import {reshape} from './ops/reshape'; import * as unaryOps from './ops/unary-op'; import {parseUnsqueezeAttributes, unsqueeze, unsqueezeV13} from './ops/unsqueeze'; @@ -29,9 +30,8 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // ['Flatten', '', '1+', flatten, parseFlattenAttributes], ['Floor', '', '6+', unaryOps.floor], // ['FusedConv', 'com.microsoft', '1+', conv, parseConvAttributes], - ['Gather', '', '1+', gather, parseGatherAttributes], - // ['Gemm', '', '7-10', gemm, parseGemmAttributesV7], - // ['Gemm', '', '11+', gemm, parseGemmAttributesV11], + ['Gather', '', '1+', gather, parseGatherAttributes], ['Gemm', '', '7-10', gemm, parseGemmAttributesV7], + ['Gemm', '', '11+', gemm, parseGemmAttributesV11], // ['GlobalAveragePool', '', '1+', globalAveragePool, parseGlobalAveragePoolAttributes], // ['GlobalMaxPool', '', '1+', globalMaxPool], // ['Greater', '', '7+', binaryOps.greater], diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts b/js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts index 670ee2ff82ae8..0c44b6dd5dc49 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/gemm.ts @@ -5,10 +5,12 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attri import {Graph} from '../../../graph'; import {OperatorAsyncImplementation, OperatorInitialization} from '../../../operators'; import {Tensor} from '../../../tensor'; -import {GemmUtil} from '../../../util'; +import {GemmUtil, ShapeUtil} from '../../../util'; import {WebGpuInferenceHandler} from '../inference-handler'; import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {WORKGROUP_SIZE} from './common'; + export interface GemmAttributes extends AttributeWithCacheKey { transA: boolean; transB: boolean; @@ -52,17 +54,14 @@ const createGemmProgramInfo = (metadata: ProgramMetadata, inputs: Tensor[], attributes: GemmAttributes): ProgramInfo => { const aShape = inputs[0].dims.slice(); const bShape = inputs[1].dims.slice(); - const [M, N] = GemmUtil.getShapeOfGemmResult( + const [M, N, K] = GemmUtil.getShapeOfGemmResult( aShape, attributes.transA, bShape, attributes.transB, inputs.length === 3 ? inputs[2].dims : undefined); const outputShape = [M, N]; if (!outputShape) { throw new Error('Can\'t use gemm on the given tensors'); } - let sharedDim = aShape[aShape.length - 1]; + const outputSize = ShapeUtil.size(outputShape); let line = ''; - if (attributes.transA) { - sharedDim = aShape[0]; - } if (attributes.transA && attributes.transB) { line = 'value += _A_T(a) * _B_T(b);'; } else if (attributes.transA && !attributes.transB) { @@ -70,40 +69,51 @@ const createGemmProgramInfo = } else if (!attributes.transA && attributes.transB) { line = 'value += _A(a) * _B_T(b);'; } else if (!attributes.transA && !attributes.transB) { - line = 'value += _A(a) * _B(b);'; + line = 'value += a[m * K + k] * b[k * N + n];'; + } + + const dataType = 'f32'; // TODO: support other data type + const calculateC = inputs.length === 3 ? `value += ${dataType}(${attributes.beta}) * c[TODO];` : ''; + const inputStorageBuffersDeclarations = [ + `@group(0) @binding(0) var a : array<${dataType}>;`, + `@group(0) @binding(1) var b : array<${dataType}>;` + ]; + if (inputs.length === 3) { + inputStorageBuffersDeclarations.push(`@group(0) @binding(2) var c : array<${dataType}>;`); } - const rank = outputShape.length; - const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : ''; - const broadcastC = inputs.length === 3 ? 'bcastIndices_C(indices, c);' : ''; - const calculateC = inputs.length === 3 ? 'value += beta * _C(c);' : ''; const shaderSource = ` - float process(int indices[${rank}]) { - int a[${rank}]; - int b[${rank}]; - ${declareC} - - copyVec(indices, a); - copyVec(indices, b); - ${broadcastC} - - float value = 0.0; - for (int k=0; k<${sharedDim}; ++k) { - a[${rank - 1}] = k; - b[${rank - 2}] = k; - ${line} - } - - value = value * alpha; - ${calculateC} - return value; - }`; + let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u; + let N: u32 = ${N}u; + let K: u32 = ${K}u; + + ${inputStorageBuffersDeclarations.join('\n')} + @group(0) @binding(${inputs.length}) var output : array<${dataType}>; + + @stage(compute) @workgroup_size(WORKGROUP_SIZE) + fn main(@builtin(global_invocation_id) global_id : vec3) { + + // Guard against out-of-bounds work group sizes + if (global_id.x >= ${outputSize}u) { + return; + } + + let m = global_id.x / N; + let n = global_id.x % N; + + let value = ${dataType}(0); + for (var k: u32 = 0u; k<${K}u; k++) { + ${line} + } + + ${calculateC} + output[global_id.x] = value; + + }`; return { ...metadata, - output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, - variables: [ - {name: 'alpha', type: 'float', data: attributes.alpha}, {name: 'beta', type: 'float', data: attributes.beta} - ], - shaderSource + outputs: [{dims: outputShape, type: inputs[0].type, gpuDataType: GpuDataType.default}], + shaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }; }; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 06270082d7605..ceff8f70f48e2 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -357,8 +357,8 @@ // "test_flatten_default_axis", "test_gather_0", "test_gather_1", - // "test_gemm_nobroadcast", - // "test_gemm_broadcast", + "test_gemm_nobroadcast", + "test_gemm_broadcast", // "test_globalaveragepool_precomputed", // "test_globalaveragepool", // "test_globalmaxpool_precomputed", @@ -523,7 +523,7 @@ "exp.jsonc", "floor.jsonc", //"global-average-pool.jsonc", - //"gemm.jsonc", + "gemm.jsonc", //"greater.jsonc", ////"identity.jsonc", //"image-scaler.jsonc",