From 073695ff1b3fbf4b614ae4b056a11fe4b6fec30f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 6 Apr 2022 17:54:24 -0700 Subject: [PATCH] optimize types --- .../onnxjs/backends/webgpu/ops/unary-op.ts | 95 ++++++++----------- 1 file changed, 41 insertions(+), 54 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts b/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts index 2ea5c905754a6..0b15b4b77b51e 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts @@ -9,25 +9,19 @@ import {WebGpuInferenceHandler} from '../inference-handler'; import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; import {WORKGROUP_SIZE} from './common'; -type ElementwiseFunctionImplementation = - // name, builtin function call. - // eg. ['Abs', 'abs'] - [string, string]| - // name, function call builder, extra implementation (optional) - // eg. ['Neg', a => `-${a}`] - [string, (variableName: string) => string, string?]; +type BuiltinFunctionName = string; +type ElementwiseCustomExpression = (expression: string) => string; +type ElementwiseFunctionCall = BuiltinFunctionName|ElementwiseCustomExpression; const createElementwiseProgramShader = - (functionImplementation: ElementwiseFunctionImplementation, datasize: number): string => { + (datasize: number, funcCall: ElementwiseFunctionCall, additionalImplementation?: string): string => { const vecSize = Math.ceil(datasize / 4); - let funcImpl: string; - let funcCall = functionImplementation[1]; - if (typeof funcCall === 'function') { - funcImpl = functionImplementation[2] ?? ''; - funcCall = funcCall('a'); + + let expression = ''; + if (typeof funcCall === 'string') { + expression = `${funcCall}(a)`; } else { - funcImpl = ''; - funcCall = `${funcCall}(a)`; + expression = funcCall('a'); } return ` let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u; @@ -35,7 +29,7 @@ const createElementwiseProgramShader = @group(0) @binding(0) var inputData : array>; @group(0) @binding(1) var outputData : array>; - ${funcImpl} + ${additionalImplementation ?? ''} @stage(compute) @workgroup_size(WORKGROUP_SIZE) fn main(@builtin(global_invocation_id) global_id : vec3) { @@ -46,39 +40,41 @@ const createElementwiseProgramShader = } let a = inputData[global_id.x]; - outputData[global_id.x] = ${funcCall}; + outputData[global_id.x] = ${expression}; }`; }; const createElementwiseProgramInfo = - (metadata: ProgramMetadata, input: Tensor, functionImplementation: ElementwiseFunctionImplementation): + (metadata: ProgramMetadata, input: Tensor, funcCall: ElementwiseFunctionCall, additionalImplementation?: string): ProgramInfo => ({ ...metadata, - shaderSource: createElementwiseProgramShader(functionImplementation, input.size), + shaderSource: createElementwiseProgramShader(input.size, funcCall, additionalImplementation), outputs: [{dims: input.dims, type: input.type, gpuDataType: GpuDataType.default}], dispatchGroup: (inputTensors) => ({x: Math.ceil(inputTensors[0].size / 64 /* workgroup size */ / 4 /* vec size */)}) }); const createElementwiseProgramInfoLoader = - (input: Tensor, functionImplementation: ElementwiseFunctionImplementation, + (input: Tensor, name: string, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, cacheKey?: string): ProgramInfoLoader => { - const metadata: - ProgramMetadata = {name: functionImplementation[0], inputTypes: [GpuDataType.default], cacheHint: cacheKey}; - return {...metadata, get: () => createElementwiseProgramInfo(metadata, input, functionImplementation)}; + const metadata: ProgramMetadata = {name, inputTypes: [GpuDataType.default], cacheHint: cacheKey}; + return { + ...metadata, + get: () => createElementwiseProgramInfo(metadata, input, funcCall, additionalImplementation) + }; }; export const abs = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Abs', 'abs']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Abs', 'abs'), inputs); export const acos = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Acos', 'acos']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Acos', 'acos'), inputs); export const asin = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Asin', 'asin']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Asin', 'asin'), inputs); export const atan = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Atan', 'atan']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Atan', 'atan'), inputs); export interface ClipAttributes extends AttributeWithCacheKey { readonly min: number; @@ -88,13 +84,10 @@ export interface ClipAttributes extends AttributeWithCacheKey { export const clip = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: ClipAttributes): Promise=>handler.run( createElementwiseProgramInfoLoader( - inputs[0], - [ - 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` + inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` let clip_min_: vec4 = vec4(f32(${attributes.min})); let clip_max_: vec4 = vec4(f32(${attributes.max})); -` - ], +`, attributes.cacheKey), inputs); @@ -118,10 +111,10 @@ export const clipV11 = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): }; export const ceil = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Ceil', 'ceil']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Ceil', 'ceil'), inputs); export const cos = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Cos', 'cos']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Cos', 'cos'), inputs); export interface EluAttributes extends AttributeWithCacheKey { readonly alpha: number; @@ -130,9 +123,7 @@ export interface EluAttributes extends AttributeWithCacheKey { export const elu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Promise=>handler.run( createElementwiseProgramInfoLoader( - inputs[0], - [ - 'Elu', a => `elu_vf32(${a})`, ` + inputs[0], 'Elu', a => `elu_vf32(${a})`, ` let elu_alpha_: f32 = f32(${attributes.alpha}); fn elu_f32(a: f32) -> f32 { @@ -141,8 +132,7 @@ export const elu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attr fn elu_vf32(v: vec4) -> vec4 { return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w)); - }` - ], + }`, attributes.cacheKey), inputs); @@ -150,10 +140,10 @@ export const parseEluAttributes = (node: Graph.Node): EluAttributes => createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 1.0)}); export const exp = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Exp', 'exp']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Exp', 'exp'), inputs); export const floor = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Floor', 'floor']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Floor', 'floor'), inputs); export interface LeakyReluAttributes extends AttributeWithCacheKey { readonly alpha: number; @@ -162,9 +152,7 @@ export interface LeakyReluAttributes extends AttributeWithCacheKey { export const leakyRelu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Promise=>handler.run( createElementwiseProgramInfoLoader( - inputs[0], - [ - 'LeakyRelu', a => `leaky_relu_vf32(${a})`, ` + inputs[0], 'LeakyRelu', a => `leaky_relu_vf32(${a})`, ` let leaky_relu_alpha_: f32 = f32(${attributes.alpha}); fn leaky_relu_f32(a: f32) -> f32 { @@ -173,8 +161,7 @@ export const leakyRelu = async(handler: WebGpuInferenceHandler, inputs: Tensor[] fn leaky_relu_vf32(v: vec4) -> vec4 { return vec4(leaky_relu_f32(v.x), leaky_relu_f32(v.y), leaky_relu_f32(v.z), leaky_relu_f32(v.w)); - }` - ], + }`, attributes.cacheKey), inputs); @@ -182,28 +169,28 @@ export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)}); export const log = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Log', 'log']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Log', 'log'), inputs); export const neg = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Neg', a => `-${a}`]), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Neg', a => `-${a}`), inputs); // export const not = (handler: WebGLInferenceHandler, inputs: Tensor[]): // Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNot()), inputs)]; export const relu = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise=>handler.run( - createElementwiseProgramInfoLoader(inputs[0], ['Relu', a => `max(${a}, vec4(0.0))`]), inputs); + createElementwiseProgramInfoLoader(inputs[0], 'Relu', a => `max(${a}, vec4(0.0))`), inputs); export const sigmoid = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise=>handler.run( - createElementwiseProgramInfoLoader(inputs[0], ['Sigmoid', a => `(vec4(1.0) / (vec4(1.0) + exp(-${a})))`]), inputs); + createElementwiseProgramInfoLoader(inputs[0], 'Sigmoid', a => `(vec4(1.0) / (vec4(1.0) + exp(-${a})))`), inputs); export const sin = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Sin', 'sin']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Sin', 'sin'), inputs); export const sqrt = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Sqrt', 'sqrt']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Sqrt', 'sqrt'), inputs); export const tan = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Tan', 'tan']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Tan', 'tan'), inputs); export const tanh = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => - handler.run(createElementwiseProgramInfoLoader(inputs[0], ['Tanh', 'tanh']), inputs); + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'Tanh', 'tanh'), inputs);