diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index c7595d8325661..ca9a24cfa3692 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -25,7 +25,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ // ['Dropout', '', '7+', unaryOps.identity], // ['DepthToSpace', '', '1+', depthToSpace, parseDepthToSpaceAttributes], // ['Equal', '', '7+', binaryOps.equal], - // ['Elu', '', '6+', unaryOps.elu, unaryOps.parseEluAttributes], + ['Elu', '', '6+', unaryOps.elu, unaryOps.parseEluAttributes], // ['Exp', '', '6+', unaryOps.exp], // ['Flatten', '', '1+', flatten, parseFlattenAttributes], // ['Floor', '', '6+', unaryOps.floor], diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/common.ts b/js/web/lib/onnxjs/backends/webgpu/ops/common.ts new file mode 100644 index 0000000000000..b436e82f0d25a --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgpu/ops/common.ts @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/** + * constant value for a workgroup size. + * + * We definitely can do further optimization in future, but for now we use 64. + * + * rule of thumb: Use [a workgroup size of] 64 unless you know what GPU you are targeting or that your workload + * needs something different. + * + * from: https://surma.dev/things/webgpu/ + **/ +export const WORKGROUP_SIZE = 64; diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts b/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts index 81f52de3f9c93..16120fc808aee 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts @@ -7,21 +7,13 @@ import {Tensor} from '../../../tensor'; import {MAX_CLIP, MIN_CLIP} from '../../../util'; import {WebGpuInferenceHandler} from '../inference-handler'; import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {WORKGROUP_SIZE} from './common'; const createElementwiseProgramShader = (funcName: string, funcImpl: string): (datasize: number) => string => (datasize) => { const vecSize = Math.ceil(datasize / 4); return ` - // constant value for a workgroup size. - // - // We definitely can do further optimization in future, but for now we use 64. - // - // rule of thumb: Use [a workgroup size of] 64 unless you know what GPU you are targeting or that your workload - // needs something different. - // - // from: https://surma.dev/things/webgpu/ - // - let WORKGROUP_SIZE: u32 = 64u; + let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u; @group(0) @binding(0) var inputData : array>; @group(0) @binding(1) var outputData : array>; @@ -113,17 +105,28 @@ export const ceil = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Pr export const cos = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise => handler.run(createElementwiseProgramInfoLoader(inputs[0], 'cos'), inputs); -// export interface EluAttributes extends AttributeWithCacheKey { -// readonly alpha: number; -// } +export interface EluAttributes extends AttributeWithCacheKey { + readonly alpha: number; +} -// export const elu = -// (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Tensor[] => [handler.run( -// createElementwiseProgramInfoLoader(handler, inputs[0], glslElu(attributes.alpha), attributes.cacheKey), -// inputs)]; +export const elu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: EluAttributes): + Promise=>handler.run( + createElementwiseProgramInfoLoader( + inputs[0], 'elu', ` + let elu_alpha_: f32 = f32(${attributes.alpha}); + + fn elu_(a: f32) -> f32 { + return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0); + } + + fn elu(v: vec4) -> vec4 { + return vec4(elu_(v.x), elu_(v.y), elu_(v.z), elu_(v.w)); + }`, + attributes.cacheKey), + inputs); -// export const parseEluAttributes = (node: Graph.Node): EluAttributes => -// createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 1.0)}); +export const parseEluAttributes = (node: Graph.Node): EluAttributes => + createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 1.0)}); // export const exp = (handler: WebGLInferenceHandler, inputs: Tensor[]): // Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslExp()), inputs)];