Skip to content

Commit

Permalink
elu
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent aac2fc6 commit 3b883b9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
2 changes: 1 addition & 1 deletion js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['Dropout', '', '7+', unaryOps.identity],
// ['DepthToSpace', '', '1+', depthToSpace, parseDepthToSpaceAttributes],
// ['Equal', '', '7+', binaryOps.equal],
// ['Elu', '', '6+', unaryOps.elu, unaryOps.parseEluAttributes],
['Elu', '', '6+', unaryOps.elu, unaryOps.parseEluAttributes],
// ['Exp', '', '6+', unaryOps.exp],
// ['Flatten', '', '1+', flatten, parseFlattenAttributes],
// ['Floor', '', '6+', unaryOps.floor],
Expand Down
14 changes: 14 additions & 0 deletions js/web/lib/onnxjs/backends/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

/**
* constant value for a workgroup size.
*
* We definitely can do further optimization in future, but for now we use 64.
*
* rule of thumb: Use [a workgroup size of] 64 unless you know what GPU you are targeting or that your workload
* needs something different.
*
* from: https://surma.dev/things/webgpu/
**/
export const WORKGROUP_SIZE = 64;
41 changes: 22 additions & 19 deletions js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,13 @@ import {Tensor} from '../../../tensor';
import {MAX_CLIP, MIN_CLIP} from '../../../util';
import {WebGpuInferenceHandler} from '../inference-handler';
import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {WORKGROUP_SIZE} from './common';

const createElementwiseProgramShader = (funcName: string, funcImpl: string): (datasize: number) => string =>
(datasize) => {
const vecSize = Math.ceil(datasize / 4);
return `
// constant value for a workgroup size.
//
// We definitely can do further optimization in future, but for now we use 64.
//
// rule of thumb: Use [a workgroup size of] 64 unless you know what GPU you are targeting or that your workload
// needs something different.
//
// from: https://surma.dev/things/webgpu/
//
let WORKGROUP_SIZE: u32 = 64u;
let WORKGROUP_SIZE: u32 = ${WORKGROUP_SIZE}u;
@group(0) @binding(0) var<storage, read> inputData : array<vec4<f32>>;
@group(0) @binding(1) var<storage, write> outputData : array<vec4<f32>>;
Expand Down Expand Up @@ -113,17 +105,28 @@ export const ceil = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Pr
export const cos = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'cos'), inputs);

// export interface EluAttributes extends AttributeWithCacheKey {
// readonly alpha: number;
// }
export interface EluAttributes extends AttributeWithCacheKey {
readonly alpha: number;
}

// export const elu =
// (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Tensor[] => [handler.run(
// createElementwiseProgramInfoLoader(handler, inputs[0], glslElu(attributes.alpha), attributes.cacheKey),
// inputs)];
export const elu = async(handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: EluAttributes):
Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(
inputs[0], 'elu', `
let elu_alpha_: f32 = f32(${attributes.alpha});
fn elu_(a: f32) -> f32 {
return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0);
}
fn elu(v: vec4<f32>) -> vec4<f32> {
return vec4(elu_(v.x), elu_(v.y), elu_(v.z), elu_(v.w));
}`,
attributes.cacheKey),
inputs);

// export const parseEluAttributes = (node: Graph.Node): EluAttributes =>
// createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 1.0)});
export const parseEluAttributes = (node: Graph.Node): EluAttributes =>
createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 1.0)});

// export const exp = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslExp()), inputs)];
Expand Down

0 comments on commit 3b883b9

Please sign in to comment.