diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc index 1a56cafdb3952..a0213f63494d3 100644 --- a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace webgpu { -#define REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, begin, end) \ +#define REGISTER_REDUCE_VERSIONED_KERNEL(ReduceOp, begin, end) \ ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ ReduceOp, \ kOnnxDomain, \ @@ -20,7 +20,16 @@ namespace webgpu { (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()), \ ReduceOp); -#define REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceOp, version) \ +#define REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceOp, begin, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + begin, end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1), \ + ReduceOp); + +#define REGISTER_REDUCE_KERNEL(ReduceOp, version) \ ONNX_OPERATOR_KERNEL_EX( \ ReduceOp, \ kOnnxDomain, \ @@ -29,58 +38,58 @@ namespace webgpu { (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1), \ ReduceOp); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 18); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMean, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMean, 11, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMean, 13, 17); +REGISTER_REDUCE_KERNEL(ReduceMean, 18); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 11, 11); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 12, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 13, 17); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 18, 19); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMax, 20); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMax, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMax, 11, 11); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMax, 12, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMax, 13, 17); +REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMax, 18, 19); +REGISTER_REDUCE_KERNEL(ReduceMax, 20); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 11, 11); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 12, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 13, 17); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 18, 19); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMin, 20); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMin, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMin, 11, 11); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMin, 12, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMin, 13, 17); +REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMin, 18, 19); +REGISTER_REDUCE_KERNEL(ReduceMin, 20); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 11, 12); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSum, 13); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSum, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSum, 11, 12); +REGISTER_REDUCE_KERNEL(ReduceSum, 13); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 11, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 18); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceProd, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceProd, 11, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceProd, 13, 17); +REGISTER_REDUCE_KERNEL(ReduceProd, 18); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 11, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL1, 18); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL1, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL1, 11, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL1, 13, 17); +REGISTER_REDUCE_KERNEL(ReduceL1, 18); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 11, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL2, 18); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL2, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL2, 11, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL2, 13, 17); +REGISTER_REDUCE_KERNEL(ReduceL2, 18); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 11, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSum, 18); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSum, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSum, 11, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSum, 13, 17); +REGISTER_REDUCE_KERNEL(ReduceLogSum, 18); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 11, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSumSquare, 18); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSumSquare, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSumSquare, 11, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSumSquare, 13, 17); +REGISTER_REDUCE_KERNEL(ReduceSumSquare, 18); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 1, 10); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 11, 12); -REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 13, 17); -REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSumExp, 18); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 11, 12); +REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 13, 17); +REGISTER_REDUCE_KERNEL(ReduceLogSumExp, 18); Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index e261d66a0d22a..d62ffe644e4cc 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -694,7 +694,9 @@ "^test_gelu_tanh_2_expanded_cpu", "^test_dynamicquantizelinear_expanded_cpu", "^test_center_crop_pad_crop_negative_axes_hwc*", // failed due to new types or shape infer with negative axis for CenterCropPad. - "^test_center_crop_pad_crop_negative_axes_hwc_expanded*" // failed due to new types or shape infer with negative axis for CenterCropPad. + "^test_center_crop_pad_crop_negative_axes_hwc_expanded*", // failed due to new types or shape infer with negative axis for CenterCropPad. + "^test_reduce_max_empty_set", + "^test_reduce_min_empty_set" ], "current_failing_tests_pure_DML": [ "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu",