diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 13004af25726d..6891b8159b090 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -4,15 +4,18 @@ #include "core/providers/common.h" #include "core/providers/webgpu/math/binary_elementwise_ops.h" #include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/string_macros.h" #include "core/providers/webgpu/webgpu_supported_types.h" namespace onnxruntime { namespace webgpu { Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AdditionalImplementation() << additional_impl_; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); // check whether can use element-wise mode. @@ -142,8 +145,15 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { } uint32_t vec_size = onnxruntime::narrow((size + 3) / 4); + + std::string additional_impl; + if (get_additional_impl_) { + additional_impl = get_additional_impl_(lhs_tensor->GetElementType(), rhs_tensor->GetElementType()); + } + BinaryElementwiseProgram program{kernel_name_, expression_, + additional_impl, is_broadcast, is_lhs_scalar, is_rhs_scalar, @@ -273,7 +283,28 @@ WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes()) WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes()) WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes()) -WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(vec4(a), vec4(b)))") +std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { + SS(s, 1024); + std::string round_str; + if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { + round_str = "round"; + } + + s << "fn pow_custom(a : input_a_element_t, b : f32) -> input_a_element_t {\n" + " if (b == 0.0) {\n" + " return input_a_element_t(1.0);\n" + " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" + " return input_a_element_t(pow(f32(a), b)); // NaN\n" + " }\n" + << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" + << "}\n" + "fn pow_v(a : vec4, b : vec4) -> vec4 {\n" + " return vec4(pow_custom(a.x, f32(b.x)), pow_custom(a.y, f32(b.y)), pow_custom(a.z, f32(b.z)), pow_custom(a.w, f32(b.w)));\n" + "}\n"; + return SS_GET(s); +} + +WEBGPU_BINARY_IMPL(Pow, "pow_v(a, b)", GetPowImpl) WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes()) WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h index 84cbcdf3244d8..f80accfb934f8 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h @@ -14,11 +14,13 @@ class BinaryElementwiseProgram final : public Program public: BinaryElementwiseProgram(const std::string& kernel_name, const std::string& expression, + const std::string& additional_impl, const bool is_broadcast, const bool is_lhs_scalar, const bool is_rhs_scalar, const bool vectorize) : Program{kernel_name}, expression_{expression}, + additional_impl_{additional_impl}, is_broadcast_{is_broadcast}, is_lhs_scalar_{is_lhs_scalar}, is_rhs_scalar_{is_rhs_scalar}, @@ -29,7 +31,8 @@ class BinaryElementwiseProgram final : public Program WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); private: - std::string expression_; + std::string_view expression_; + std::string_view additional_impl_; bool is_broadcast_; bool is_lhs_scalar_; bool is_rhs_scalar_; @@ -38,11 +41,15 @@ class BinaryElementwiseProgram final : public Program class BinaryElementwise : public WebGpuKernel { public: + using GetAdditionalImplementationFunction = std::string (*)(int lhs_element_type, int rhs_element_type); + BinaryElementwise(const OpKernelInfo& info, const std::string& kernel_name, - const std::string& expression) : WebGpuKernel{info}, - kernel_name_{kernel_name}, - expression_{expression} {} + const std::string& expression, + const GetAdditionalImplementationFunction get_additional_impl = nullptr) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + get_additional_impl_{get_additional_impl} {} protected: Status ComputeInternal(ComputeContext& context) const final; @@ -50,6 +57,7 @@ class BinaryElementwise : public WebGpuKernel { private: std::string kernel_name_; std::string expression_; + const GetAdditionalImplementationFunction get_additional_impl_; }; } // namespace webgpu