Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
#include "core/providers/common.h"
#include "core/providers/webgpu/math/binary_elementwise_ops.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/string_macros.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {
Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);

shader.AdditionalImplementation() << additional_impl_;

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size");

// check whether can use element-wise mode.
Expand Down Expand Up @@ -142,8 +145,15 @@
}

uint32_t vec_size = onnxruntime::narrow<uint32_t>((size + 3) / 4);

std::string additional_impl;
if (get_additional_impl_) {
additional_impl = get_additional_impl_(lhs_tensor->GetElementType(), rhs_tensor->GetElementType());
}

BinaryElementwiseProgram program{kernel_name_,
expression_,
additional_impl,
is_broadcast,
is_lhs_scalar,
is_rhs_scalar,
Expand Down Expand Up @@ -273,7 +283,28 @@
WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes())
WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes())

WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(vec4<f32>(a), vec4<f32>(b)))")
std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) {
SS(s, 1024);
std::string round_str;

Check warning on line 288 in onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc:288: Add #include <string> for string [build/include_what_you_use] [4]
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
round_str = "round";
}

s << "fn pow_custom(a : input_a_element_t, b : f32) -> input_a_element_t {\n"
" if (b == 0.0) {\n"
" return input_a_element_t(1.0);\n"
" } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n"
" return input_a_element_t(pow(f32(a), b)); // NaN\n"
" }\n"
<< " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n"
<< "}\n"
"fn pow_v(a : vec4<input_a_element_t>, b : vec4<input_b_element_t>) -> vec4<input_a_element_t> {\n"
" return vec4<input_a_element_t>(pow_custom(a.x, f32(b.x)), pow_custom(a.y, f32(b.y)), pow_custom(a.z, f32(b.z)), pow_custom(a.w, f32(b.w)));\n"
"}\n";
return SS_GET(s);
}

WEBGPU_BINARY_IMPL(Pow, "pow_v(a, b)", GetPowImpl)
WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes())
WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes())
WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes())
Expand Down
16 changes: 12 additions & 4 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram>
public:
BinaryElementwiseProgram(const std::string& kernel_name,
const std::string& expression,
const std::string& additional_impl,
const bool is_broadcast,
const bool is_lhs_scalar,
const bool is_rhs_scalar,
const bool vectorize) : Program{kernel_name},
expression_{expression},
additional_impl_{additional_impl},
is_broadcast_{is_broadcast},
is_lhs_scalar_{is_lhs_scalar},
is_rhs_scalar_{is_rhs_scalar},
Expand All @@ -29,7 +31,8 @@ class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram>
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
std::string expression_;
std::string_view expression_;
std::string_view additional_impl_;
bool is_broadcast_;
bool is_lhs_scalar_;
bool is_rhs_scalar_;
Expand All @@ -38,18 +41,23 @@ class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram>

class BinaryElementwise : public WebGpuKernel {
public:
using GetAdditionalImplementationFunction = std::string (*)(int lhs_element_type, int rhs_element_type);

BinaryElementwise(const OpKernelInfo& info,
const std::string& kernel_name,
const std::string& expression) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression} {}
const std::string& expression,
const GetAdditionalImplementationFunction get_additional_impl = nullptr) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression},
get_additional_impl_{get_additional_impl} {}

protected:
Status ComputeInternal(ComputeContext& context) const final;

private:
std::string kernel_name_;
std::string expression_;
const GetAdditionalImplementationFunction get_additional_impl_;
};

} // namespace webgpu
Expand Down
Loading