diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 1b4e524e2aef6..4568c0785db96 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -154,7 +154,7 @@ Status ReduceKernel::ReduceKernelShared( m, n, false); } case ApplicableMatrixReduction::Columns: - // don't call reduce_matrix_columns() since it will reset initial output data + // don't call reduce_matrix_columns() since it will reset initial output data default: break; } @@ -600,24 +600,30 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr } } } else { // For ArgMax & ArgMin ops, use the indicies as the output with int64 type - if (temp_X) { - auto temp_output = cuda_ep.GetScratchBuffer(output_count); - CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( - cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes, - workspace_cuda.get(), workspace_bytes, - &one, input_tensor, temp_X.get(), - &zero, output_tensor, temp_output.get())); + // cudnnReduceTensor has issue if input and output has same size, which will happen if the axis to be reduced has dim value of 1. + // the output is zeros of the output size + if (input_count == output_count) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.template MutableData(), static_cast(0), output_count * sizeof(int64_t))); } else { - auto temp_output = cuda_ep.GetScratchBuffer(output_count); - CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( - cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes, - workspace_cuda.get(), workspace_bytes, - &one, input_tensor, reinterpret_cast(input.template Data()), - &zero, output_tensor, temp_output.get())); - } + if (temp_X) { + auto temp_output = cuda_ep.GetScratchBuffer(output_count); + CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( + cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes, + workspace_cuda.get(), workspace_bytes, + &one, input_tensor, temp_X.get(), + &zero, output_tensor, temp_output.get())); + } else { + auto temp_output = cuda_ep.GetScratchBuffer(output_count); + CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( + cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes, + workspace_cuda.get(), workspace_bytes, + &one, input_tensor, reinterpret_cast(input.template Data()), + &zero, output_tensor, temp_output.get())); + } - // CUDA reduction index is uint32_t for now, cast it to int64_t according to ONNX spec - Impl_Cast(reinterpret_cast(indices_cuda.get()), output.template MutableData(), output_count); + // CUDA reduction index is uint32_t for now, cast it to int64_t according to ONNX spec + Impl_Cast(reinterpret_cast(indices_cuda.get()), output.template MutableData(), output_count); + } } if (calculate_log) { diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 0241f731ab0b9..47507d35dae61 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -2060,6 +2060,18 @@ TEST(ReductionOpTest, ArgMax2D_select_last) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(ReductionOpTest, ArgMax2D_dim1) { + OpTester test("ArgMax", 11); + test.AddAttribute("axis", (int64_t)1); + test.AddInput("data", {3, 1}, + {1.0f, + 6.0f, + 9.0f}); + test.AddOutput("reduced", {3, 1}, + {0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(ReductionOpTest, ArgMin) { OpTester test("ArgMin"); test.AddAttribute("axis", (int64_t)0); diff --git a/tools/ci_build/github/pai/pai-excluded-tests.txt b/tools/ci_build/github/pai/pai-excluded-tests.txt index 6ed3d24ce7bd2..db3a89d5c3f7a 100644 --- a/tools/ci_build/github/pai/pai-excluded-tests.txt +++ b/tools/ci_build/github/pai/pai-excluded-tests.txt @@ -114,6 +114,7 @@ ReductionOpTest.ArgMax_Double_Type ReductionOpTest.ArgMax_do_not_keepdims ReductionOpTest.ArgMax_do_not_keepdims_2 ReductionOpTest.ArgMax2D +ReductionOpTest.ArgMax2D_dim1 ReductionOpTest.ArgMin ReductionOpTest.ArgMin_Double_Type ReductionOpTest.ArgMin_Double_Precision