From e603613733d2ad937117bb322ccb3751797ae29f Mon Sep 17 00:00:00 2001 From: gs Date: Wed, 25 Feb 2026 09:31:05 -0800 Subject: [PATCH 1/2] softplus support for webgpu --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 4 ++++ .../core/providers/webgpu/webgpu_execution_provider.cc | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index e16327b9facad..5406dd6053abd 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -291,5 +291,9 @@ WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.attr))", "", 1.0f) WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_IMPL(Softplus, "select(log(1.0 + exp(a)), a + log(1.0 + exp(-a)), a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Softplus, 1, 21, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Softplus, 22, WebGpuSupportedFloatTypes()) + } // namespace webgpu } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 1891775c45057..c5447465d814e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -146,6 +146,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, Gelu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 21, Softplus); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 22, Softplus); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); @@ -496,6 +498,8 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals KERNEL_CREATE_INFO(16, LeakyRelu), KERNEL_CREATE_INFO(10, ThresholdedRelu), KERNEL_CREATE_INFO(20, Gelu), + KERNEL_CREATE_INFO_VERSIONED(1, 21, Softplus), + KERNEL_CREATE_INFO(22, Softplus), // // binary - math KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), From c2bdb8003f8c8566ad6d5cf39457829cf9befd25 Mon Sep 17 00:00:00 2001 From: gs Date: Thu, 26 Feb 2026 09:23:49 -0800 Subject: [PATCH 2/2] handle overflow for float32 and float16 --- .../webgpu/math/unary_elementwise_ops.cc | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 5406dd6053abd..7db65a42eac4d 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -291,9 +291,31 @@ WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.attr))", "", 1.0f) WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Softplus, "select(log(1.0 + exp(a)), a + log(1.0 + exp(-a)), a > x_value_t(0))", "", ShaderUsage::UseValueTypeAlias) +// For large a, softplus(a) = log(1 + exp(a)) ≈ a. Use a threshold to return a directly, +// avoiding unnecessary exp/log computation and potential overflow. +// PyTorch uses threshold=20 for float32. For float16, exp overflows at ~11.09 so use 11. +class Softplus final : public UnaryElementwise { + public: + Softplus(const OpKernelInfo& info) + : UnaryElementwise{info, "Softplus", + "select(" + "select(log(1.0 + exp(a)), a + log(1.0 + exp(-a)), a > x_value_t(0))," + "a," + "a > x_value_t(x_element_t(uniforms.attr))" + ")", + "", + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias} {} + + Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { + const auto* input_tensor = context.Input(0); + float threshold = input_tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 ? 11.0f : 20.0f; + program.AddUniformVariables({threshold}); + return Status::OK(); + } +}; + WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Softplus, 1, 21, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Softplus, 22, WebGpuSupportedFloatTypes()) } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime