Skip to content

Commit 5ff27ef

Browse files
authored
[js/webgpu] support customop FastGelu (#19392)
### Description Support WebGPU custom operator FastGelu.
1 parent a4cfdc1 commit 5ff27ef

File tree

10 files changed

+353
-8
lines changed

10 files changed

+353
-8
lines changed

js/web/docs/webgpu-operators.md

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Do not modify directly.*
4141
| Erf | ai.onnx(9-12,13+) | |
4242
| Exp | ai.onnx(6-12,13+) | |
4343
| Expand | ai.onnx(8-12,13+) | |
44+
| FastGelu | com.microsoft(1+) | |
4445
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
4546
| Floor | ai.onnx(6-12,13+) | |
4647
| FusedConv | com.microsoft(1+) | |

js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'
1313
import {cumsum, parseCumSumAttributes} from './ops/cumsum';
1414
import {einsum, parseEinsumAttributes} from './ops/einsum';
1515
import {expand} from './ops/expand';
16+
import {fastGelu} from './ops/fast-gelu';
1617
import {gather, parseGatherAttributes} from './ops/gather';
1718
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
1819
import {gemm, parseGemmAttributes} from './ops/gemm';
@@ -72,6 +73,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
7273
['Erf', [unaryOps.erf]],
7374
['Exp', [unaryOps.exp]],
7475
['Expand', [expand]],
76+
['FastGelu', [fastGelu]],
7577
['Floor', [unaryOps.floor]],
7678
['FusedConv', [conv, parseConvAttributes]],
7779
['Gather', [gather, parseGatherAttributes]],

js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI
4343
4444
${shaderHelper.declareVariables(input, bias, output)}
4545
46-
${erfImpl(`vec4<${dataType}>`, dataType)}
46+
${erfImpl(dataType)}
4747
4848
${shaderHelper.mainStart()}
4949
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
import {DataType} from '../../../wasm-common';
5+
import {TensorView} from '../../tensor-view';
6+
import {ShapeUtil} from '../../util';
7+
import {ComputeContext, ProgramInfo} from '../types';
8+
9+
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common';
10+
import * as unary from './unary-op';
11+
12+
// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias.
13+
14+
const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => {
15+
const dataType = inputTensors[0].dataType;
16+
const outputSize = ShapeUtil.size(inputTensors[0].dims);
17+
const biasLength = ShapeUtil.size(inputTensors[1].dims);
18+
// can only use vec4 when bias length is multiple of 4
19+
const useVec4 = biasLength % 4 === 0;
20+
const getShaderSource = (shaderHelper: ShaderHelper): string => {
21+
const x = inputVariable('x', dataType, [1], 4);
22+
const bias = inputVariable('bias', dataType, [1], 4);
23+
const y = outputVariable('y', dataType, [1], 4);
24+
25+
const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}];
26+
27+
const singleElementBias = (i: 0|1|2|3) => `
28+
let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size;
29+
let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`;
30+
const biasGetExpression = useVec4 ?
31+
`
32+
let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` :
33+
`${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)}
34+
let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`;
35+
36+
return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)}
37+
38+
${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))}
39+
40+
${shaderHelper.mainStart(WORKGROUP_SIZE)}
41+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')}
42+
43+
let x = ${x.getByOffset('global_idx')};
44+
${biasGetExpression}
45+
let x_in = x + bias;
46+
${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))}
47+
}`;
48+
};
49+
50+
return {
51+
name: 'FastGeluWithBias',
52+
shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']},
53+
getShaderSource,
54+
getRunData: (inputs) => ({
55+
outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}],
56+
programUniforms:
57+
[{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}],
58+
dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)}
59+
})
60+
};
61+
};
62+
63+
export const fastGelu = (context: ComputeContext): void => {
64+
if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) {
65+
unary.fastGelu(context);
66+
} else {
67+
context.compute(createFastGeluProgramInfo(context.inputs));
68+
}
69+
};

js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts

+26-7
Original file line numberDiff line numberDiff line change
@@ -178,24 +178,23 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
178178
attributes.cacheKey));
179179
};
180180

181-
export const erfImpl = (dataType: string, varType = 'f32') => `
181+
export const erfImpl = (varType = 'f32') => `
182182
const r0: ${varType} = 0.3275911;
183183
const r1: ${varType} = 0.254829592;
184184
const r2: ${varType} = -0.284496736;
185185
const r3: ${varType} = 1.421413741;
186186
const r4: ${varType} = -1.453152027;
187187
const r5: ${varType} = 1.061405429;
188188
189-
fn erf_vf32(v: ${dataType}) -> ${dataType} {
189+
fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> {
190190
let absv = abs(v);
191191
let x = 1.0 / (1.0 + r0 * absv);
192192
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
193193
}`;
194194

195195
export const erf = (context: ComputeContext): void => {
196196
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
197-
context.compute(createElementwiseProgramInfo(
198-
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
197+
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType)));
199198
};
200199

201200
export const exp = (context: ComputeContext): void => {
@@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => {
209208
export const gelu = (context: ComputeContext): void => {
210209
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
211210
context.compute(createElementwiseProgramInfo(
212-
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
213-
erfImpl(`vec4<${dataType}>`, dataType)));
211+
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType)));
214212
};
215213

216214
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
@@ -278,10 +276,31 @@ export const tan = (context: ComputeContext): void => {
278276
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan'));
279277
};
280278

