Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ __device__ void reduce_all(
// the size of shared_memory equals to the number of warps.
#pragma unroll
for (int stride = MAX_NUM_WARPS_PER_BLOCK / 2; stride > 0; stride /= 2) {
if (tid_in_block + stride < num_warps_in_block) {
if (tid_in_block < stride && tid_in_block + stride < num_warps_in_block) {
shared_memory[tid_in_block] += shared_memory[tid_in_block + stride];
}
__syncthreads();
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/resize_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,13 @@ size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode,
static_cast<size_t>(std::accumulate(output_dims.begin(),
output_dims.end(), (int64_t)0));
case UpsampleMode::LINEAR:
// For LINEAR mode:
// - bilinear (2-D/4-D) uses mapping for [H, W]
// - trilinear (3-D/5-D) uses mapping for [D, H, W]
if (output_dims.size() == 3 || output_dims.size() == 5) {
return sizeof(LinearMappingInfo) *
static_cast<size_t>(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 3, (int64_t)0));
}
return sizeof(LinearMappingInfo) *
static_cast<size_t>(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0));
case UpsampleMode::CUBIC:
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/resize_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,38 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_pytorch_half_pixel) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: results mismatch
}

TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_CudaRegression) {
auto cuda_ep = DefaultCudaExecutionProvider();
if (!cuda_ep) {
GTEST_SKIP() << "CUDA EP not available";
}

OpTester test("Resize", 13);
std::vector<float> roi{};
std::vector<float> scales{1.0f, 1.0f, 2.0f, 2.0f, 2.0f};

test.AddAttribute("mode", "linear");
test.AddAttribute("coordinate_transformation_mode", "pytorch_half_pixel");

constexpr int64_t N = 1, C = 1, D = 3, H = 4, W = 5;
std::vector<float> X(static_cast<size_t>(N * C * D * H * W), 1.0f);

test.AddInput<float>("X", {N, C, D, H, W}, X);
test.AddInput<float>("roi", {0}, roi);
test.AddInput<float>("scales", {5}, scales);

constexpr int64_t out_D = D * 2;
constexpr int64_t out_H = H * 2;
constexpr int64_t out_W = W * 2;
std::vector<float> Y(static_cast<size_t>(N * C * out_D * out_H * out_W), 1.0f);

test.AddOutput<float>("Y", {N, C, out_D, out_H, out_W}, Y);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(std::move(cuda_ep));
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest) {
// To test NNAPI EP, we need the scales/sizes to be in initializers
auto run_test = [](bool scales_in_initializer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,35 @@ void TestReduceColumnsToColumn(int m, int n, float relative_error_tolerance = 1e

CheckDeviceValues(m, d_out.get(), expected_column.data(), relative_error_tolerance);
}

void TestReduceColumnsToColumnRepeated(int m, int n, int iterations, float relative_error_tolerance = 1e-4f) {
SCOPED_TRACE(MakeString("m: ", m, ", n:", n, ", iterations: ", iterations));

const TensorShape shape{m, n};
RandomValueGenerator random{};
const auto values = random.Uniform<float>(shape.GetDims(), 1.0f, 10.0f);
const auto expected_column = ExpectedReduceMatrixColumnsOutput(m, n, values);

auto d_in = AllocateDeviceMemory<float>(m * n);
auto d_out = AllocateDeviceMemory<float>(m);

cudaMemcpy(d_in.get(), values.data(), m * n * sizeof(float), cudaMemcpyHostToDevice);

size_t buffer_size_in_bytes =
compute_reduce_matrix_columns_buffer_size<float>(m, n);
auto d_buffer = AllocateDeviceMemory<char>(buffer_size_in_bytes);

for (int i = 0; i < iterations; ++i) {
ASSERT_STATUS_OK(reduce_matrix_columns(
0,
d_in.get(), d_out.get(),
m, n,
d_buffer.get(), buffer_size_in_bytes));

ASSERT_TRUE(CUDA_CALL(cudaDeviceSynchronize()).IsOK());
CheckDeviceValues(m, d_out.get(), expected_column.data(), relative_error_tolerance);
}
}
} // namespace

TEST(ReductionFunctionsTest, ReduceRowToScalar) {
Expand Down Expand Up @@ -205,6 +234,10 @@ TEST(ReductionFunctionsTest, ReduceColumnsToColumn) {
}
}

TEST(ReductionFunctionsTest, ReduceColumnsToColumnRepeated) {
TestReduceColumnsToColumnRepeated(17, 8192, 100, 2e-4f);
}

TEST(ReductionFunctionsTest, BufferOffsets) {
const int m = 2048;
const int n = 1024;
Expand Down
Loading