diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.cc new file mode 100644 index 0000000000000..99cd643423400 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/bert/bias_split_gelu.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/math/unary_elementwise_ops.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + BiasSplitGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + BiasSplitGelu); + +Status BiasSplitGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input"); + const ShaderVariableHelper& bias = shader.AddInput("bias"); + const ShaderVariableHelper& output = shader.AddOutput("output"); + + shader.AdditionalImplementation() << ErfImpl; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "const M_SQRT2: f32 = sqrt(2.0);\n" + << "const halfChannels = uniforms.channels / 2u;\n" + << "let biasIdx = global_idx % halfChannels;\n" + << "let batchIndex = global_idx / halfChannels;\n" + << "let inputOffset = biasIdx + batchIndex * halfChannels * 2;\n" + << "let valueLeft = " << input.GetByOffset("inputOffset") << " + " << bias.GetByOffset("biasIdx") << ";\n" + << "let valueRight = " << input.GetByOffset("inputOffset + halfChannels") << " + " << bias.GetByOffset("biasIdx + halfChannels") << ";\n" + << "let geluRight = valueRight * 0.5 * (erf_v(valueRight / M_SQRT2) + 1);\n" + << output.SetByOffset("global_idx", "valueLeft * geluRight"); + + return Status::OK(); +} + +Status BiasSplitGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + + TensorShape input_shape = input->Shape(); + + if (input_shape.NumDimensions() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasSplitGelu input should have 3 dimensions."); + } + + int64_t channels = input_shape[2]; + int64_t components = GetMaxComponents(channels); + channels /= components; + input_shape[2] = channels / 2; // for output shape calculation (N,S,D) -> (N,S,D/2) + + TensorShape bias_shape = bias->Shape(); + if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasSplitGelu bias should have 1 dimension with size equal to the number of channels."); + } + + auto* output = context.Output(0, input_shape); + int64_t output_size = output->Shape().Size() / components; + + BiasSplitGeluProgram program{}; + program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, + {bias}}) + .AddOutput({output}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {static_cast(channels)}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.h b/onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.h new file mode 100644 index 0000000000000..ccc3dd8c89b7b --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class BiasSplitGeluProgram final : public Program { + public: + BiasSplitGeluProgram() : Program{"BiasSplitGelu"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"channels", ProgramUniformVariableDataType::Uint32}); +}; + +class BiasSplitGelu final : public WebGpuKernel { + public: + BiasSplitGelu(const OpKernelInfo& info) : WebGpuKernel(info) {} + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/webgpu/bert/gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/gelu.cc new file mode 100644 index 0000000000000..8dafecfae83e5 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/gelu.cc @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/math/unary_elementwise_ops.h" // contains Gelu definition +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +ONNX_OPERATOR_KERNEL_EX( + Gelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + Gelu); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/webgpu/bert/quick_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/quick_gelu.cc new file mode 100644 index 0000000000000..7d669e140ef23 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/quick_gelu.cc @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/math/unary_elementwise_ops.h" // contained Gelu definition +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +ONNX_OPERATOR_KERNEL_EX( + QuickGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + QuickGelu); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 068a94c7390e2..6e63ba3a0caa4 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -38,14 +38,14 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing // BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 189d7baafce6a..e16327b9facad 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -256,26 +256,6 @@ WEBGPU_CLIP_KERNEL(MLFloat16) // activation // -class LinearUnit : public UnaryElementwise { - public: - LinearUnit(const OpKernelInfo& info, - const std::string& kernel_name, - const std::string& expression, - const std::string& additional_impl, - float default_alpha) - : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} { - info.GetAttrOrDefault("alpha", &alpha_, default_alpha); - } - - Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { - program.AddUniformVariables({alpha_}); - return Status::OK(); - } - - protected: - float alpha_; -}; - #define WEBGPU_LU_IMPL(OP_TYPE, ...) \ class OP_TYPE final : public LinearUnit { \ public: \ @@ -285,17 +265,17 @@ class LinearUnit : public UnaryElementwise { WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) -class Gelu : public UnaryElementwise { - public: - Gelu(const OpKernelInfo& info) - : UnaryElementwise{info, - "Gelu", - info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, - info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, - ShaderUsage::UseValueTypeAlias} { - cache_hint = info.GetAttrOrDefault("approximate", "none"); - } -}; +Gelu::Gelu(const OpKernelInfo& info) + : UnaryElementwise{info, + "Gelu", + info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, + info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, + ShaderUsage::UseValueTypeAlias} { + cache_hint = info.GetAttrOrDefault("approximate", "none"); +} + +QuickGelu::QuickGelu(const OpKernelInfo& info) + : LinearUnit{info, "QuickGelu", "quick_gelu_v(a)", QuickGeluImpl, 1.702f} {} WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) @@ -312,4 +292,4 @@ WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4) -> vec4 { } )"; +constexpr const char QuickGeluImpl[] = R"( +fn quick_gelu_v(a: vec4) -> vec4 { + let one = 1.0; + let zero = 0.0; + let alpha_vec = vec4(uniforms.attr); + let v = a * alpha_vec; + var x1 : vec4; + for (var i = 0; i < 4; i = i + 1) { + if (v[i] >= zero) { + x1[i] = one / (one + exp(-v[i])); + } else { + x1[i] = one - one / (one + exp(v[i])); + } + } + return a * x1; +} +)"; + // default GELU expression, depending on ErfImpl constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; @@ -111,4 +159,4 @@ constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475 constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))"; } // namespace webgpu -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h new file mode 100644 index 0000000000000..4f9018646905d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -0,0 +1,20 @@ +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace webgpu { + +inline int64_t GetMaxComponents(int64_t size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file