diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 5e00ce6857d45..6e0586e772334 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -153,6 +153,70 @@ __global__ void _ResizeNearestMappingKernel2D( } } +template +__global__ void _ResizeNearestMappingKernel3D( + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const float scales_depth, const float scales_height, const float scales_width, + const float roi_start_depth, const float roi_end_depth, + const float roi_start_height, const float roi_end_height, + const float roi_start_width, const float roi_end_width, + const bool extrapolation_enabled, + const CudaFunctionOriginalCoordinate& transform_coordinate, + const CudaFunctionNearestPixel& calc_nearest_pixel, + NearestMappingInfo* dims_mapping) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, output_depth + output_height + output_width); + if (id < output_depth) { // for Depth + int dim = id; + if (scales_depth == 1.0f) { + dims_mapping[id].extrapolate_ = 0; + } else { + float orig_coord = transform_coordinate(static_cast(dim), scales_depth, + static_cast(output_depth), + static_cast(input_depth), + roi_start_depth, roi_end_depth); + dims_mapping[id].extrapolate_ = static_cast( + extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_depth - 1))); + dim = calc_nearest_pixel(orig_coord, scales_depth < 1); + if (dim >= input_depth) dim = input_depth - 1; + if (dim < 0) dim = 0; + } + dims_mapping[id].origin_ = dim; + } else if (id < output_depth + output_height) { // for Height + int dim = id - output_depth; + if (scales_height == 1.0f) { + dims_mapping[id].extrapolate_ = 0; + } else { + float orig_coord = transform_coordinate(static_cast(dim), scales_height, + static_cast(output_height), + static_cast(input_height), + roi_start_height, roi_end_height); + dims_mapping[id].extrapolate_ = static_cast( + extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_height - 1))); + dim = calc_nearest_pixel(orig_coord, scales_height < 1); + if (dim >= input_height) dim = input_height - 1; + if (dim < 0) dim = 0; + } + dims_mapping[id].origin_ = dim; + } else { // for Width + int dim = id - output_depth - output_height; + if (scales_width == 1.0f) { + dims_mapping[id].extrapolate_ = 0; + } else { + float orig_coord = transform_coordinate(static_cast(dim), scales_width, + static_cast(output_width), + static_cast(input_width), + roi_start_width, roi_end_width); + dims_mapping[id].extrapolate_ = static_cast( + extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_width - 1))); + dim = calc_nearest_pixel(orig_coord, scales_width < 1); + if (dim >= input_width) dim = input_width - 1; + if (dim < 0) dim = 0; + } + dims_mapping[id].origin_ = dim; + } +} + template __global__ void _ResizeNearestMappingKernel( const size_t rank, @@ -221,6 +285,34 @@ __global__ void _ResizeNearestKernel2D( output_data[id] = input_data[input_index]; } +template +__global__ void _ResizeNearestKernel3D( + const int64_t output_depth, const int64_t output_height, const int64_t output_width, + const int64_t input_stride_image, const int64_t input_stride_depth, const int input_stride_row, + const fast_divmod output_stride_image, const fast_divmod output_stride_depth, const fast_divmod output_stride_row, + const T* input_data, T* output_data, const size_t N, + const T extrapolation_value, const NearestMappingInfo* dims_mapping) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + int imageid, d, h, w, output_index, temp; + output_stride_image.divmod(static_cast(id), imageid, output_index); + output_stride_depth.divmod(output_index, d, temp); + output_stride_row.divmod(temp, h, w); + if (UseExtrapolation) { + if (dims_mapping[d].extrapolate_ + + dims_mapping[output_depth + h].extrapolate_ + + dims_mapping[output_depth + output_height + w].extrapolate_) { + output_data[id] = extrapolation_value; + return; + } + } + int input_index = static_cast(input_stride_image) * imageid + + static_cast(input_stride_depth) * dims_mapping[d].origin_ + + input_stride_row * dims_mapping[output_depth + h].origin_ + + dims_mapping[output_depth + output_height + w].origin_; + output_data[id] = input_data[input_index]; +} + template __global__ void _ResizeNearestKernel( const int rank, @@ -667,6 +759,56 @@ void ResizeNearestImpl( return; } + // Check if we can use the optimized 3D path: rank >= 3, not TF_CROP_AND_RESIZE, + // and all outer dimensions (except last 3) have scale == 1.0 + bool could3d = rank >= 3 && + transform_coordinate != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE && + std::all_of(scales_vals.Data(), scales_vals.Data() + (rank - 3), [](float v) { return v == 1.0; }); + if (could3d) { + int64_t output_depth = output_shape[rank - 3]; + int64_t output_height = output_shape[rank - 2]; + int64_t output_width = output_shape[rank - 1]; + fast_divmod div_output_image = (rank > 3) ? output_div_pitches[rank - 4] + : fast_divmod(static_cast(output_depth * output_height * output_width)); + int blocksPerDimsMappingGrid = static_cast(ceil((output_depth + output_height + output_width) / 32.0)); + + DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(transform_coordinate, [&]() { + DISPATCH_RESIZE_NEAREST_MODE(calc_nearest_pixel, [&]() { + _ResizeNearestMappingKernel3D<<>>( + static_cast(input_shape[rank - 3]), static_cast(input_shape[rank - 2]), + static_cast(input_shape[rank - 1]), + static_cast(output_depth), static_cast(output_height), static_cast(output_width), + scales_vals[rank - 3], scales_vals[rank - 2], scales_vals[rank - 1], + roi_vals[rank - 3], roi_vals[rank - 3 + rank], + roi_vals[rank - 2], roi_vals[rank - 2 + rank], + roi_vals[rank - 1], roi_vals[rank - 1 + rank], + extrapolation_enabled, coord_t(), nearest_t(), + dims_mapping); + }); + }); + + int64_t input_stride_depth = input_shape[rank - 2] * input_shape[rank - 1]; + int64_t input_stride_image = input_shape[rank - 3] * input_stride_depth; + if (extrapolation_enabled) { + _ResizeNearestKernel3D<<>>( + output_depth, output_height, output_width, + input_stride_image, input_stride_depth, static_cast(input_shape[rank - 1]), + div_output_image, output_div_pitches[rank - 3], output_div_pitches[rank - 2], + input_data, output_data, N, + extrapolation_value, + dims_mapping); + } else { + _ResizeNearestKernel3D<<>>( + output_depth, output_height, output_width, + input_stride_image, input_stride_depth, static_cast(input_shape[rank - 1]), + div_output_image, output_div_pitches[rank - 3], output_div_pitches[rank - 2], + input_data, output_data, N, + extrapolation_value, + dims_mapping); + } + return; + } + int64_t total_dim_sum = std::accumulate(output_shape.Data(), output_shape.Data() + rank, (int64_t)0); int blocksPerDimsMappingGrid = (int)(ceil(static_cast(total_dim_sum) / 32)); DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(transform_coordinate, [&]() { diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 8fd994baec713..3129476b1b505 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -1316,6 +1316,87 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample5dTest_WithSizes_CeilMode) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); } +TEST(ResizeOpTest, ResizeOpNearestUpSampleTest_5D_CudaRegression_Optimized3DMapping) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 1.0f, 1.5f, 1.5f, 1.5f}; + + test.AddAttribute("mode", "nearest"); + test.AddAttribute("coordinate_transformation_mode", "asymmetric"); + test.AddAttribute("nearest_mode", "floor"); + + constexpr int64_t N = 1, C = 1, D = 2, H = 2, W = 2; + std::vector X = { + 1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f, + 7.0f, 8.0f}; + + test.AddInput("X", {N, C, D, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {5}, scales); + + std::vector Y = { + 1.0f, 1.0f, 2.0f, + 1.0f, 1.0f, 2.0f, + 3.0f, 3.0f, 4.0f, + + 1.0f, 1.0f, 2.0f, + 1.0f, 1.0f, 2.0f, + 3.0f, 3.0f, 4.0f, + + 5.0f, 5.0f, 6.0f, + 5.0f, 5.0f, 6.0f, + 7.0f, 7.0f, 8.0f}; + + test.AddOutput("Y", {N, C, 3, 3, 3}, Y); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ResizeOpTest, ResizeOpNearestDownSampleTest_5D_CudaRegression_Optimized3DMapping) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 1.0f, 0.5f, 0.5f, 0.5f}; + + test.AddAttribute("mode", "nearest"); + test.AddAttribute("coordinate_transformation_mode", "asymmetric"); + test.AddAttribute("nearest_mode", "floor"); + + constexpr int64_t N = 1, C = 1, D = 4, H = 4, W = 4; + std::vector X(64); + std::iota(X.begin(), X.end(), 1.0f); + + test.AddInput("X", {N, C, D, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {5}, scales); + + std::vector Y = { + 1.0f, 3.0f, + 9.0f, 11.0f, + + 33.0f, 35.0f, + 41.0f, 43.0f}; + + test.AddOutput("Y", {N, C, 2, 2, 2}, Y); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { OpTester test("Resize", 13);