Skip to content

Commit

Permalink
other f32 unary operators
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 13, 2022
1 parent a1fbcfd commit e8e4d88
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 38 deletions.
13 changes: 5 additions & 8 deletions js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['InstanceNormalization', '', '6+', instanceNormalization, parseInstanceNormalizationAttributes],
['LeakyRelu', '', '6+', unaryOps.leakyRelu, unaryOps.parseLeakyReluAttributes],
// ['Less', '', '7+', binaryOps.less],
// ['Log', '', '6+', unaryOps.log],
['Log', '', '6+', unaryOps.log],
// ['MatMul', '', '1+', matMul, parseMatMulAttributes],
// // TODO: support new attributes for MaxPool-8 and MaxPool-10
// ['MaxPool', '', '1+', maxPool, parseMaxPoolAttributes],
Expand All @@ -59,13 +59,11 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// ['ReduceProd', '', '1+', reduceProd, parseReduceAttributes],
// ['ReduceSum', '', '1-12', reduceSum, parseReduceAttributes],
// ['ReduceSumSquare', '', '1+', reduceLogSumSquare, parseReduceAttributes],
// ['Relu', '', '6+', unaryOps.relu],
['Reshape', '', '5+', reshape],
['Relu', '', '6+', unaryOps.relu], ['Reshape', '', '5+', reshape],
// ['Resize', '', '10', resize, parseResizeAttributesV10],
// ['Resize', '', '11+', resize, parseResizeAttributesV11],
// ['Shape', '', '1+', shape],
// ['Sigmoid', '', '6+', unaryOps.sigmoid],
// ['Sin', '', '7+', unaryOps.sin],
['Sigmoid', '', '6+', unaryOps.sigmoid], ['Sin', '', '7+', unaryOps.sin],
// ['Slice', '', '10+', sliceV10], // TODO: support 'steps' for Slice-10
// ['Slice', '', '1-9', slice, parseSliceAttributes],
// // The "semantic" meaning of axis has changed in opset-13.
Expand All @@ -76,13 +74,12 @@ export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
// // When the attribute is missing, we need the count of number of outputs
// // so that we can determine the 'split' attribute from the runtime input to the Operator
// ['Split', '', '2-12', split, parseSplitAttributes],
// ['Sqrt', '', '6+', unaryOps.sqrt],
['Sqrt', '', '6+', unaryOps.sqrt],
// ['Squeeze', '', '1-12', squeeze, parseSqueezeAttributes],
// ['Squeeze', '', '13+', squeezeV13],
// ['Sub', '', '7+', binaryOps.sub],
// ['Sum', '', '6+', sum],
// ['Tan', '', '7+', unaryOps.tan],
// ['Tanh', '', '6+', unaryOps.tanh],
['Tan', '', '7+', unaryOps.tan], ['Tanh', '', '6+', unaryOps.tanh],
// ['Tile', '', '6+', tile],
// ['Transpose', '', '1+', transpose, parseTransposeAttributes],
// ['Upsample', '', '7-8', upsample, parseUpsampleAttributesV7],
Expand Down
40 changes: 26 additions & 14 deletions js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,29 +157,41 @@ export const leakyRelu = async(handler: WebGpuInferenceHandler, inputs: Tensor[]
export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes =>
createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)});

// export const log = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslLog()), inputs)];
export const log = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'log'), inputs);

// export const neg = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNeg()), inputs)];

// export const not = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNot()), inputs)];

// export const relu = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslRelu()), inputs)];
export const relu = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(inputs[0], 'relu', `
let relu_zero_: vec4<f32> = vec4(0.0, 0.0, 0.0, 0.0);
// export const sigmoid = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSigmoid()), inputs)];
fn relu(v: vec4<f32>) -> vec4<f32> {
return max( v, relu_zero_ );
}`),
inputs);

// export const sin = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSin()), inputs)];
export const sigmoid = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[] >=>handler.run(
createElementwiseProgramInfoLoader(inputs[0], 'sigmoid', `
let sigmoid_one_: vec4<f32> = vec4(1.0, 1.0, 1.0, 1.0);
// export const sqrt = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSqrt()), inputs)];
fn sigmoid(v: vec4<f32>) -> vec4<f32> {
return sigmoid_one_ / (sigmoid_one_ + exp(-v));
}`),
inputs);

// export const tan = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTan()), inputs)];
export const sin = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'sin'), inputs);

// export const tanh = (handler: WebGLInferenceHandler, inputs: Tensor[]):
// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTanh()), inputs)];
export const sqrt = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'sqrt'), inputs);

export const tan = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'tan'), inputs);

export const tanh = async(handler: WebGpuInferenceHandler, inputs: Tensor[]): Promise<Tensor[]> =>
handler.run(createElementwiseProgramInfoLoader(inputs[0], 'tanh'), inputs);
32 changes: 16 additions & 16 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@
// "test_identity",
"test_leakyrelu_default",
"test_leakyrelu_example",
"test_leakyrelu"
"test_leakyrelu",
// "test_lrn_default", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-2 (100x increase), it passes the test.
// "test_lrn", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-3 (10x increase), it passes the test.
// "test_matmul_2d",
Expand Down Expand Up @@ -405,16 +405,16 @@
// "test_or_bcast4v4d",
// "test_prelu_broadcast",
// "test_prelu_example",
// "test_relu",
"test_relu",
// "test_reshape_extended_dims",
// "test_reshape_negative_dim",
// "test_reshape_one_dim",
// "test_reshape_reduced_dims",
// "test_reshape_reordered_dims",
// "test_sigmoid",
// "test_sigmoid_example",
// "test_sin_example",
// "test_sin",
"test_sigmoid",
"test_sigmoid_example",
"test_sin_example",
"test_sin",
// "test_softmax_axis_0",
// "test_softmax_axis_1",
// "test_softmax_axis_2",
Expand Down Expand Up @@ -484,10 +484,10 @@
// "v{7,8,9}/test_slice_neg",
// "test_slice_start_out_of_bounds", // tensor shape of 0
// "test_squeeze",
// "test_tan_example",
// "test_tan",
// "test_tanh_example",
// "test_tanh",
"test_tan_example",
"test_tan",
"test_tanh_example",
"test_tanh"
// "test_tile",
// "test_tile_precomputed",
// "test_transpose_all_permutations_0",
Expand Down Expand Up @@ -528,26 +528,26 @@
////"identity.jsonc",
//"image-scaler.jsonc",
//"less.jsonc",
//"log.jsonc",
"log.jsonc",
//"matmul.jsonc",
//"mul.jsonc",
//"neg.jsonc",
//"not.jsonc",
//"or.jsonc",
"leaky-relu.jsonc"
"leaky-relu.jsonc",
//"reduce-min.jsonc",
//"relu.jsonc",
"relu.jsonc",
//"pad.jsonc",
//"pad-big.jsonc",
//"pow.jsonc",
//"pow-big-number.jsonc",
//"reshape.jsonc",
//"softmax.jsonc",
//"sin.jsonc",
"sin.jsonc",
//"split.jsonc",
//"sqrt.jsonc",
"sqrt.jsonc",
//"sub.jsonc",
//"tan.jsonc",
"tan.jsonc"
//"transpose.jsonc",
//"xor.jsonc"
]
Expand Down

0 comments on commit e8e4d88

Please sign in to comment.