diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc index 255ad9cdf66c6..1a56cafdb3952 100644 --- a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc @@ -38,12 +38,50 @@ REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 11, 11); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 12, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMax, 18); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 18, 19); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMax, 20); + +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 11, 11); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 12, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 13, 17); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 18, 19); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMin, 20); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 1, 10); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 11, 12); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSum, 13); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 18); + +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL1, 18); + +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL2, 18); + +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSum, 18); + +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSumSquare, 18); + +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 11, 12); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 13, 17); +REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSumExp, 18); + Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); if (is_input_empty_) { @@ -126,14 +164,28 @@ Status ReduceKernel::ComputeInternal(ComputeContext& context) if (input_axes.empty()) { if (noop_with_empty_axes_ || rank == 0) { // If axes is empty and noop_with_empty_axes_ is true, it is a no-op according to the spec - // If input tensor is a scalar, return the input tensor as is. - // This is not correct for ReduceLogSum and ReduceSumSquare - // TODO handle these cases separately. - auto output = context.Output(0, input_tensor->Shape()); - if (output->DataRaw() != input_tensor->DataRaw()) { - ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output)); + // If input tensor is a scalar and it's not a ReduceLogSum or ReduceSumSquare, return the input tensor as is. + if (rank == 0 && (name_ == "ReduceLogSum" || name_ == "ReduceSumSquare" || name_ == "ReduceL1" || name_ == "ReduceL2")) { + // For ReduceLogSum with scalar input, output = log(input) + // For ReduceSumSquare with scalar input, output = input * input + auto output = context.Output(0, input_tensor->Shape()); + // We need to run the operation even for scalar inputs for these ops + const auto code = GetOpSpecificCode(input_tensor); + ReduceKernelProgram program(name_, keepdims_, noop_with_empty_axes_, input_axes, code, false); + std::vector reduce_axes = {0}; + program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank}) + .SetDispatchGroupSize(1) + .AddUniformVariables({{1}, {static_cast(noop_with_empty_axes_ ? 1 : 0)}, {reduce_axes}}); + return context.RunProgram(program); + } else { + // For other ops, or when axes is empty with noop_with_empty_axes_ true, just copy the input + auto output = context.Output(0, input_tensor->Shape()); + if (output->DataRaw() != input_tensor->DataRaw()) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output)); + } + return Status::OK(); } - return Status::OK(); } else { // If axes is empty and noop_with_empty_axes_ is false, it is a reduction over all axes input_axes.resize(rank); @@ -211,6 +263,14 @@ ReduceOpSpecificCode ReduceMax::GetOpSpecificCode(const Tensor* input_tensor) co ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); return code; } +ReduceOpSpecificCode ReduceMin::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string loop_header = "var min_element = first_element;"; + std::string loop_body = "min_element = min(min_element, current_element);"; + std::string loop_footer = "let output_value = output_value_t(min_element);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} ReduceOpSpecificCode ReduceSum::GetOpSpecificCode(const Tensor* input_tensor) const { ORT_UNUSED_PARAMETER(input_tensor); std::string loop_header = "var sum = f32(0);"; @@ -219,6 +279,54 @@ ReduceOpSpecificCode ReduceSum::GetOpSpecificCode(const Tensor* input_tensor) co ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); return code; } +ReduceOpSpecificCode ReduceProd::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string loop_header = "var prod = f32(1);"; + std::string loop_body = "prod *= f32(current_element);"; + std::string loop_footer = "let output_value = output_value_t(prod);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} +ReduceOpSpecificCode ReduceL1::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string loop_header = "var l1 = f32(0);"; + std::string loop_body = "l1 += abs(f32(current_element));"; + std::string loop_footer = "let output_value = output_value_t(l1);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} +ReduceOpSpecificCode ReduceL2::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string loop_header = "var l2 = f32(0);"; + std::string loop_body = "let t = f32(current_element); l2 += (t * t);"; + std::string loop_footer = "l2 = sqrt(l2); let output_value = output_value_t(l2);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} +ReduceOpSpecificCode ReduceLogSum::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string loop_header = "var sum = f32(0);"; + std::string loop_body = "sum += f32(current_element);"; + std::string loop_footer = "let log_sum = log(sum); let output_value = output_value_t(log_sum);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} +ReduceOpSpecificCode ReduceSumSquare::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string loop_header = "var sum_square = f32(0);"; + std::string loop_body = "let t = f32(current_element); sum_square += (t * t);"; + std::string loop_footer = "let output_value = output_value_t(sum_square);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} +ReduceOpSpecificCode ReduceLogSumExp::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string loop_header = "var sum_exp = f32(0);"; + std::string loop_body = "sum_exp += exp(f32(current_element));"; + std::string loop_footer = "let log_sum_exp = log(sum_exp); let output_value = output_value_t(log_sum_exp);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} } // namespace webgpu -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h index 1c7dba89b7144..291d931f41c05 100644 --- a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h @@ -71,11 +71,53 @@ class ReduceMax final : public ReduceKernel { ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; }; +class ReduceMin final : public ReduceKernel { + public: + ReduceMin(const OpKernelInfo& info) : ReduceKernel(info, "ReduceMin") {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + class ReduceSum final : public ReduceKernel { public: ReduceSum(const OpKernelInfo& info) : ReduceKernel(info, "ReduceSum", true) {} ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; }; +class ReduceProd final : public ReduceKernel { + public: + ReduceProd(const OpKernelInfo& info) : ReduceKernel(info, "ReduceProd", true) {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + +class ReduceL1 final : public ReduceKernel { + public: + ReduceL1(const OpKernelInfo& info) : ReduceKernel(info, "ReduceL1", true) {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + +class ReduceL2 final : public ReduceKernel { + public: + ReduceL2(const OpKernelInfo& info) : ReduceKernel(info, "ReduceL2", true) {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + +class ReduceLogSum final : public ReduceKernel { + public: + ReduceLogSum(const OpKernelInfo& info) : ReduceKernel(info, "ReduceLogSum", true) {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + +class ReduceSumSquare final : public ReduceKernel { + public: + ReduceSumSquare(const OpKernelInfo& info) : ReduceKernel(info, "ReduceSumSquare", true) {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + +class ReduceLogSumExp final : public ReduceKernel { + public: + ReduceLogSumExp(const OpKernelInfo& info) : ReduceKernel(info, "ReduceLogSumExp", true) {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 0ff07f0581475..71d46bc4fb2cc 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -144,7 +144,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 19, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); @@ -155,7 +156,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 19, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, ReduceMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); @@ -517,7 +519,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -528,45 +531,46 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), KERNEL_CREATE_INFO(16, Where), diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 61a16d41e3e59..4bc97d035c7f7 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -6044,6 +6044,7 @@ void test_empty_set(const std::string& op, int opset, bool axes_as_input, float kQnnExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, + kWebGpuExecutionProvider, }); }