279+
export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`;
280+
281281
export const tanh = (context: ComputeContext): void => {
282282
// TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
283+
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression));
284+
};
285+
286+
export const fastGeluImpl = (varType = 'f32') => `
287+
const fast_gelu_a: ${varType} = 0.5;
288+
const fast_gelu_b: ${varType} = 0.7978845608028654;
289+
const fast_gelu_c: ${varType} = 0.035677408136300125;
290+
291+
fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> {
292+
return ${tanhExpression('v')};
293+
}
294+
`;
295+
296+
export const fastGeluExpression = (x: string) =>
297+
`(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`;
298+
299+
export const fastGelu = (context: ComputeContext): void => {
300+
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
283301
context.compute(createElementwiseProgramInfo(
284-
context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`));
302+
context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined,
303+
context.inputs[0].dataType));
285304
};
286305

287306
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {

js/web/test/data/ops/fast-gelu.jsonc

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
[
2+
{
3+
"name": "FastGelu test without bias",
4+
"operator": "FastGelu",
5+
"opset": { "domain": "com.microsoft", "version": 1 },
6+
"cases": [
7+
{
8+
"name": "scalar",
9+
"inputs": [
10+
{
11+
"data": [1],
12+
"dims": [],
13+
"type": "float32"
14+
}
15+
],
16+
"outputs": [
17+
{
18+
"data": [0.841192],
19+
"dims": [],
20+
"type": "float32"
21+
}
22+
]
23+
},
24+
{
25+
"name": "[2x4]",
26+
"inputs": [
27+
{
28+
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
29+
"dims": [2, 4],
30+
"type": "float32"
31+
}
32+
],
33+
"outputs": [
34+
{
35+
"data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432],
36+
"dims": [2, 4],
37+
"type": "float32"
38+
}
39+
]
40+
},
41+
{
42+
"name": "[3x5]",
43+
"inputs": [
44+
{
45+
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
46+
"dims": [3, 5],
47+
"type": "float32"
48+
}
49+
],
50+
"outputs": [
51+
{
52+
"data": [
53+
0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581,
54+
1.0617, 1.17393, 1.28671, 1.39957
55+
],
56+
"dims": [3, 5],
57+
"type": "float32"
58+
}
59+
]
60+
}
61+
]
62+
},
63+
{
64+
"name": "FastGelu test with bias",
65+
"operator": "FastGelu",
66+
"opset": { "domain": "com.microsoft", "version": 1 },
67+
"cases": [
68+
{
69+
"name": "scalar",
70+
"inputs": [
71+
{
72+
"data": [1],
73+
"dims": [],
74+
"type": "float32"
75+
},
76+
{
77+
"data": [0.5],
78+
"dims": [],
79+
"type": "float32"
80+
}
81+
],
82+
"outputs": [
83+
{
84+
"data": [1.39957],
85+
"dims": [],
86+
"type": "float32"
87+
}
88+
]
89+
},
90+
{
91+
"name": "[2x4], [4]",
92+
"inputs": [
93+
{
94+
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
95+
"dims": [2, 4],
96+
"type": "float32"
97+
},
98+
{
99+
"data": [1, 2, 3, 4],
100+
"dims": [4],
101+
"type": "float32"
102+
}
103+
],
104+
"outputs": [
105+
{
106+
"data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8],
107+
"dims": [2, 4],
108+
"type": "float32"
109+
}
110+
]
111+
},
112+
{
113+
"name": "[2x4], [3]",
114+
"inputs": [
115+
{
116+
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
117+
"dims": [2, 4],
118+
"type": "float32"
119+
},
120+
{
121+
"data": [1, 2, 3],
122+
"dims": [3],
123+
"type": "float32"
124+
}
125+
],
126+
"outputs": [
127+
{
128+
"data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331],
129+
"dims": [2, 4],
130+
"type": "float32"
131+
}
132+
]
133+
},
134+
{
135+
"name": "[3x5], [2]",
136+
"inputs": [
137+
{
138+
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
139+
"dims": [3, 5],
140+
"type": "float32"
141+
},
142+
{
143+
"data": [2, 3],
144+
"dims": [2],
145+
"type": "float32"
146+
}
147+
],
148+
"outputs": [
149+
{
150+
"data": [
151+
2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869,
152+
4.39999, 3.49938
153+
],
154+
"dims": [3, 5],
155+
"type": "float32"
156+
}
157+
]
158+
},
159+
{
160+
"name": "[3x5], [7]",
161+
"inputs": [
162+
{
163+
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
164+
"dims": [3, 5],
165+
"type": "float32"
166+
},
167+
{
168+
"data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7],
169+
"dims": [7],
170+
"type": "float32"
171+
}
172+
],
173+
"outputs": [
174+
{
175+
"data": [
176+
2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989,
177+
4.09996, 3.59959
178+
],
179+
"dims": [3, 5],
180+
"type": "float32"
181+
}
182+
]
183+
},
184+
{
185+
"name": "[4x4], [8]",
186+
"inputs": [
187+
{
188+
"data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0],
189+
"dims": [4, 4],
190+
"type": "float32"
191+
},
192+
{
193+
"data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1],
194+
"dims": [8],
195+
"type": "float32"
196+
}
197+
],
198+
"outputs": [
199+
{
200+
"data": [
201+
0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957,
202+
1.39957, 4.39999, 1.0617, -0.149419, 3.09737
203+
],
204+
"dims": [4, 4],
205+
"type": "float32"
206+
}
207+
]
208+
}
209+
]
210+
}
211+
]

js/web/test/suite-test-list.jsonc

+1
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,7 @@
13521352
"equal.jsonc",
13531353
"exp.jsonc",
13541354
"expand.jsonc",
1355+
"fast-gelu.jsonc",
13551356
"floor.jsonc",
13561357
"gather-elements.jsonc",
13571358
"gemm.jsonc",

0 commit comments

Comments
 (0)