Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/bert/bias_split_gelu.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/math/unary_elementwise_ops.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
BiasSplitGelu,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
BiasSplitGelu);

Status BiasSplitGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& input = shader.AddInput("input");
const ShaderVariableHelper& bias = shader.AddInput("bias");
const ShaderVariableHelper& output = shader.AddOutput("output");

shader.AdditionalImplementation() << ErfImpl;

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< "const M_SQRT2: f32 = sqrt(2.0);\n"
<< "const halfChannels = uniforms.channels / 2u;\n"
<< "let biasIdx = global_idx % halfChannels;\n"
<< "let batchIndex = global_idx / halfChannels;\n"
<< "let inputOffset = biasIdx + batchIndex * halfChannels * 2;\n"
<< "let valueLeft = " << input.GetByOffset("inputOffset") << " + " << bias.GetByOffset("biasIdx") << ";\n"
<< "let valueRight = " << input.GetByOffset("inputOffset + halfChannels") << " + " << bias.GetByOffset("biasIdx + halfChannels") << ";\n"
<< "let geluRight = valueRight * 0.5 * (erf_v(valueRight / M_SQRT2) + 1);\n"
<< output.SetByOffset("global_idx", "valueLeft * geluRight");

return Status::OK();
}

Status BiasSplitGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const auto* input = context.Input(0);
const auto* bias = context.Input(1);

TensorShape input_shape = input->Shape();

if (input_shape.NumDimensions() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasSplitGelu input should have 3 dimensions.");
}

int64_t channels = input_shape[2];
int64_t components = GetMaxComponents(channels);
channels /= components;
input_shape[2] = channels / 2; // for output shape calculation (N,S,D) -> (N,S,D/2)

TensorShape bias_shape = bias->Shape();
if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasSplitGelu bias should have 1 dimension with size equal to the number of channels.");
}

auto* output = context.Output(0, input_shape);
int64_t output_size = output->Shape().Size() / components;

BiasSplitGeluProgram program{};
program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
{bias}})
.AddOutput({output})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
{static_cast<uint32_t>(channels)}});
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
32 changes: 32 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/bias_split_gelu.h:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

class BiasSplitGeluProgram final : public Program<BiasSplitGeluProgram> {
public:
BiasSplitGeluProgram() : Program{"BiasSplitGelu"} {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
{"channels", ProgramUniformVariableDataType::Uint32});
};

class BiasSplitGelu final : public WebGpuKernel {
public:
BiasSplitGelu(const OpKernelInfo& info) : WebGpuKernel(info) {}
Status ComputeInternal(ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
27 changes: 27 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/math/unary_elementwise_ops.h" // contains Gelu definition
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/bert/gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/gelu.cc:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

ONNX_OPERATOR_KERNEL_EX(
Gelu,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
Gelu);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
27 changes: 27 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/quick_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/math/unary_elementwise_ops.h" // contained Gelu definition
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/bert/quick_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/quick_gelu.cc:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

ONNX_OPERATOR_KERNEL_EX(
QuickGelu,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
QuickGelu);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization)>,
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
Expand Down
44 changes: 12 additions & 32 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,26 +256,6 @@
// activation
//

class LinearUnit : public UnaryElementwise {
public:
LinearUnit(const OpKernelInfo& info,
const std::string& kernel_name,
const std::string& expression,
const std::string& additional_impl,
float default_alpha)
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} {
info.GetAttrOrDefault("alpha", &alpha_, default_alpha);
}

Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override {
program.AddUniformVariables({alpha_});
return Status::OK();
}

protected:
float alpha_;
};

#define WEBGPU_LU_IMPL(OP_TYPE, ...) \
class OP_TYPE final : public LinearUnit { \
public: \
Expand All @@ -285,17 +265,17 @@
WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0)
WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes())

class Gelu : public UnaryElementwise {
public:
Gelu(const OpKernelInfo& info)
: UnaryElementwise{info,
"Gelu",
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr,
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhImpl : ErfImpl,
ShaderUsage::UseValueTypeAlias} {
cache_hint = info.GetAttrOrDefault<std::string>("approximate", "none");
}
};
Gelu::Gelu(const OpKernelInfo& info)
: UnaryElementwise{info,
"Gelu",
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr,
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhImpl : ErfImpl,
ShaderUsage::UseValueTypeAlias} {
cache_hint = info.GetAttrOrDefault<std::string>("approximate", "none");

Check warning on line 274 in onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc:274: Add #include <string> for string [build/include_what_you_use] [4]
}

QuickGelu::QuickGelu(const OpKernelInfo& info)
: LinearUnit{info, "QuickGelu", "quick_gelu_v(a)", QuickGeluImpl, 1.702f} {}

WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes())

Expand All @@ -312,4 +292,4 @@
WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes())

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
50 changes: 49 additions & 1 deletion onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,36 @@
ShaderUsage additional_usage_;
};

class Gelu : public UnaryElementwise {
public:
Gelu(const OpKernelInfo& info);

Check warning on line 65 in onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h:65: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
};

class LinearUnit : public UnaryElementwise {
public:
LinearUnit(const OpKernelInfo& info,
const std::string& kernel_name,
const std::string& expression,
const std::string& additional_impl,

Check warning on line 73 in onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h:73: Add #include <string> for string [build/include_what_you_use] [4]
float default_alpha)
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderUsage::UseElementTypeAlias} {
info.GetAttrOrDefault("alpha", &alpha_, default_alpha);
}

Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override {
program.AddUniformVariables({alpha_});
return Status::OK();
}

protected:
float alpha_;
};

class QuickGelu : public LinearUnit {
public:
QuickGelu(const OpKernelInfo& info);

Check warning on line 90 in onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h:90: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
};

constexpr const char ErfImpl[] = R"(
const r0 = 0.3275911;
const r1 = 0.254829592;
Expand Down Expand Up @@ -104,11 +134,29 @@
}
)";

constexpr const char QuickGeluImpl[] = R"(
fn quick_gelu_v(a: vec4<x_element_t>) -> vec4<x_element_t> {
let one = 1.0;
let zero = 0.0;
let alpha_vec = vec4<x_element_t>(uniforms.attr);
let v = a * alpha_vec;
var x1 : vec4<x_element_t>;
for (var i = 0; i < 4; i = i + 1) {
if (v[i] >= zero) {
x1[i] = one / (one + exp(-v[i]));
} else {
x1[i] = one - one / (one + exp(v[i]));
}
}
return a * x1;
}
)";

// default GELU expression, depending on ErfImpl
constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))";

// fast GELU expression, depending on TanhImpl
constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))";

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed under the MIT License.

#pragma once

#include <cstdint>

namespace onnxruntime {
namespace webgpu {

inline int64_t GetMaxComponents(int64_t size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

} // namespace webgpu
} // namespace onnxruntime
Loading