diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu index 51c80d272bb96..62801c8da1e5f 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu @@ -209,7 +209,7 @@ __device__ void reduce_all( // the size of shared_memory equals to the number of warps. #pragma unroll for (int stride = MAX_NUM_WARPS_PER_BLOCK / 2; stride > 0; stride /= 2) { - if (tid_in_block + stride < num_warps_in_block) { + if (tid_in_block < stride && tid_in_block + stride < num_warps_in_block) { shared_memory[tid_in_block] += shared_memory[tid_in_block + stride]; } __syncthreads(); diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index a96d4c82a7fdc..963fa020d033a 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -585,6 +585,13 @@ size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, static_cast(std::accumulate(output_dims.begin(), output_dims.end(), (int64_t)0)); case UpsampleMode::LINEAR: + // For LINEAR mode: + // - bilinear (2-D/4-D) uses mapping for [H, W] + // - trilinear (3-D/5-D) uses mapping for [D, H, W] + if (output_dims.size() == 3 || output_dims.size() == 5) { + return sizeof(LinearMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 3, (int64_t)0)); + } return sizeof(LinearMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); case UpsampleMode::CUBIC: diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 200a1aded8204..8fd994baec713 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -1015,6 +1015,38 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_pytorch_half_pixel) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: results mismatch } +TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_CudaRegression) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 1.0f, 2.0f, 2.0f, 2.0f}; + + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", "pytorch_half_pixel"); + + constexpr int64_t N = 1, C = 1, D = 3, H = 4, W = 5; + std::vector X(static_cast(N * C * D * H * W), 1.0f); + + test.AddInput("X", {N, C, D, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {5}, scales); + + constexpr int64_t out_D = D * 2; + constexpr int64_t out_H = H * 2; + constexpr int64_t out_W = W * 2; + std::vector Y(static_cast(N * C * out_D * out_H * out_W), 1.0f); + + test.AddOutput("Y", {N, C, out_D, out_H, out_W}, Y); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest) { // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { diff --git a/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc b/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc index ec7e98528504e..593255b9e9c23 100644 --- a/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc @@ -177,6 +177,35 @@ void TestReduceColumnsToColumn(int m, int n, float relative_error_tolerance = 1e CheckDeviceValues(m, d_out.get(), expected_column.data(), relative_error_tolerance); } + +void TestReduceColumnsToColumnRepeated(int m, int n, int iterations, float relative_error_tolerance = 1e-4f) { + SCOPED_TRACE(MakeString("m: ", m, ", n:", n, ", iterations: ", iterations)); + + const TensorShape shape{m, n}; + RandomValueGenerator random{}; + const auto values = random.Uniform(shape.GetDims(), 1.0f, 10.0f); + const auto expected_column = ExpectedReduceMatrixColumnsOutput(m, n, values); + + auto d_in = AllocateDeviceMemory(m * n); + auto d_out = AllocateDeviceMemory(m); + + cudaMemcpy(d_in.get(), values.data(), m * n * sizeof(float), cudaMemcpyHostToDevice); + + size_t buffer_size_in_bytes = + compute_reduce_matrix_columns_buffer_size(m, n); + auto d_buffer = AllocateDeviceMemory(buffer_size_in_bytes); + + for (int i = 0; i < iterations; ++i) { + ASSERT_STATUS_OK(reduce_matrix_columns( + 0, + d_in.get(), d_out.get(), + m, n, + d_buffer.get(), buffer_size_in_bytes)); + + ASSERT_TRUE(CUDA_CALL(cudaDeviceSynchronize()).IsOK()); + CheckDeviceValues(m, d_out.get(), expected_column.data(), relative_error_tolerance); + } +} } // namespace TEST(ReductionFunctionsTest, ReduceRowToScalar) { @@ -205,6 +234,10 @@ TEST(ReductionFunctionsTest, ReduceColumnsToColumn) { } } +TEST(ReductionFunctionsTest, ReduceColumnsToColumnRepeated) { + TestReduceColumnsToColumnRepeated(17, 8192, 100, 2e-4f); +} + TEST(ReductionFunctionsTest, BufferOffsets) { const int m = 2048; const int n = 1024;