From 41274ba24fe3e3979326bdbdf1f4c81347ceac20 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 24 Mar 2022 14:58:23 -0700 Subject: [PATCH] try more unary ops --- .../backends/webgpu/op-resolve-rules.ts | 10 +- .../onnxjs/backends/webgpu/ops/unary-op.ts | 204 +++++++++++++--- js/web/test/suite-test-list.jsonc | 227 ++++++++++++++++++ 3 files changed, 399 insertions(+), 42 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts index 5e144831dbaaa..c7390764456dc 100644 --- a/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgpu/op-resolve-rules.ts @@ -6,20 +6,16 @@ import {OpSet} from '../../opset'; import * as unaryOps from './ops/unary-op'; export const WEBGPU_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ - ['Abs', '', '6+', unaryOps.abs] - // ['Abs', '', '6+', unaryOps.abs], - // ['Acos', '', '7+', unaryOps.acos], + ['Abs', '', '6+', unaryOps.abs], ['Acos', '', '7+', unaryOps.acos], // ['Add', '', '7+', binaryOps.add], // ['And', '', '7+', binaryOps.and], - // ['Asin', '', '7+', unaryOps.asin], - // ['Atan', '', '7+', unaryOps.atan], + ['Asin', '', '7+', unaryOps.asin], ['Atan', '', '7+', unaryOps.atan], // // TODO: support new attributes for AveragePool-10 // ['AveragePool', '', '7+', averagePool, parseAveragePoolAttributes], // ['BatchNormalization', '', '7+', batchNormalization, parseBatchNormalizationAttributes], // ['Cast', '', '6+', cast, parseCastAttributes], // ['Ceil', '', '6+', unaryOps.ceil], - // ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes], - // ['Clip', '', '11+', unaryOps.clipV11], + ['Clip', '', '6-10', unaryOps.clip, unaryOps.parseClipAttributes], ['Clip', '', '11+', unaryOps.clipV11], // ['Concat', '', '4+', concat, parseConcatAttributes], // ['Conv', '', '1+', conv, parseConvAttributes], // ['Cos', '', '7+', unaryOps.cos], diff --git a/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts b/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts index 1fbc571b331ad..7ec90f40ad543 100644 --- a/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts +++ b/js/web/lib/onnxjs/backends/webgpu/ops/unary-op.ts @@ -1,40 +1,174 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; +import {Graph} from '../../../graph'; import {Tensor} from '../../../tensor'; +import {MAX_CLIP, MIN_CLIP} from '../../../util'; import {WebGpuInferenceHandler} from '../inference-handler'; -import {GpuDataType} from '../types'; - -export const abs = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => handler.run( - { - name: 'Abs', - inputTypes: [GpuDataType.default], - // inputLayouts: [], - // outputLayouts: [], - shaderSource: ` - @group(0) @binding(0) var inputData : array; - @group(0) @binding(1) var outputData : array; - - @stage(compute) @workgroup_size(32) - fn main(@builtin(global_invocation_id) global_id : vec3) { - // Guard against out-of-bounds work group sizes - if (global_id.x * 32u >= ${inputs[0].size}u) { - return; - } - - // - // TODO: SIMD? - // - - let start = global_id.x * 32u; - let end = select(start + 32u, ${inputs[0].size}u, start + 32u > ${inputs[0].size}u); - - for (var i = start; i < end; i = i + 1u) { - outputData[i] = abs(inputData[i]); - } - }`, - outputs: [{dims: inputs[0].dims, type: inputs[0].type, gpuDataType: GpuDataType.default}], - // entryPoint: 'main', - dispatchGroup: (inputTensors) => ({x: Math.ceil(inputTensors[0].size / 32)}) - }, - inputs); +import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +const createElementwiseProgramShader = (funcName: string, funcImpl: string): (datasize: number) => string => + (datasize) => { + const vecSize = Math.ceil(datasize / 4); + return ` + // constant value for a workgroup size. + // + // We definitely can do further optimization in future, but for now we use 64. + // + // rule of thumb: Use [a workgroup size of] 64 unless you know what GPU you are targeting or that your workload + // needs something different. + // + // from: https://surma.dev/things/webgpu/ + // + let WORKGROUP_SIZE: u32 = 64u; + + @group(0) @binding(0) var inputData : array>; + @group(0) @binding(1) var outputData : array>; + + ${funcImpl} + + @stage(compute) @workgroup_size(WORKGROUP_SIZE) + fn main(@builtin(global_invocation_id) global_id : vec3) { + + // Guard against out-of-bounds work group sizes + if (global_id.x >= ${vecSize}u) { + return; + } + + outputData[global_id.x] = ${funcName}(inputData[global_id.x]); + }`; + }; + +const createElementwiseProgramInfo = + (metadata: ProgramMetadata, input: Tensor, funcName: string, funcImpl = ''): ProgramInfo => ({ + ...metadata, + shaderSource: createElementwiseProgramShader(funcName, funcImpl)(input.size), + outputs: [{dims: input.dims, type: input.type, gpuDataType: GpuDataType.default}], + dispatchGroup: (inputTensors) => + ({x: Math.ceil(inputTensors[0].size / 64 /* workgroup size */ / 4 /* vec size */)}) + }); + +const createElementwiseProgramInfoLoader = + (input: Tensor, functionName: string, functionImplementation = '', cacheKey?: string): ProgramInfoLoader => { + const metadata: ProgramMetadata = {name: functionName, inputTypes: [GpuDataType.default], cacheHint: cacheKey}; + return { + ...metadata, + get: () => createElementwiseProgramInfo(metadata, input, functionName, functionImplementation) + }; + }; + +export const abs = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'abs'), inputs); + +export const acos = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'acos'), inputs); + +export const asin = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'asin'), inputs); + +export const atan = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => + handler.run(createElementwiseProgramInfoLoader(inputs[0], 'atan'), inputs); + +export interface ClipAttributes extends AttributeWithCacheKey { + readonly min: number; + readonly max: number; +} + +export const clip = (handler: WebGpuInferenceHandler, inputs: Tensor[], attributes: ClipAttributes): Tensor[] => + handler.run( + createElementwiseProgramInfoLoader( + inputs[0], 'clip', ` + let clip_min_: f32 = f32(${attributes.min}); + let clip_max_: f32 = f32(${attributes.max}); + + fn clip(x: vec4) -> vec4 { + return clamp(x, clip_min_, clip_max_); + }`, + attributes.cacheKey), + inputs); + +export const parseClipAttributes = (node: Graph.Node): ClipAttributes => createAttributeWithCacheKey( + {min: node.attributes.getFloat('min', MIN_CLIP), max: node.attributes.getFloat('max', MAX_CLIP)}); + +const generateClipAttributesFromInputs = (handler: WebGpuInferenceHandler, inputs: Tensor[]): ClipAttributes => { + if (inputs.length >= 3 && + (!handler.session.isInitializer(inputs[1].dataId) || !handler.session.isInitializer(inputs[2].dataId))) { + throw new Error('dynamic clip attributes are not allowed'); + } + + const min = (inputs.length >= 3) ? inputs[1].numberData[0] : MIN_CLIP; + const max = (inputs.length >= 3) ? inputs[2].numberData[0] : MAX_CLIP; + return createAttributeWithCacheKey({min, max}); +}; + +export const clipV11 = (handler: WebGpuInferenceHandler, inputs: Tensor[]): Tensor[] => { + const attributes = generateClipAttributesFromInputs(handler, inputs); + return clip(handler, [inputs[0]], attributes); +}; + +// export const ceil = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCeil()), inputs)]; + +// export const cos = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslCos()), inputs)]; + +// export interface EluAttributes extends AttributeWithCacheKey { +// readonly alpha: number; +// } + +// export const elu = +// (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: EluAttributes): Tensor[] => [handler.run( +// createElementwiseProgramInfoLoader(handler, inputs[0], glslElu(attributes.alpha), attributes.cacheKey), +// inputs)]; + +// export const parseEluAttributes = (node: Graph.Node): EluAttributes => +// createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 1.0)}); + +// export const exp = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslExp()), inputs)]; + +// export const floor = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslFloor()), inputs)]; + +// export const identity = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslIdentity()), inputs)]; + +// export interface LeakyReluAttributes extends AttributeWithCacheKey { +// readonly alpha: number; +// } + +// export const leakyRelu = +// (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: LeakyReluAttributes): Tensor[] => [handler.run( +// createElementwiseProgramInfoLoader(handler, inputs[0], glslLeakyRelu(attributes.alpha), attributes.cacheKey), +// inputs)]; + +// export const parseLeakyReluAttributes = (node: Graph.Node): LeakyReluAttributes => +// createAttributeWithCacheKey({alpha: node.attributes.getFloat('alpha', 0.01)}); + +// export const log = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslLog()), inputs)]; + +// export const neg = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNeg()), inputs)]; + +// export const not = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslNot()), inputs)]; + +// export const relu = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslRelu()), inputs)]; + +// export const sigmoid = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSigmoid()), inputs)]; + +// export const sin = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSin()), inputs)]; + +// export const sqrt = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslSqrt()), inputs)]; + +// export const tan = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTan()), inputs)]; + +// export const tanh = (handler: WebGLInferenceHandler, inputs: Tensor[]): +// Tensor[] => [handler.run(createElementwiseProgramInfoLoader(handler, inputs[0], glslTanh()), inputs)]; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index b7f954eea8266..e4f020da86204 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -282,6 +282,233 @@ "xor.jsonc" ] }, + "webgpu": { + "onnx": [], + "node": [ + "test_abs", + "test_acos_example", + "test_acos", + // "test_add_bcast", + // "test_add", + // "test_and_bcast3v1d", + // "test_and_bcast3v2d", + // "test_and_bcast4v2d", + // "test_and_bcast4v3d", + // "test_and_bcast4v4d", + // "test_and2d", + // "test_and3d", + // "test_and4d", + "test_asin_example", + "test_asin", + "test_atan_example", + "test_atan", + // "test_averagepool_1d_default", + // "test_averagepool_2d_default", + //"v12/test_averagepool_2d_pads", // TODO: fix avgpool and maxpool on VM + // "v12/test_averagepool_2d_precomputed_pads", + // "v12/test_averagepool_2d_precomputed_same_upper", + // "v12/test_averagepool_2d_precomputed_strides", + // "v12/test_averagepool_2d_same_upper", + // "v12/test_averagepool_2d_same_lower", + // "v12/test_averagepool_2d_strides", + // "test_averagepool_3d_default", + // "test_basic_conv_with_padding", + // "test_basic_conv_without_padding", + // "test_batchnorm_epsilon", + // "test_batchnorm_example", + // "test_cast_DOUBLE_to_FLOAT", + // "test_cast_FLOAT_to_DOUBLE", + "v{7,8,9,10}/test_clip_splitbounds", + "v{7,8,9,10}/test_clip_outbounds", + "v{7,8,9,10}/test_clip_inbounds", + "v{7,8,9,10}/test_clip_example", + "v{7,8,9,10}/test_clip_default_min", + "v{7,8,9,10}/test_clip_default_max", + "v{7,8,9,10}/test_clip_default_inbounds", + "v{7,8,9,10}/test_clip", + // "test_concat_1d_axis_0", + // "test_concat_2d_axis_0", + // "test_concat_2d_axis_1", + // "test_concat_3d_axis_0", + // "test_concat_3d_axis_1", + // "test_concat_3d_axis_2", + // "test_conv_with_strides_and_asymmetric_padding", + // "test_conv_with_strides_no_padding", + // "test_conv_with_strides_padding", + "test_constant", + "test_cos_example", + "test_cos", + // "test_div_bcast", + // "test_div_example", + // "test_div", + // "test_dropout_default", + // "test_dropout_random", + // "test_depthtospace_crd_mode", + // "test_depthtospace_crd_mode_example", + // "test_depthtospace_dcr_mode", + // "test_depthtospace_example", + "test_elu_example", + "test_elu", + "test_elu_default" + // "test_flatten_axis0", + // "test_flatten_axis1", + // "test_flatten_axis2", + // "test_flatten_axis3", + // "test_flatten_default_axis", + // "test_gather_0", + // "test_gather_1", + // "test_gemm_nobroadcast", + // "test_gemm_broadcast", + // "test_globalaveragepool_precomputed", + // "test_globalaveragepool", + // "test_globalmaxpool_precomputed", + // "test_globalmaxpool", + // "test_greater_bcast", + // "test_greater", + // "test_instancenorm_epsilon", + // "test_instancenorm_example", + // "test_less_bcast", + // "test_less", + // "test_equal_bcast", + // "test_equal", + // "test_identity", + // "test_leakyrelu_default", + // "test_leakyrelu_example", + // "test_leakyrelu", + // "test_lrn_default", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-2 (100x increase), it passes the test. + // "test_lrn", <-- failing due to low precison. If absolute CPU error threshold is increased from 1e-4 to 1e-3 (10x increase), it passes the test. + // "test_matmul_2d", + // "test_matmul_3d", + // "test_matmul_4d", + // "test_maxpool_1d_default", + // "test_maxpool_2d_default", + // "v12/test_maxpool_2d_pads", + // "v12/test_maxpool_2d_precomputed_pads", + // "v12/test_maxpool_2d_precomputed_same_upper", + // "v12/test_maxpool_2d_precomputed_strides", + // "v12/test_maxpool_2d_same_lower", + // "v12/test_maxpool_2d_same_upper", + // "v12/test_maxpool_2d_strides", + // "test_maxpool_3d_default", + // "test_mul_bcast", + // "test_mul_example", + // "test_mul", + // "test_neg", + // "test_neg_example", + // "test_not_2d", + // "test_not_3d", + // "test_not_4d", + // "test_or_bcast3v1d", + // "test_or_bcast3v2d", + // "test_or_bcast4v2d", + // "test_or_bcast4v3d", + // "test_or_bcast4v4d", + // "test_prelu_broadcast", + // "test_prelu_example", + // "test_relu", + // "test_reshape_extended_dims", + // "test_reshape_negative_dim", + // "test_reshape_one_dim", + // "test_reshape_reduced_dims", + // "test_reshape_reordered_dims", + // "test_sigmoid", + // "test_sigmoid_example", + // "test_sin_example", + // "test_sin", + // "test_softmax_axis_0", + // "test_softmax_axis_1", + // "test_softmax_axis_2", + // "test_softmax_default_axis", + // "test_softmax_example", + // { + // "name": "test_softmax_large_number", + // "condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment + // }, + // "test_sub_bcast", + // "test_sub_example", + // "test_sub", + // "test_sum_example", + // "test_sum_one_input", + // "test_sum_two_inputs", + // "test_reduce_log_sum_asc_axes", + // "test_reduce_log_sum_default", + // "test_reduce_log_sum_desc_axes", + // "test_reduce_max_default_axes_keepdim_example", + // "test_reduce_max_default_axes_keepdims_random", + // "test_reduce_max_do_not_keepdims_example", + // "test_reduce_max_do_not_keepdims_random", + // "test_reduce_max_keepdims_example", + // "test_reduce_max_keepdims_random", + // "test_reduce_mean_default_axes_keepdims_example", + // "test_reduce_mean_default_axes_keepdims_random", + // "test_reduce_mean_do_not_keepdims_example", + // "test_reduce_mean_do_not_keepdims_random", + // "test_reduce_mean_keepdims_example", + // "test_reduce_mean_keepdims_random", + // "test_reduce_min_default_axes_keepdims_example", + // "test_reduce_min_default_axes_keepdims_random", + // "test_reduce_min_do_not_keepdims_example", + // "test_reduce_min_do_not_keepdims_random", + // "test_reduce_min_keepdims_example", + // "test_reduce_min_keepdims_random", + // { + // "name": "test_reduce_prod_default_axes_keepdims_example", + // "condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment + // }, + // "test_reduce_prod_default_axes_keepdims_random", + // "test_reduce_prod_do_not_keepdims_example", + // "test_reduce_prod_do_not_keepdims_random", + // "test_reduce_prod_keepdims_example", + // "test_reduce_prod_keepdims_random", + // "v{7,8,9,10,11,12}/test_reduce_sum_default_axes_keepdims_example", + // "v{7,8,9,10,11,12}/test_reduce_sum_default_axes_keepdims_random", + // "v{7,8,9,10,11,12}/test_reduce_sum_do_not_keepdims_example", + // "v{7,8,9,10,11,12}/test_reduce_sum_do_not_keepdims_random", + // "v{7,8,9,10,11,12}/test_reduce_sum_keepdims_example", + // "v{7,8,9,10,11,12}/test_reduce_sum_keepdims_random", + // "v{7,8,9,10,11,12}/test_reduce_sum_square_default_axes_keepdims_example", + // "v{7,8,9,10,11,12}/test_reduce_sum_square_default_axes_keepdims_random", + // "v{7,8,9,10,11,12}/test_reduce_sum_square_do_not_keepdims_example", + // "v{7,8,9,10,11,12}/test_reduce_sum_square_do_not_keepdims_random", + // "v{7,8,9,10,11,12}/test_reduce_sum_square_keepdims_example", + // "v{7,8,9,10,11,12}/test_reduce_sum_square_keepdims_random", + // "v{7,8,9,10,11,12}/test_split_variable_parts_default_axis", + // "v{7,8,9,10,11,12}/test_split_variable_parts_1d", + // "v{7,8,9,10,11,12}/test_split_variable_parts_2d", + // "v{7,8,9,10,11,12}/test_split_equal_parts_default_axis", + // "v{7,8,9,10,11,12}/test_split_equal_parts_1d", + // "v{7,8,9,10,11,12}/test_split_equal_parts_2d", + // "v{7,8,9}/test_slice", + // "v{7,8,9}/test_slice_default_axes", + // "v{7,8,9}/test_slice_end_out_of_bounds", + // "v{7,8,9}/test_slice_neg", + // "test_slice_start_out_of_bounds", // tensor shape of 0 + // "test_squeeze", + // "test_tan_example", + // "test_tan", + // "test_tanh_example", + // "test_tanh", + // "test_tile", + // "test_tile_precomputed", + // "test_transpose_all_permutations_0", + // "test_transpose_all_permutations_1", + // "test_transpose_all_permutations_2", + // "test_transpose_all_permutations_3", + // "test_transpose_all_permutations_4", + // "test_transpose_all_permutations_5", + // "test_transpose_default", + // "test_unsqueeze", + // "test_xor_bcast3v1d", + // "test_xor_bcast3v2d", + // "test_xor_bcast4v2d", + // "test_xor_bcast4v3d", + // "test_xor_bcast4v4d", + // "test_xor2d", + // "test_xor3d", + // "test_xor4d" + ], + "ops": [] + }, "wasm": { "onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"], "node": [