From 4790b0949a432432c44ddeff5e8959462f5ab32f Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Wed, 18 Feb 2026 00:23:14 +0800 Subject: [PATCH 01/58] Add deform conv 2d cpu execution provider support --- .../providers/cpu/cpu_execution_provider.cc | 4 + .../core/providers/cpu/nn/deform_conv.cc | 336 ++++++++++++++++ .../core/providers/cpu/nn/deform_conv.h | 48 +++ .../cpu/nn/deform_conv_expected_gen.py | 105 +++++ .../providers/cpu/nn/deform_conv_op_test.cc | 366 ++++++++++++++++++ 5 files changed, 859 insertions(+) create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv.cc create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv.h create mode 100644 onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py create mode 100644 onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 74b8f8e468097..71bd94e9d8a55 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1220,6 +1220,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, DeformConv); // Opset 20 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, 20, ConstantOfShape); @@ -1316,6 +1317,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Ac class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Atanh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, DeformConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Det); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_float, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_double, Dropout); @@ -3277,6 +3279,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Resize)>, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 20 BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc new file mode 100644 index 0000000000000..7bca8f3e48301 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -0,0 +1,336 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CPU implementation of DeformConv (deformable convolution 2D). + +#include "deform_conv.h" + +#include +#include + +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/common/safeint.h" +#include "core/util/math.h" + +namespace onnxruntime { + +namespace { + +// Bilinear interpolation at (h, w). Returns 0 if out of bounds. +template +T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { + // Check boundaries + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return static_cast(0); + } + + const int h_low = static_cast(std::floor(h)); + const int w_low = static_cast(std::floor(w)); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const T lh = h - static_cast(h_low); + const T lw = w - static_cast(w_low); + const T hh = static_cast(1) - lh; + const T hw = static_cast(1) - lw; + + const T v1 = (h_low >= 0 && w_low >= 0) ? in[h_low * width + w_low] : static_cast(0); + const T v2 = (h_low >= 0 && w_high < width) ? in[h_low * width + w_high] : static_cast(0); + const T v3 = (h_high < height && w_low >= 0) ? in[h_high * width + w_low] : static_cast(0); + const T v4 = (h_high < height && w_high < width) ? in[h_high * width + w_high] : static_cast(0); + + const T w1 = hh * hw; + const T w2 = hh * lw; + const T w3 = lh * hw; + const T w4 = lh * lw; + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +// Deformable Im2Col for a SINGLE image. +// Converts the input image into a matrix suitable for GEMM. +// Output 'data_col' shape: [C_in * kH * kW, H_out * W_out] +template +void DeformableIm2col( + const T* data_im, // Input image [C, H, W] + const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] + const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (optional) + int64_t height, int64_t width, // Input dimensions + int64_t kernel_h, int64_t kernel_w, // Kernel dimensions + int64_t pad_h, int64_t pad_w, // Padding + int64_t stride_h, int64_t stride_w, // Stride + int64_t dilation_h, int64_t dilation_w, // Dilation + int64_t channels, // Input channels + int64_t offset_groups, // Number of offset groups + int64_t height_col, int64_t width_col, // Output dimensions + bool use_mask, + T* data_col) { // Output buffer + + const int64_t channel_per_offset_group = channels / offset_groups; + + // We iterate over the output matrix columns (spatial locations) + // and fill the matrix rows (channels * kernels). + // Note: Parallelization can be applied here over 'c_col' (spatial index). + + for (int64_t c_col = 0; c_col < height_col * width_col; ++c_col) { + const int64_t w_col = c_col % width_col; + const int64_t h_col = c_col / width_col; + + // For each spatial location (h_col, w_col), we iterate over all input channels + for (int64_t c_im = 0; c_im < channels; ++c_im) { + const int64_t offset_grp = c_im / channel_per_offset_group; + + // Iterate over kernel window + for (int64_t i = 0; i < kernel_h; ++i) { + for (int64_t j = 0; j < kernel_w; ++j) { + + // Calculate the index in the offset/mask tensors. + // The offset tensor is organized as: (offset_groups, 2 * kH * kW, H_out, W_out). + // Flattened offset channel index relative to the start of the tensor: + // base = offset_grp * (2 * kH * kW). + // specific = 2 * (i * kW + j). + + const int64_t data_offset_h_ptr = + ((offset_grp * (2 * kernel_h * kernel_w) + 2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + + const int64_t data_offset_w_ptr = + ((offset_grp * (2 * kernel_h * kernel_w) + 2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + + const int64_t data_mask_ptr = + ((offset_grp * (kernel_h * kernel_w) + (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + + const T offset_h = data_offset[data_offset_h_ptr]; + const T offset_w = data_offset[data_offset_w_ptr]; + + T val = static_cast(0); + T mask_val = static_cast(1); + if (use_mask) { + mask_val = data_mask[data_mask_ptr]; + } + + // Only compute interpolation if mask is not zero (optimization) + if (mask_val != 0) { + const T h_im = h_col * stride_h - pad_h + i * dilation_h + offset_h; + const T w_im = w_col * stride_w - pad_w + j * dilation_w + offset_w; + + // Map (c_im, h_im, w_im) back to input + // data_im is [C, H, W] + const T* data_im_ptr = data_im + c_im * (height * width); + val = BilinearInterpolate(data_im_ptr, height, width, h_im, w_im); + } + + // Assign to data_col + // The layout of data_col row is: [Channel, KernelH, KernelW] flattened. + // Row index: c_im * (kH * kW) + i * kW + j + const int64_t col_row_idx = (c_im * kernel_h * kernel_w) + (i * kernel_w + j); + + data_col[col_row_idx * (height_col * width_col) + c_col] = val * mask_val; + } + } + } + } +} + +} // namespace + +template +Status DeformConv::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* offset = context->Input(2); + const auto* B = context->Input(3); // optional + const auto* mask = context->Input(4); // optional + + const auto& X_shape = X->Shape(); + const auto& W_shape = W->Shape(); + const auto& offset_shape = offset->Shape(); + + // Validate Input Shapes + const int64_t N = X_shape[0]; + const int64_t C = X_shape[1]; + const int64_t H = X_shape[2]; + const int64_t W_in = X_shape[3]; + + const int64_t M = W_shape[0]; // out channels + // Handle kernel shape inference + const int64_t kH = attrs_.kernel_shape.size() >= 1 ? attrs_.kernel_shape[0] : W_shape[2]; + const int64_t kW = attrs_.kernel_shape.size() >= 2 ? attrs_.kernel_shape[1] : W_shape[3]; + + int64_t pad_h = 0; + int64_t pad_w = 0; + int64_t pad_h_end = 0; + int64_t pad_w_end = 0; + if (attrs_.pads.size() >= 4) { + pad_h = attrs_.pads[0]; + pad_w = attrs_.pads[1]; + pad_h_end = attrs_.pads[2]; + pad_w_end = attrs_.pads[3]; + } + + const int64_t stride_h = attrs_.strides.empty() ? 1 : attrs_.strides[0]; + const int64_t stride_w = attrs_.strides.size() < 2 ? 1 : attrs_.strides[1]; + const int64_t dilation_h = attrs_.dilations.empty() ? 1 : attrs_.dilations[0]; + const int64_t dilation_w = attrs_.dilations.size() < 2 ? 1 : attrs_.dilations[1]; + const int64_t group = attrs_.group; + const int64_t offset_group = attrs_.offset_group; + + const int64_t out_h = (H + pad_h + pad_h_end - dilation_h * (kH - 1) - 1) / stride_h + 1; + const int64_t out_w = (W_in + pad_w + pad_w_end - dilation_w * (kW - 1) - 1) / stride_w + 1; + + // Checks + ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); + ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); + ORT_RETURN_IF_NOT(offset_shape[1] == offset_group * 2 * kH * kW, + "Offset channel count must be offset_group * 2 * kH * kW."); + ORT_RETURN_IF_NOT(offset_shape[2] == out_h, "Offset spatial height must match output oH."); + ORT_RETURN_IF_NOT(offset_shape[3] == out_w, "Offset spatial width must match output oW."); + ORT_RETURN_IF_NOT(C % offset_group == 0, "Input channels must be divisible by offset_group."); + ORT_RETURN_IF_NOT(C == W_shape[1] * group, "Input channels must match weight in channels * group."); + ORT_RETURN_IF_NOT(M % group == 0, "Output channels must be divisible by group."); + + const bool use_mask = (mask != nullptr); + if (use_mask) { + ORT_RETURN_IF_NOT(mask->Shape().NumDimensions() == 4, "Mask must be 4D."); + ORT_RETURN_IF_NOT(mask->Shape()[1] == offset_group * kH * kW, "Mask channel count must be offset_group * kH * kW."); + ORT_RETURN_IF_NOT(mask->Shape()[2] == out_h, "Mask spatial height must match output oH."); + ORT_RETURN_IF_NOT(mask->Shape()[3] == out_w, "Mask spatial width must match output oW."); + } + + // Allocate Output + const TensorShape Y_shape({N, M, out_h, out_w}); + Tensor* Y = context->Output(0, Y_shape); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + // Common sizes + const int64_t kernel_size = kH * kW; + const int64_t output_image_size = out_h * out_w; + const int64_t input_image_size = H * W_in; + const int64_t kernel_dim = C / group * kernel_size; // The "K" dimension for GEMM (per group) + + // Total col buffer size: (C * kH * kW) * (out_h * out_w) + // We allocate this per image to save memory compared to batch allocation if N is large, + // or simply because Im2Col is easier to implement per-image. + const int64_t col_buffer_size = (C * kernel_size) * output_image_size; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); + + const T* Xdata = X->Data(); + const T* Wdata = W->Data(); + const T* offset_data = offset->Data(); + const T* mask_data = use_mask ? mask->Data() : nullptr; + T* Ydata = Y->MutableData(); + const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + // Main Loop: Iterate over Batch + for (int64_t n = 0; n < N; ++n) { + + // 1. Perform Im2Col for the current image n + // Pointers for current image + const T* X_curr = Xdata + n * (C * input_image_size); + const T* offset_curr = offset_data + n * (offset_group * 2 * kernel_size * output_image_size); + const T* mask_curr = use_mask ? (mask_data + n * (offset_group * kernel_size * output_image_size)) : nullptr; + T* col_buffer_ptr = col_buffer.get(); + + DeformableIm2col( + X_curr, + offset_curr, + mask_curr, + H, W_in, + kH, kW, + pad_h, pad_w, + stride_h, stride_w, + dilation_h, dilation_w, + C, + offset_group, + out_h, out_w, + use_mask, + col_buffer_ptr); + + // 2. Perform GEMM for each group + for (int64_t g = 0; g < group; ++g) { + // Weight pointer for group g + // Weight shape: [M, C/group, kH, kW]. + // Stride for group g is (M/group) * (C/group * kH * kW). + const T* weight_g = Wdata + g * (M / group) * kernel_dim; + + // Col buffer pointer for group g + // Col buffer shape: [C * kH * kW, output_image_size] + // We need the rows corresponding to group g. + // Row stride: output_image_size + // Group stride: (C/group * kH * kW) * output_image_size + const T* col_g = col_buffer_ptr + g * kernel_dim * output_image_size; + + // Output pointer for group g + // Output shape: [N, M, out_h, out_w] + // Current image offset: n * M * output_image_size + // Group offset: g * (M/group) * output_image_size + T* Y_g = Ydata + n * M * output_image_size + g * (M / group) * output_image_size; + + // Y = W * Col + // W matrix: [M/group, kernel_dim] + // Col matrix: [kernel_dim, output_image_size] + // Y matrix: [M/group, output_image_size] + math::GemmEx( + CblasNoTrans, + CblasNoTrans, + narrow(M / group), // M + narrow(output_image_size),// N + narrow(kernel_dim), // K + static_cast(1), // alpha + weight_g, // A + narrow(kernel_dim), // lda + col_g, // B + narrow(output_image_size), // ldb + static_cast(0), // beta + Y_g, // C + narrow(output_image_size), // ldc + thread_pool); + } + } + + // 3. Add Bias if present + if (Bdata != nullptr) { + for (int64_t n = 0; n < N; ++n) { + T* Y_curr = Ydata + n * M * output_image_size; + for (int64_t m = 0; m < M; ++m) { + T bias_val = Bdata[m]; + for (int64_t i = 0; i < output_image_size; ++i) { + Y_curr[m * output_image_size + i] += bias_val; + } + } + } + } + + return Status::OK(); +} + +// Explicit template instantiation for float and double +template class DeformConv; + +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + DeformConv, + 19, + 21, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 2) // offset + .InputMemoryType(OrtMemTypeCPUInput, 4), // optional mask + DeformConv); + +ONNX_CPU_OPERATOR_KERNEL( + DeformConv, + 22, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 4), + DeformConv); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.h b/onnxruntime/core/providers/cpu/nn/deform_conv.h new file mode 100644 index 0000000000000..47e310da004c3 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/op_node_proto_helper.h" +#include "core/framework/tensor_shape.h" + +namespace onnxruntime { + +// Attributes for ONNX DeformConv (opset 19+). +// See https://onnx.ai/onnx/operators/onnx__DeformConv.html +struct DeformConvAttributes { + explicit DeformConvAttributes(const OpKernelInfo& info) { + Status status = info.GetAttrs("kernel_shape", kernel_shape); + ORT_ENFORCE(status.IsOK(), "Attribute kernel_shape is not set."); + status = info.GetAttrs("strides", strides); + ORT_ENFORCE(status.IsOK(), "Attribute strides is not set."); + status = info.GetAttrs("pads", pads); + ORT_ENFORCE(status.IsOK(), "Attribute pads is not set."); + status = info.GetAttrs("dilations", dilations); + ORT_ENFORCE(status.IsOK(), "Attribute dilations is not set."); + group = info.GetAttrOrDefault("group", 1); + offset_group = info.GetAttrOrDefault("offset_group", 1); + } + + TensorShapeVector kernel_shape; + TensorShapeVector strides; + TensorShapeVector pads; + TensorShapeVector dilations; + int64_t group{1}; + int64_t offset_group{1}; +}; + +template +class DeformConv : public OpKernel { + public: + explicit DeformConv(const OpKernelInfo& info) : OpKernel(info), attrs_(info) {} + + Status Compute(OpKernelContext* context) const override; + + private: + DeformConvAttributes attrs_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py new file mode 100644 index 0000000000000..5af0613c081e8 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +""" +Generate expected outputs for DeformConv tests using torchvision.ops.deform_conv2d. +Run with: .venv/bin/python onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py +Outputs C++-friendly std::vector initializer lists for pasting into deform_conv_op_test.cc +""" +import torch +import torchvision.ops + +def _pair(x): + if isinstance(x, int): + return (x, x) + return x + +def to_cpp_list(t: torch.Tensor, fmt="{:.6f}") -> str: + """Flatten tensor in NCHW order and format as C++ initializer list.""" + t = t.detach().float().contiguous() + return ", ".join(fmt.format(x) + "f" for x in t.flatten().tolist()) + +def run_case(name: str, batch_sz: int, n_in: int, n_out: int, n_weight_grps: int, n_offset_grps: int, + kernel_h: int, kernel_w: int, stride: tuple, pad: tuple, dilation: tuple, + in_h: int, in_w: int, seed: int = 42): + """Build inputs with seed, run deform_conv2d, print C++ snippets.""" + torch.manual_seed(seed) + stride_h, stride_w = _pair(stride) + pad_h, pad_w = _pair(pad) + dil_h, dil_w = _pair(dilation) + + out_h = (in_h + 2 * pad_h - (dil_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (kernel_w - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in, in_h, in_w, dtype=torch.float32) + offset = torch.randn(batch_sz, n_offset_grps * 2 * kernel_h * kernel_w, out_h, out_w, dtype=torch.float32) + mask = torch.randn(batch_sz, n_offset_grps * kernel_h * kernel_w, out_h, out_w, dtype=torch.float32) + weight = torch.randn(n_out, n_in // n_weight_grps, kernel_h, kernel_w, dtype=torch.float32) + bias = torch.randn(n_out, dtype=torch.float32) + + # Standard answer from torchvision + out = torchvision.ops.deform_conv2d( + x, offset, weight, bias=bias, + stride=(stride_h, stride_w), padding=(pad_h, pad_w), dilation=(dil_h, dil_w), mask=mask + ) + + # ONNX pads = [top, left, bottom, right] + pads_onnx = [pad_h, pad_w, pad_h, pad_w] + + print(f"// --- {name} (seed={seed}) ---") + print(f"// Shapes: X({batch_sz},{n_in},{in_h},{in_w}) W({n_out},{n_in//n_weight_grps},{kernel_h},{kernel_w})") + print(f"// stride=({stride_h},{stride_w}) pad=({pad_h},{pad_w}) dilation=({dil_h},{dil_w})") + print(f"// out_h={out_h} out_w={out_w}") + print() + print("std::vector X = {" + to_cpp_list(x) + "};") + print("std::vector W = {" + to_cpp_list(weight) + "};") + print("std::vector offset = {" + to_cpp_list(offset) + "};") + print("std::vector B = {" + to_cpp_list(bias) + "};") + print("std::vector mask = {" + to_cpp_list(mask) + "};") + print("std::vector expected_Y = {" + to_cpp_list(out) + "};") + print() + print("// Params: kernel_shape={" + f"{kernel_h}, {kernel_w}" + "}, stride={" + f"{stride_h}, {stride_w}" + "}, pads={" + ", ".join(map(str, pads_onnx)) + "}, dilations={" + f"{dil_h}, {dil_w}" + "}, group=" + str(n_weight_grps) + ", offset_group=" + str(n_offset_grps)) + print() + return out + +def main(): + print("// Generated by deform_conv_expected_gen.py (torchvision.ops.deform_conv2d)") + print() + + # Case 1: Same config as PyTorch TestDeformConv.get_fn_args (small batch for readability) + run_case( + "PyTorch get_fn_args style (batch=1)", + batch_sz=1, + n_in=6, n_out=2, n_weight_grps=2, n_offset_grps=3, + kernel_h=3, kernel_w=2, + stride=(2, 1), pad=(1, 0), dilation=(2, 1), + in_h=5, in_w=4, + seed=42, + ) + + # Case 2: No mask (mask optional) - same config, then expected with mask=None + torch.manual_seed(42) + n_in, n_out = 6, 2 + n_weight_grps, n_offset_grps = 2, 3 + kH, kW = 3, 2 + stride_h, stride_w = 2, 1 + pad_h, pad_w = 1, 0 + dil_h, dil_w = 2, 1 + in_h, in_w = 5, 4 + batch_sz = 1 + out_h = (in_h + 2 * pad_h - (dil_h * (kH - 1) + 1)) // stride_h + 1 + out_w = (in_w + 2 * pad_w - (dil_w * (kW - 1) + 1)) // stride_w + 1 + + x = torch.rand(batch_sz, n_in, in_h, in_w, dtype=torch.float32) + offset = torch.randn(batch_sz, n_offset_grps * 2 * kH * kW, out_h, out_w, dtype=torch.float32) + weight = torch.randn(n_out, n_in // n_weight_grps, kH, kW, dtype=torch.float32) + bias = torch.randn(n_out, dtype=torch.float32) + + out_no_mask = torchvision.ops.deform_conv2d( + x, offset, weight, bias=bias, + stride=(stride_h, stride_w), padding=(pad_h, pad_w), dilation=(dil_h, dil_w), mask=None + ) + print("// --- Same inputs, no mask (expected_Y when mask is omitted) ---") + print("std::vector expected_Y_no_mask = {" + to_cpp_list(out_no_mask) + "};") + print() + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc new file mode 100644 index 0000000000000..a3ca282bfa0b2 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -0,0 +1,366 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Unit tests for DeformConv (CPU), aligned with PyTorch Vision deform_conv2d tests. +// Reference: https://github.com/pytorch/vision/blob/main/test/test_ops.py (TestDeformConv) + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace { + +// Parameters similar to PyTorch TestDeformConv::get_fn_args (smaller for speed). +struct DeformConvTestParams { + int64_t batch_sz; + int64_t n_in_channels; + int64_t n_out_channels; + int64_t n_weight_grps; + int64_t n_offset_grps; + std::vector kernel_shape; // {kH, kW} + std::vector stride; + std::vector pad; + std::vector dilation; + int64_t in_h; + int64_t in_w; +}; + +void RunDeformConvTest(const DeformConvTestParams& params, + const std::vector& X, + const std::vector& W, + const std::vector& offset, + const std::vector& B, + const std::vector* mask, + const std::vector& expected_Y, + int opset = 19, + float rtol = 1e-5f, + float atol = 1e-5f) { + const int64_t kH = params.kernel_shape[0]; + const int64_t kW = params.kernel_shape[1]; + const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - + params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; + const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - + params.dilation[1] * (kW - 1) - 1) / params.stride[1] + 1; + + OpTester test("DeformConv", opset); + test.AddAttribute("kernel_shape", params.kernel_shape); + test.AddAttribute("strides", params.stride); + test.AddAttribute("pads", params.pad); + test.AddAttribute("dilations", params.dilation); + test.AddAttribute("group", params.n_weight_grps); + test.AddAttribute("offset_group", params.n_offset_grps); + + const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; + const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; + const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; + const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {params.n_out_channels}, B); + if (mask != nullptr) { + const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; + test.AddInput("mask", mask_shape, *mask); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("Y", Y_shape, expected_Y, false, rtol, atol); + + std::unordered_set excluded = {kTensorrtExecutionProvider, kCudaExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +} // namespace + +// Minimal case: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +// At (0,0) offset (0.5, 0.5) samples center of [1,2;3,4] -> 2.5. +TEST(DeformConvTest, MinimalBilinear) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; // NCHW + std::vector W = {1.f}; + // offset (1, 2, 2, 2): ch0=offset_h, ch1=offset_w per output position. (0,0):(0.5,0)->2.5, (0,1):(0.5,-1)->1 + std::vector offset = { + 0.5f, 0.f, 0.f, 0.f, + 0.5f, -1.0f, 0.f, 0.f + }; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; // (1,1,2,2) + std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Forward with mask and bias: 2 batches, 2 groups, zero offset -> behaves like grouped conv. +// With offset=0 and mask=1, Y = Conv(X,W) + B. Use small inputs and compute expected. +TEST(DeformConvTest, ForwardWithMaskAndBias) { + DeformConvTestParams p = {}; + p.batch_sz = 2; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(p.batch_sz * p.n_in_channels * p.in_h * p.in_w); + const size_t w_size = static_cast(p.n_out_channels * (p.n_in_channels / p.n_weight_grps) * 2 * 2); + const size_t offset_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); // zero offset -> regular grid sampling + std::vector mask(mask_size, 1.f); + std::vector B = {0.5f, -0.5f}; + + // With offset=0, mask=1: deform_conv equals grouped conv. Per ONNX, group 0 -> output ch 0, group 1 -> ch 1. + // Uniform X=0.1, W=0.1, 2x2 kernel -> 0.08 + B per channel; Y[:,0,:,:]=0.58, Y[:,1,:,:]=-0.42. + const size_t y_size = static_cast(p.batch_sz * p.n_out_channels * out_h * out_w); + std::vector expected_Y(y_size); + for (int64_t b = 0; b < p.batch_sz; ++b) { + for (int64_t c = 0; c < p.n_out_channels; ++c) { + float val = (c % 2 == 0) ? 0.58f : -0.42f; + for (int64_t i = 0; i < out_h * out_w; ++i) { + expected_Y[b * p.n_out_channels * out_h * out_w + c * out_h * out_w + i] = val; + } + } + } + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// No mask (optional): same as above but mask omitted; compare to run with ones mask via tolerance. +TEST(DeformConvTest, ForwardNoMask) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = 1 * 2 * 3 * 3; + const size_t w_size = 2 * 2 * 2 * 2; + const size_t offset_size = 1 * 2 * 2 * 2 * out_h * out_w; + const size_t y_size = 1 * 2 * out_h * out_w; + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector B(2, 0.f); + // No mask => mask=1. Zero offset => same as conv. Y = 4*2*0.1*0.1 = 0.08 per position. + std::vector expected_Y(y_size, 0.08f); + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + const std::vector X_shape = {p.batch_sz, p.n_in_channels, p.in_h, p.in_w}; + const std::vector W_shape = {p.n_out_channels, p.n_in_channels / p.n_weight_grps, 2, 2}; + const std::vector offset_shape = {p.batch_sz, p.n_offset_grps * 2 * 2 * 2, out_h, out_w}; + const std::vector Y_shape = {p.batch_sz, p.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {p.n_out_channels}, B); + test.AddOptionalInputEdge(); // no mask + test.AddOutput("Y", Y_shape, expected_Y, false, 1e-4f, 1e-4f); + std::unordered_set excluded = {kTensorrtExecutionProvider, kCudaExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +// Empty batch (like PyTorch batch_sz=0). (like PyTorch batch_sz=0). +TEST(DeformConvTest, EmptyBatch) { + DeformConvTestParams p = {}; + p.batch_sz = 0; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X; + std::vector W = std::vector(2 * 2 * 2 * 2, 0.1f); + std::vector offset; + std::vector B(2, 0.f); + std::vector expected_Y; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + const std::vector X_shape = {0, p.n_in_channels, p.in_h, p.in_w}; + const std::vector W_shape = {p.n_out_channels, p.n_in_channels / p.n_weight_grps, 2, 2}; + const std::vector offset_shape = {0, p.n_offset_grps * 2 * 2 * 2, out_h, out_w}; + const std::vector Y_shape = {0, p.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, X); + test.AddInput("W", W_shape, W); + test.AddInput("offset", offset_shape, offset); + test.AddInput("B", {p.n_out_channels}, B); + test.AddOptionalInputEdge(); + test.AddOutput("Y", Y_shape, expected_Y); + std::unordered_set excluded = {kTensorrtExecutionProvider, kCudaExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +// Wrong offset channel count -> expect failure (like PyTorch test_wrong_sizes). -> expect failure (like PyTorch test_wrong_sizes). +TEST(DeformConvTest, WrongOffsetShape) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 2 * 3 * 3, 0.1f); + std::vector W(2 * 2 * 2 * 2, 0.1f); + std::vector wrong_offset(1 * 2 * out_h * out_w); // wrong: only 2 channels instead of 8 + std::vector B(2, 0.f); + std::vector expected_Y(1 * 2 * out_h * out_w, 0.f); + + const std::vector offset_shape_wrong = {1, 2, out_h, out_w}; + const std::vector Y_shape_wrong = {1, 2, out_h, out_w}; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + test.AddInput("X", {1, 2, 3, 3}, X); + test.AddInput("W", {2, 2, 2, 2}, W); + test.AddInput("offset", offset_shape_wrong, wrong_offset); // invalid channels + test.AddInput("B", {2}, B); + test.AddOptionalInputEdge(); + test.AddOutput("Y", Y_shape_wrong, expected_Y); + test.Run(OpTester::ExpectResult::kExpectFailure, "Offset channel count must be offset_group * 2 * kH * kW"); +} + +// Wrong mask channel count -> expect failure. +TEST(DeformConvTest, WrongMaskShape) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 2 * 3 * 3, 0.1f); + std::vector W(2 * 2 * 2 * 2, 0.1f); + const size_t offset_size = static_cast( + p.batch_sz * p.n_offset_grps * 2 * p.kernel_shape[0] * p.kernel_shape[1] * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector B(2, 0.f); + std::vector wrong_mask(1 * 2 * out_h * out_w); // wrong: 2 instead of 4 + std::vector expected_Y(1 * 2 * out_h * out_w, 0.f); + + const std::vector mask_shape_wrong = {1, 2, out_h, out_w}; + const std::vector Y_shape_mask = {1, 2, out_h, out_w}; + + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", p.kernel_shape); + test.AddAttribute("strides", p.stride); + test.AddAttribute("pads", p.pad); + test.AddAttribute("dilations", p.dilation); + test.AddAttribute("group", p.n_weight_grps); + test.AddAttribute("offset_group", p.n_offset_grps); + test.AddInput("X", {1, 2, 3, 3}, X); + test.AddInput("W", {2, 2, 2, 2}, W); + test.AddInput("offset", {1, 8, out_h, out_w}, offset); + test.AddInput("B", {2}, B); + test.AddInput("mask", mask_shape_wrong, wrong_mask); + test.AddOutput("Y", Y_shape_mask, expected_Y); + test.Run(OpTester::ExpectResult::kExpectFailure, "Mask channel count"); +} + +// Opset 22 (same behavior, different opset). +TEST(DeformConvTest, Opset22) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + std::vector offset = {0.5f, 0.f, 0.f, 0.f, 0.5f, 0.f, 0.f, 0.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 2.f, 3.f, 4.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); +} + +} // namespace test +} // namespace onnxruntime From abfec39e33a5834107ad8cfe7ab85c6132da00a2 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Wed, 18 Feb 2026 05:09:39 +0800 Subject: [PATCH 02/58] Add more tests --- .../cpu/nn/deform_conv_expected_gen.py | 11 + .../providers/cpu/nn/deform_conv_op_test.cc | 270 ++++++++++++++++++ 2 files changed, 281 insertions(+) diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py index 5af0613c081e8..0dffd5f337c61 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -101,5 +101,16 @@ def main(): print("std::vector expected_Y_no_mask = {" + to_cpp_list(out_no_mask) + "};") print() + # Case 3: groups=2, offset_group=2, non-zero offset (for GroupsWithNonZeroOffset test) + run_case( + "Groups with non-zero offset (batch=1, 2 groups)", + batch_sz=1, + n_in=4, n_out=2, n_weight_grps=2, n_offset_grps=2, + kernel_h=2, kernel_w=2, + stride=(1, 1), pad=(0, 0), dilation=(1, 1), + in_h=3, in_w=3, + seed=123, + ) + if __name__ == "__main__": main() diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index a3ca282bfa0b2..1c42501805d0b 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -362,5 +362,275 @@ TEST(DeformConvTest, Opset22) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); } +// Non-square kernel (kH != kW): 2x3 kernel, zero offset -> same as standard conv. +TEST(DeformConvTest, NonSquareKernel) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 3}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 4; + p.in_w = 5; + // ONNX output size: out_h = (4 - 1*(2-1) - 1)/1 + 1 = 3, out_w = (5 - 1*(3-1) - 1)/1 + 1 = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t x_size = static_cast(1 * 1 * 4 * 5); + const size_t w_size = static_cast(1 * 1 * 2 * 3); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 3 * out_h * out_w); // n_offset_grps * 2 * kH * kW * out_h * out_w + const size_t mask_size = static_cast(1 * 1 * 2 * 3 * out_h * out_w); // n_offset_grps * kH * kW * out_h * out_w + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // With offset=0, mask=1: each output = 6 * 0.1 * 0.1 = 0.06 (9 positions) + std::vector expected_Y(static_cast(out_h * out_w), 0.06f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Asymmetric stride (stride_h != stride_w): stride=(2,1), zero offset. +TEST(DeformConvTest, AsymmetricStride) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {2, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 5; + p.in_w = 4; + // out_h = (5 - 1*(2-1) - 1) / 2 + 1 = 2, out_w = (4 - 1*(2-1) - 1) / 1 + 1 = 3 + const int64_t out_h = 2; + const int64_t out_w = 3; + + const size_t x_size = static_cast(1 * 1 * 5 * 4); + const size_t w_size = static_cast(1 * 1 * 2 * 2); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// groups > 0 and non-zero offset; expected from deform_conv_expected_gen.py (seed=123). +TEST(DeformConvTest, GroupsWithNonZeroOffset) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + + std::vector X = {0.296112f, 0.516562f, 0.251671f, 0.688557f, 0.073972f, 0.866522f, 0.136580f, 0.102479f, 0.184056f, 0.726447f, 0.315254f, 0.687107f, 0.075635f, 0.196638f, 0.316412f, 0.401740f, 0.118568f, 0.827395f, 0.382084f, 0.660494f, 0.853572f, 0.593153f, 0.636725f, 0.982629f, 0.274495f, 0.658376f, 0.277542f, 0.857325f, 0.899328f, 0.039014f, 0.926823f, 0.738757f, 0.717884f, 0.705837f, 0.915650f, 0.433980f}; + std::vector W = {-1.182045f, -0.287745f, -0.604301f, 0.600237f, -1.420473f, -0.223828f, 0.430555f, -0.898857f, -0.017858f, 0.426403f, -0.765741f, -0.054514f, -0.732053f, 1.234742f, 1.186221f, -0.220099f}; + std::vector offset = {-0.388483f, -0.934346f, -0.499144f, -1.086653f, 0.962421f, 0.249208f, -0.484502f, -2.092915f, 0.098284f, -0.093507f, 0.266215f, -0.585035f, -0.343038f, -0.682148f, -0.988689f, -1.701830f, -1.220290f, 1.313853f, 1.053300f, 0.138805f, -0.204445f, -2.268529f, -0.913328f, -0.420363f, -0.659559f, -0.797928f, 0.183831f, 0.229347f, 0.617743f, -0.287578f, 0.821824f, 0.151178f, -0.044382f, 1.623557f, -2.322871f, 1.087831f, -0.063545f, -0.448641f, -1.278470f, -1.144004f, -0.152640f, 0.116741f, 0.440260f, -1.446546f, -0.558082f, -0.051696f, -0.908273f, 0.350683f, -0.394809f, 0.489227f, -0.216815f, -1.747165f, 1.722842f, 0.773806f, 0.404630f, -1.646126f, -0.595084f, -0.711218f, 0.622965f, -1.372881f, -0.128065f, -1.283835f, -0.290120f, 1.276741f}; + std::vector B = {0.983955f, 0.204512f}; + std::vector mask = {-0.031861f, -0.478956f, 0.766809f, 0.027468f, 0.047470f, -0.923866f, -1.060737f, -2.324446f, -2.062818f, 0.006375f, -0.989555f, 0.701609f, -0.982238f, 0.277031f, 0.645495f, -0.895681f, 0.492753f, -0.014078f, -0.274663f, -0.764091f, -0.587157f, 1.195165f, -1.209575f, -0.556008f, -0.077105f, 1.277377f, -1.459629f, -2.159528f, -0.706709f, -0.922245f, 3.895372f, -0.602697f}; + std::vector expected_Y = {0.971546f, 1.139858f, 0.452817f, 1.863882f, -0.565266f, 1.423187f, -2.462833f, -0.104923f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// Sampling out of bounds: offset pushes sampling to (-5,-5), BilinearInterpolate returns 0. +TEST(DeformConvTest, OutOfBoundsSampling) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + // out_h=out_w=2 (2x2 output), offset shape [1, 2, 2, 2] = 8 values. All (-5,-5) -> OOB -> 0 + std::vector offset = {-5.f, -5.f, -5.f, -5.f, -5.f, -5.f, -5.f, -5.f}; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {0.f, 0.f, 0.f, 0.f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Dilation > 1: 2x2 kernel with dilation (2,2), zero offset -> 4 sample points with stride 2. +TEST(DeformConvTest, DilationGt1) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {2, 2}; + p.in_h = 5; + p.in_w = 5; + // out_h = (5 - 2*(2-1) - 1)/1 + 1 = 3, out_w = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t x_size = 25; + const size_t w_size = 4; + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Each output: 4 samples at (0,0),(0,2),(2,0),(2,2) -> 4 * 0.1 * 0.1 = 0.04 + std::vector expected_Y(static_cast(out_h * out_w), 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Decoupled groups: group=2, offset_group=1 (one offset map shared by all input channels). +TEST(DeformConvTest, DecoupledGroups) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(1 * 4 * 3 * 3); + const size_t w_size = static_cast(2 * 2 * 2 * 2); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f, 0.f}; + // Zero offset -> grouped conv. Per output ch: 2 in_ch * 4 kernel * 0.01 = 0.08 + std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.08f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Asymmetric padding: pads [top=1, left=0, bottom=0, right=1]; output 3x3, some positions have OOB samples. +TEST(DeformConvTest, AsymmetricPadding) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {1, 0, 0, 1}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + // out_h = (3+1+0-1*(2-1)-1)/1+1 = 3, out_w = (3+0+1-1-1)/1+1 = 3 + const int64_t out_h = 3; + const int64_t out_w = 3; + + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 2 * 2 * out_h * out_w); + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Row 0: (0,0),(0,1) 2 valid -> 0.02; (0,2) only (0,2) in, (0,3) OOB -> 1 valid -> 0.01. Row 1/2: as before. + std::vector expected_Y = {0.02f, 0.02f, 0.01f, 0.04f, 0.04f, 0.02f, 0.04f, 0.04f, 0.02f}; + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Tiny offset (near zero): offset (1e-6, 1e-6), sample ~(0,0) -> bilinear ≈ X[0,0]. Use 1x1 input for 1 output. +TEST(DeformConvTest, TinyOffset) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 1; + + std::vector X = {1.f}; + std::vector W = {1.f}; + std::vector offset = {1e-6f, 1e-6f}; + std::vector B = {0.f}; + std::vector mask = {1.f}; + std::vector expected_Y = {1.f}; // bilinear at (1e-6, 1e-6) ≈ 1 + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); +} + +// Offset (0.5, 0.5) at each kernel point: sampling at (i+0.5, j+0.5) -> (0.5,0.5),(0.5,1.5),(1.5,0.5),(1.5,1.5). +// Only (0.5,0.5) is fully in-bounds for 2x2 input; others hit boundary (OOB gives 0). Result = 1.6875. +TEST(DeformConvTest, OffsetAtPixelCenters) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {0.25f, 0.25f, 0.25f, 0.25f}; + std::vector offset = { + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f + }; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {1.6875f}; // op output: one center sample 2.5 + boundary samples + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + } // namespace test } // namespace onnxruntime From a0c506041a548a0660ca0c1e00419bfdc0ac929d Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Wed, 18 Feb 2026 05:57:30 +0800 Subject: [PATCH 03/58] Add cuda support for deformconv2d --- .../providers/cuda/cuda_execution_provider.cc | 16 + .../core/providers/cuda/nn/deform_conv.cc | 255 +++++++++++ .../core/providers/cuda/nn/deform_conv.h | 51 +++ .../providers/cuda/nn/deform_conv_impl.cu | 424 ++++++++++++++++++ .../core/providers/cuda/nn/deform_conv_impl.h | 62 +++ .../providers/cpu/nn/deform_conv_op_test.cc | 127 +++++- 6 files changed, 931 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv.cc create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv.h create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu create mode 100644 onnxruntime/core/providers/cuda/nn/deform_conv_impl.h diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b87cf8cbc16c1..fdc23b5277370 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1455,6 +1455,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, BFloat16, DeformConv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Cast); @@ -1596,6 +1600,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Conv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, DeformConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid); @@ -2574,6 +2582,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2706,6 +2718,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc new file mode 100644 index 0000000000000..2df4d25cd7892 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CUDA implementation of DeformConv (deformable convolution 2D). + +#include "core/providers/shared_library/provider_api.h" +#include "deform_conv.h" +#include "deform_conv_impl.h" + +#include "core/common/narrow.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace cuda { + +namespace { + +constexpr int kMaxParallelImgs = 32; + +int GetGreatestDivisorBelowBound(int n, int bound) { + for (int k = bound; k > 1; --k) { + if (n % k == 0) { + return k; + } + } + return 1; +} + +} // namespace + +template +Status DeformConv::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* offset = context->Input(2); + const auto* B = context->Input(3); + const auto* mask = context->Input(4); + + const auto& X_shape = X->Shape(); + const auto& W_shape = W->Shape(); + const auto& offset_shape = offset->Shape(); + + const int64_t N = X_shape[0]; + const int64_t C = X_shape[1]; + const int64_t H = X_shape[2]; + const int64_t W_in = X_shape[3]; + + const int64_t M = W_shape[0]; + const int64_t kH = attrs_.kernel_shape.size() >= 1 ? attrs_.kernel_shape[0] : W_shape[2]; + const int64_t kW = attrs_.kernel_shape.size() >= 2 ? attrs_.kernel_shape[1] : W_shape[3]; + + int64_t pad_h = 0, pad_w = 0, pad_h_end = 0, pad_w_end = 0; + if (attrs_.pads.size() >= 4) { + pad_h = attrs_.pads[0]; + pad_w = attrs_.pads[1]; + pad_h_end = attrs_.pads[2]; + pad_w_end = attrs_.pads[3]; + } + + const int64_t stride_h = attrs_.strides.empty() ? 1 : attrs_.strides[0]; + const int64_t stride_w = attrs_.strides.size() < 2 ? 1 : attrs_.strides[1]; + const int64_t dilation_h = attrs_.dilations.empty() ? 1 : attrs_.dilations[0]; + const int64_t dilation_w = attrs_.dilations.size() < 2 ? 1 : attrs_.dilations[1]; + const int64_t group = attrs_.group; + const int64_t offset_group = attrs_.offset_group; + + const int64_t out_h = (H + pad_h + pad_h_end - dilation_h * (kH - 1) - 1) / stride_h + 1; + const int64_t out_w = (W_in + pad_w + pad_w_end - dilation_w * (kW - 1) - 1) / stride_w + 1; + + ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); + ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); + ORT_RETURN_IF_NOT(offset_shape[1] == offset_group * 2 * kH * kW, + "Offset channel count must be offset_group * 2 * kH * kW."); + ORT_RETURN_IF_NOT(offset_shape[2] == out_h && offset_shape[3] == out_w, + "Offset spatial dims must match output."); + ORT_RETURN_IF_NOT(C % offset_group == 0, "Input channels must be divisible by offset_group."); + ORT_RETURN_IF_NOT(C == W_shape[1] * group, "Input channels must match weight in channels * group."); + ORT_RETURN_IF_NOT(M % group == 0, "Output channels must be divisible by group."); + + const bool use_mask = (mask != nullptr); + if (use_mask) { + ORT_RETURN_IF_NOT(mask->Shape().NumDimensions() == 4, "Mask must be 4D."); + ORT_RETURN_IF_NOT(mask->Shape()[1] == offset_group * kH * kW, "Mask channel count invalid."); + ORT_RETURN_IF_NOT(mask->Shape()[2] == out_h && mask->Shape()[3] == out_w, "Mask spatial dims must match output."); + } + + Tensor* Y = context->Output(0, {N, M, out_h, out_w}); + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + const int64_t kernel_size = kH * kW; + const int64_t output_image_size = out_h * out_w; + const int64_t input_image_size = H * W_in; + const int64_t kernel_dim = (C / group) * kernel_size; + + // Calculate memory usage per image to avoid OOM with large images + // col_buffer: C * kernel_size * output_image_size + // gemm_output_buffer: (M / group) * output_image_size + // We use a safe max(1, ...) for bytes_per_image to avoid division by zero in edge cases + const size_t bytes_per_image = SafeInt(output_image_size) * (C * kernel_size + M / group) * sizeof(T); + + // Heuristic: limit temp memory to 256MB per chunk to balance parallelism and memory usage. + // For small images, this allows up to kMaxParallelImgs (32). + // For large images (4K/8K), this restricts parallelism to 1 to prevent OOM. + constexpr size_t kMaxTempMemSize = 256 * 1024 * 1024; + const int max_parallel_imgs_mem = std::max(1, static_cast(kMaxTempMemSize / std::max(size_t(1), bytes_per_image))); + const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); + + const int n_parallel_imgs = GetGreatestDivisorBelowBound(static_cast(N), target_parallel_imgs); + const int64_t col_stride = static_cast(n_parallel_imgs) * output_image_size; + const int64_t col_buffer_size = (C * kernel_size) * col_stride; + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); + // Removed col_transposed allocation as we avoid physical transpose. + auto gemm_output_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt((M / group) * col_stride)); + + const T* Xdata = X->Data(); + const T* Wdata = W->Data(); + const T* offset_data = offset->Data(); + const T* mask_data = use_mask ? mask->Data() : nullptr; + T* Ydata = Y->MutableData(); + const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + + cudaStream_t stream = Stream(context); + cublasHandle_t cublas = GetCublasHandle(context); + const cudaDeviceProp& device_prop = GetDeviceProp(); + CudaT alpha = ToCudaType::FromFloat(1.0f); + CudaT beta = ToCudaType::FromFloat(0.0f); + + for (int64_t b = 0; b < N; b += n_parallel_imgs) { + const int cur_parallel = static_cast(std::min(static_cast(n_parallel_imgs), N - b)); + const int64_t cur_out_size = static_cast(cur_parallel) * output_image_size; + + const T* X_block = Xdata + b * (C * input_image_size); + const T* offset_block = offset_data + b * (offset_group * 2 * kernel_size * output_image_size); + const T* mask_block = use_mask ? (mask_data + b * (offset_group * kernel_size * output_image_size)) : nullptr; + + DeformConvIm2ColImpl( + stream, + X_block, + offset_block, + mask_block, + col_buffer.get(), + cur_parallel, + C, + H, + W_in, + kH, + kW, + out_h, + out_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + offset_group, + use_mask); + + for (int64_t g = 0; g < group; ++g) { + const T* W_g = Wdata + g * (M / group) * kernel_dim; + const T* col_g = col_buffer.get() + g * kernel_dim * col_stride; + T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size; + + // Avoid physical transpose by using cuBLAS OP_N/OP_N logic. + // We want Y = W * Col. + // W is [M/group, kernel_dim] (Row-Major). + // Col is [kernel_dim, cur_out_size] (Row-Major). + // We compute Y^T = Col^T * W^T. + // Col^T (Col-Major [cur_out_size, kernel_dim]) is exactly Col (Row-Major [kernel_dim, cur_out_size]) in memory. + // W^T (Col-Major [kernel_dim, M/group]) is exactly W (Row-Major [M/group, kernel_dim]) in memory. + // Result Y^T is Col-Major [cur_out_size, M/group]. + // In memory, Y^T (Col-Major) is exactly Y (Row-Major [M/group, cur_out_size]). + // So we get Y in Row-Major layout. + + // A = Col (Row-Major [kernel_dim, cur_out_size]) -> interpreted as Col-Major [cur_out_size, kernel_dim]. + // B = W (Row-Major [M/group, kernel_dim]) -> interpreted as Col-Major [kernel_dim, M/group]. + // C = A * B = Col^T * W^T = Y^T. + // C is Col-Major [cur_out_size, M/group]. + // m = cur_out_size, n = M/group, k = kernel_dim. + // lda = cur_out_size. + // ldb = kernel_dim. + // ldc = cur_out_size. + + CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( + cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + narrow(cur_out_size), + narrow(M / group), + narrow(kernel_dim), + &alpha, + reinterpret_cast(col_g), + narrow(cur_out_size), + reinterpret_cast(W_g), + narrow(kernel_dim), + &beta, + reinterpret_cast(gemm_output_buffer.get()), + narrow(cur_out_size), + device_prop, + UseTF32()))); + + // The output gemm_output_buffer is now Row-Major [M/group, cur_out_size]. + // We need to copy it to Y_g (NCHW). + DeformConvCopyGemmOutputRowMajorToNCHW( + stream, + gemm_output_buffer.get(), + Y_g, + M, + M / group, + output_image_size, + cur_parallel); + } + } + + if (Bdata != nullptr) { + DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w); + } + + return Status::OK(); +} + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 19, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(double) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h new file mode 100644 index 0000000000000..44eb7237b327e --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensor_shape.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace cuda { + +// Attributes for ONNX DeformConv (opset 19+). Mirrors CPU for consistency. +// See https://onnx.ai/onnx/operators/onnx__DeformConv.html +struct DeformConvAttributes { + explicit DeformConvAttributes(const OpKernelInfo& info) { + Status status = info.GetAttrs("kernel_shape", kernel_shape); + ORT_ENFORCE(status.IsOK(), "Attribute kernel_shape is not set."); + status = info.GetAttrs("strides", strides); + ORT_ENFORCE(status.IsOK(), "Attribute strides is not set."); + status = info.GetAttrs("pads", pads); + ORT_ENFORCE(status.IsOK(), "Attribute pads is not set."); + status = info.GetAttrs("dilations", dilations); + ORT_ENFORCE(status.IsOK(), "Attribute dilations is not set."); + group = info.GetAttrOrDefault("group", 1); + offset_group = info.GetAttrOrDefault("offset_group", 1); + } + + TensorShapeVector kernel_shape; + TensorShapeVector strides; + TensorShapeVector pads; + TensorShapeVector dilations; + int64_t group{1}; + int64_t offset_group{1}; +}; + +template +class DeformConv final : public CudaKernel { + public: + explicit DeformConv(const OpKernelInfo& info) : CudaKernel(info), attrs_(info) {} + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + DeformConvAttributes attrs_; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeformConv); +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu new file mode 100644 index 0000000000000..595f07a9c1dc3 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -0,0 +1,424 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// CUDA implementation of DeformConv: deformable im2col kernel + bilinear interpolation. +// Reference: torchvision deform_conv2d_kernel.cu, ONNX DeformConv spec. + +#include "deform_conv_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/fast_divmod.h" +#include "core/common/float16.h" +#include +#include +#include + +namespace onnxruntime { +namespace cuda { + +namespace { + +constexpr int kDeformConvThreadsPerBlock = 256; + +// Calculate grid size with a safety limit to prevent overflow. +// Since we use grid-stride loops in kernels, limiting the grid size is safe. +inline int GetGridSize(size_t n, size_t threads_per_block) { + size_t blocks_needed = (n + threads_per_block - 1) / threads_per_block; + return static_cast(std::min(blocks_needed, static_cast(std::numeric_limits::max()))); +} + +// Bilinear interpolation at (h, w). Returns 0 if out of bounds (ONNX spec). +template +__device__ __inline__ T BilinearInterpolate( + const T* in, + int64_t height, + int64_t width, + T h, + T w) { + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return static_cast(0); + } + int h_low = static_cast(_Floor(h)); + int w_low = static_cast(_Floor(w)); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - static_cast(h_low); + T lw = w - static_cast(w_low); + T hh = static_cast(1) - lh; + T hw = static_cast(1) - lw; + + T v1 = (h_low >= 0 && w_low >= 0) ? __ldg(in + h_low * width + w_low) : static_cast(0); + T v2 = (h_low >= 0 && w_high < width) ? __ldg(in + h_low * width + w_high) : static_cast(0); + T v3 = (h_high < height && w_low >= 0) ? __ldg(in + h_high * width + w_low) : static_cast(0); + T v4 = (h_high < height && w_high < width) ? __ldg(in + h_high * width + w_high) : static_cast(0); + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} + +// FP16/BF16: coordinate and weight math in float to avoid precision loss. +template +struct DeformConvUseFloatCoords : std::false_type {}; +template <> +struct DeformConvUseFloatCoords : std::true_type {}; +template <> +struct DeformConvUseFloatCoords : std::true_type {}; + +// __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly. +template +__device__ __inline__ T DeformConvLdg(const T* p) { + return __ldg(p); +} +template <> +__device__ __inline__ BFloat16 DeformConvLdg(const BFloat16* p) { + return BFloat16::FromBits(__ldg(reinterpret_cast(p))); +} + +__device__ __inline__ half BilinearInterpolate( + const half* in, + int64_t height, + int64_t width, + float h, + float w) { + if (h <= -1.0f || h >= height || w <= -1.0f || w >= width) { + return __float2half(0.0f); + } + int h_low = static_cast(floorf(h)); + int w_low = static_cast(floorf(w)); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - static_cast(h_low); + float lw = w - static_cast(w_low); + float hh = 1.0f - lh; + float hw = 1.0f - lw; + + float v1 = (h_low >= 0 && w_low >= 0) ? __half2float(__ldg(in + h_low * width + w_low)) : 0.0f; + float v2 = (h_low >= 0 && w_high < width) ? __half2float(__ldg(in + h_low * width + w_high)) : 0.0f; + float v3 = (h_high < height && w_low >= 0) ? __half2float(__ldg(in + h_high * width + w_low)) : 0.0f; + float v4 = (h_high < height && w_high < width) ? __half2float(__ldg(in + h_high * width + w_high)) : 0.0f; + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return __float2half(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); +} + +__device__ __inline__ BFloat16 BilinearInterpolate( + const BFloat16* in, + int64_t height, + int64_t width, + float h, + float w) { + if (h <= -1.0f || h >= height || w <= -1.0f || w >= width) { + return BFloat16(0.0f); + } + int h_low = static_cast(floorf(h)); + int w_low = static_cast(floorf(w)); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - static_cast(h_low); + float lw = w - static_cast(w_low); + float hh = 1.0f - lh; + float hw = 1.0f - lw; + + float v1 = (h_low >= 0 && w_low >= 0) ? static_cast(DeformConvLdg(in + h_low * width + w_low)) : 0.0f; + float v2 = (h_low >= 0 && w_high < width) ? static_cast(DeformConvLdg(in + h_low * width + w_high)) : 0.0f; + float v3 = (h_high < height && w_low >= 0) ? static_cast(DeformConvLdg(in + h_high * width + w_low)) : 0.0f; + float v4 = (h_high < height && w_high < width) ? static_cast(DeformConvLdg(in + h_high * width + w_high)) : 0.0f; + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return BFloat16(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); +} + +// 1D parallel: each thread handles (in_c, out_y, out_x, out_b), inner loop over kH x kW. +// num_kernels = C * out_h * out_w * parallel_imgs. +// Col layout row-major: rows = C*kH*kW, cols = parallel_imgs*out_h*out_w. +// data_col[col_row_idx * col_stride + c_col] with col_stride = parallel_imgs*out_h*out_w. +template +__global__ void DeformableIm2ColKernel( + IndexT num_kernels, + const T* __restrict__ input, + const T* __restrict__ offset, + const T* __restrict__ mask, + int64_t height, + int64_t width, + int64_t weight_h, + int64_t weight_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t channels, + int64_t offset_group, + DivMod out_h_div, + DivMod out_w_div, + DivMod parallel_imgs_div, + DivMod channel_per_offset_grp_div, + bool use_mask, + T* __restrict__ data_col) { + + // Reconstruct dimensions from DivMod objects + const int64_t out_h = out_h_div.d_; + const int64_t out_w = out_w_div.d_; + const int64_t parallel_imgs = parallel_imgs_div.d_; + + const int64_t out_size = out_h * out_w; + const int64_t col_stride = parallel_imgs * out_size; + + for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { + IndexT val = index; + IndexT out_x, out_y, out_b, in_c; + + // Fast division/modulo to recover coordinates + out_w_div.divmod(val, val, out_x); + out_h_div.divmod(val, val, out_y); + parallel_imgs_div.divmod(val, in_c, out_b); + + IndexT offset_grp, dummy; + channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); + + const T* input_ptr = input + out_b * (channels * height * width) + in_c * (height * width); + const T* offset_ptr = offset + out_b * (offset_group * 2 * weight_h * weight_w * out_size) + + offset_grp * (2 * weight_h * weight_w * out_size); + const T* mask_ptr = use_mask ? (mask + out_b * (offset_group * weight_h * weight_w * out_size) + + offset_grp * (weight_h * weight_w * out_size)) + : nullptr; + + const int64_t c_col = out_b * out_size + out_y * out_w + out_x; + + // Use float for coordinate math when T is half or BFloat16 to avoid precision loss. + using CoordT = typename std::conditional::value, float, T>::type; + +#pragma unroll + for (int64_t i = 0; i < weight_h; ++i) { +#pragma unroll + for (int64_t j = 0; j < weight_w; ++j) { + const int64_t mask_idx = i * weight_w + j; + const int64_t offset_idx = 2 * mask_idx; + + T mask_val = static_cast(1); + if (use_mask) { + mask_val = DeformConvLdg(mask_ptr + mask_idx * out_size + out_y * out_w + out_x); + } + + const int64_t offset_h_idx = (offset_idx)*out_size + out_y * out_w + out_x; + const int64_t offset_w_idx = (offset_idx + 1) * out_size + out_y * out_w + out_x; + const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr + offset_h_idx)); + const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr + offset_w_idx)); + + const CoordT h_im = static_cast(out_y * stride_h - pad_h + i * dilation_h) + offset_h; + const CoordT w_im = static_cast(out_x * stride_w - pad_w + j * dilation_w) + offset_w; + + T val = static_cast(0); + if (mask_val != static_cast(0)) { + val = BilinearInterpolate(input_ptr, height, width, h_im, w_im); + } + + const int64_t col_row_idx = (in_c * weight_h * weight_w) + (i * weight_w + j); + data_col[col_row_idx * col_stride + c_col] = val * mask_val; + } + } + } +} + +// Bias add: Y[n,m,oh,ow] += B[m]. Layout NCHW. +template +__global__ void DeformConvAddBiasKernel(T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { + int64_t out_size = out_h * out_w; + int64_t total = N * M * out_size; + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { + int64_t m = (idx / out_size) % M; + Y[idx] += DeformConvLdg(B + m); + } +} + +// Copy GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) into NCHW Y_g. +// src(c, j) with j = b_idx*output_image_size + pos -> dst[b_idx*M*output_image_size + c*output_image_size + pos]. +template +__global__ void CopyGemmOutputRowMajorToNCHWKernel( + const T* __restrict__ src, + T* __restrict__ dst, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel) { + int64_t total = cur_parallel * M_per_group * output_image_size; + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { + int64_t pos = idx % output_image_size; + int64_t c = (idx / output_image_size) % M_per_group; + int64_t b_idx = idx / (output_image_size * M_per_group); + int64_t j = b_idx * output_image_size + pos; + // src index for row-major: c * (cur_parallel * output_image_size) + j + dst[b_idx * M * output_image_size + c * output_image_size + pos] = src[c * (cur_parallel * output_image_size) + j]; + } +} + +} // namespace + +template +void DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { + int64_t total = N * M * out_h * out_w; + if (total <= 0) return; + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); + DeformConvAddBiasKernel<<>>(Y, B, N, M, out_h, out_w); +} + +template +void DeformConvCopyGemmOutputRowMajorToNCHW( + cudaStream_t stream, + const T* gemm_output, + T* Y_g, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel) { + int64_t total = cur_parallel * M_per_group * output_image_size; + if (total <= 0) return; + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); + CopyGemmOutputRowMajorToNCHWKernel<<>>( + gemm_output, Y_g, M, M_per_group, output_image_size, cur_parallel); +} + +template +void DeformConvIm2ColImpl( + cudaStream_t stream, + const T* input, + const T* offset, + const T* mask, + T* col_buffer, + int64_t parallel_imgs, + int64_t C, + int64_t H, + int64_t W, + int64_t kH, + int64_t kW, + int64_t out_h, + int64_t out_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t offset_group, + bool use_mask) { + const int64_t num_kernels = static_cast(C) * out_h * out_w * parallel_imgs; + if (num_kernels <= 0) { + return; + } + + const int64_t col_numel = static_cast(C) * kH * kW * parallel_imgs * out_h * out_w; + const bool use_64bit = (num_kernels > static_cast(std::numeric_limits::max())) || + (col_numel > static_cast(std::numeric_limits::max())); + + int blocks = GetGridSize(static_cast(num_kernels), kDeformConvThreadsPerBlock); + + if (use_64bit) { + DeformableIm2ColKernel<<>>( + num_kernels, + input, + offset, + mask, + H, + W, + kH, + kW, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + C, // channels is C + offset_group, + DivMod(out_h), + DivMod(out_w), + DivMod(parallel_imgs), + DivMod(C / offset_group), + use_mask, + col_buffer); + } else { + DeformableIm2ColKernel<<>>( + static_cast(num_kernels), + input, + offset, + mask, + H, + W, + kH, + kW, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + C, // channels is C + offset_group, + DivMod(static_cast(out_h)), + DivMod(static_cast(out_w)), + DivMod(static_cast(parallel_imgs)), + DivMod(static_cast(C / offset_group)), + use_mask, + col_buffer); + } +} + +#define INST_DeformConvIm2ColImpl(T) \ + template void DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool); + +INST_DeformConvIm2ColImpl(float) +INST_DeformConvIm2ColImpl(double) +INST_DeformConvIm2ColImpl(half) +INST_DeformConvIm2ColImpl(BFloat16) + +template void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); +template void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); +template void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); +template void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); + +template void DeformConvAddBiasImpl(cudaStream_t, float*, const float*, int64_t, int64_t, int64_t, int64_t); +template void DeformConvAddBiasImpl(cudaStream_t, double*, const double*, int64_t, int64_t, int64_t, int64_t); +template void DeformConvAddBiasImpl(cudaStream_t, half*, const half*, int64_t, int64_t, int64_t, int64_t); +template void DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); + +// Delegate ORT type to CUDA type (e.g. MLFloat16 -> half); avoids repeating three identical specializations. +#define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ + template <> \ + void DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ + const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ + int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ + int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ + int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ + int64_t dilation_h, int64_t dilation_w, int64_t offset_group, bool use_mask) { \ + DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ + reinterpret_cast(offset), \ + mask ? reinterpret_cast(mask) : nullptr, \ + reinterpret_cast(col_buffer), \ + parallel_imgs, C, H, W, kH, kW, out_h, out_w, \ + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, \ + offset_group, use_mask); \ + } \ + template <> \ + void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ + const ORT_T* gemm_output, ORT_T* Y_g, \ + int64_t M, int64_t M_per_group, \ + int64_t output_image_size, int64_t cur_parallel) { \ + DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ + reinterpret_cast(gemm_output), \ + reinterpret_cast(Y_g), \ + M, M_per_group, output_image_size, cur_parallel); \ + } \ + template <> \ + void DeformConvAddBiasImpl(cudaStream_t stream, ORT_T* Y, const ORT_T* B, \ + int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ + DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ + reinterpret_cast(B), N, M, out_h, out_w); \ + } + +DELEGATE_DEFORM_CONV_IMPL(MLFloat16, half) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h new file mode 100644 index 0000000000000..55f0b0eccf54d --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace cuda { + +// Adds bias to output: Y[n,m,oh,ow] += B[m]. Y is [N, M, out_h, out_w], B is [M]. +// T may be float, double, MLFloat16 (FP16), or BFloat16. +template +void DeformConvAddBiasImpl( + cudaStream_t stream, + T* Y, + const T* B, + int64_t N, + int64_t M, + int64_t out_h, + int64_t out_w); + +// Copies GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) to NCHW slice at Y_g. +// T may be float, double, MLFloat16 (FP16), or BFloat16. +template +void DeformConvCopyGemmOutputRowMajorToNCHW( + cudaStream_t stream, + const T* gemm_output, + T* Y_g, + int64_t M, + int64_t M_per_group, + int64_t output_image_size, + int64_t cur_parallel); + +// Fills col_buffer with deformable im2col. col_buffer layout: row-major [C*kH*kW, parallel_imgs*out_h*out_w]. +// Called once per batch block; caller does GEMM and bias. T may be float, double, MLFloat16 (FP16), or BFloat16. +template +void DeformConvIm2ColImpl( + cudaStream_t stream, + const T* input, // [parallel_imgs, C, H, W] + const T* offset, // [parallel_imgs, offset_group*2*kH*kW, out_h, out_w] + const T* mask, // [parallel_imgs, offset_group*kH*kW, out_h, out_w] or nullptr + T* col_buffer, // [C*kH*kW, parallel_imgs*out_h*out_w] + int64_t parallel_imgs, + int64_t C, + int64_t H, + int64_t W, + int64_t kH, + int64_t kW, + int64_t out_h, + int64_t out_w, + int64_t pad_h, + int64_t pad_w, + int64_t stride_h, + int64_t stride_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t offset_group, + bool use_mask); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 1c42501805d0b..69ba6badac4bf 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -6,6 +6,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/conversion.h" namespace onnxruntime { namespace test { @@ -70,7 +71,56 @@ void RunDeformConvTest(const DeformConvTestParams& params, test.AddOutput("Y", Y_shape, expected_Y, false, rtol, atol); - std::unordered_set excluded = {kTensorrtExecutionProvider, kCudaExecutionProvider, + std::unordered_set excluded = {kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +void RunDeformConvTestFP16(const DeformConvTestParams& params, + const std::vector& X, + const std::vector& W, + const std::vector& offset, + const std::vector& B, + const std::vector* mask, + const std::vector& expected_Y, + int opset = 19, + float rtol = 1e-2f, // FP16 requires looser tolerance + float atol = 1e-2f) { + const int64_t kH = params.kernel_shape[0]; + const int64_t kW = params.kernel_shape[1]; + const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - + params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; + const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - + params.dilation[1] * (kW - 1) - 1) / params.stride[1] + 1; + + OpTester test("DeformConv", opset); + test.AddAttribute("kernel_shape", params.kernel_shape); + test.AddAttribute("strides", params.stride); + test.AddAttribute("pads", params.pad); + test.AddAttribute("dilations", params.dilation); + test.AddAttribute("group", params.n_weight_grps); + test.AddAttribute("offset_group", params.n_offset_grps); + + const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; + const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; + const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; + const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, FloatsToMLFloat16s(X)); + test.AddInput("W", W_shape, FloatsToMLFloat16s(W)); + test.AddInput("offset", offset_shape, FloatsToMLFloat16s(offset)); + test.AddInput("B", {params.n_out_channels}, FloatsToMLFloat16s(B)); + if (mask != nullptr) { + const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; + test.AddInput("mask", mask_shape, FloatsToMLFloat16s(*mask)); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("Y", Y_shape, FloatsToMLFloat16s(expected_Y), false, rtol, atol); + + // Exclude CPU provider as it likely doesn't support FP16 DeformConv + std::unordered_set excluded = {kCpuExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } @@ -107,7 +157,76 @@ TEST(DeformConvTest, MinimalBilinear) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } -// Forward with mask and bias: 2 batches, 2 groups, zero offset -> behaves like grouped conv. +// Minimal case FP16: Same as MinimalBilinear but in FP16. +// Validates CUDA FP16 implementation (specifically coordinate precision logic). +TEST(DeformConvTest, MinimalBilinearFP16) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + std::vector offset = { + 0.5f, 0.f, 0.f, 0.f, + 0.5f, -1.0f, 0.f, 0.f + }; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; + + RunDeformConvTestFP16(p, X, W, offset, B, &mask, expected_Y); +} + +// Forward with mask and bias FP16 +TEST(DeformConvTest, ForwardWithMaskAndBiasFP16) { + DeformConvTestParams p = {}; + p.batch_sz = 2; + p.n_in_channels = 4; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(p.batch_sz * p.n_in_channels * p.in_h * p.in_w); + const size_t w_size = static_cast(p.n_out_channels * (p.n_in_channels / p.n_weight_grps) * 2 * 2); + const size_t offset_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(p.batch_sz * p.n_offset_grps * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.5f, -0.5f}; + + const size_t y_size = static_cast(p.batch_sz * p.n_out_channels * out_h * out_w); + std::vector expected_Y(y_size); + for (int64_t b = 0; b < p.batch_sz; ++b) { + for (int64_t c = 0; c < p.n_out_channels; ++c) { + float val = (c % 2 == 0) ? 0.58f : -0.42f; + for (int64_t i = 0; i < out_h * out_w; ++i) { + expected_Y[b * p.n_out_channels * out_h * out_w + c * out_h * out_w + i] = val; + } + } + } + + RunDeformConvTestFP16(p, X, W, offset, B, &mask, expected_Y); +} // With offset=0 and mask=1, Y = Conv(X,W) + B. Use small inputs and compute expected. TEST(DeformConvTest, ForwardWithMaskAndBias) { DeformConvTestParams p = {}; @@ -199,7 +318,7 @@ TEST(DeformConvTest, ForwardNoMask) { test.AddInput("B", {p.n_out_channels}, B); test.AddOptionalInputEdge(); // no mask test.AddOutput("Y", Y_shape, expected_Y, false, 1e-4f, 1e-4f); - std::unordered_set excluded = {kTensorrtExecutionProvider, kCudaExecutionProvider, + std::unordered_set excluded = {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } @@ -245,7 +364,7 @@ TEST(DeformConvTest, EmptyBatch) { test.AddInput("B", {p.n_out_channels}, B); test.AddOptionalInputEdge(); test.AddOutput("Y", Y_shape, expected_Y); - std::unordered_set excluded = {kTensorrtExecutionProvider, kCudaExecutionProvider, + std::unordered_set excluded = {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } From dd8e7f1cfac0b31c5807d45402c3099b80b795a9 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Thu, 19 Feb 2026 20:28:36 +0800 Subject: [PATCH 04/58] Improve deformconv cuda pref --- .../providers/cuda/nn/deform_conv_impl.cu | 143 +++++++++++++----- 1 file changed, 106 insertions(+), 37 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index 595f07a9c1dc3..3092b5b56a9ab 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -130,10 +130,8 @@ __device__ __inline__ BFloat16 BilinearInterpolate( return BFloat16(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); } -// 1D parallel: each thread handles (in_c, out_y, out_x, out_b), inner loop over kH x kW. -// num_kernels = C * out_h * out_w * parallel_imgs. -// Col layout row-major: rows = C*kH*kW, cols = parallel_imgs*out_h*out_w. -// data_col[col_row_idx * col_stride + c_col] with col_stride = parallel_imgs*out_h*out_w. +// 1D parallel: each thread handles one output pixel (out_b, out_y, out_x) for a specific channel (in_c). +// Optimized memory access patterns and removed redundant calculations. template __global__ void DeformableIm2ColKernel( IndexT num_kernels, @@ -158,66 +156,103 @@ __global__ void DeformableIm2ColKernel( DivMod channel_per_offset_grp_div, bool use_mask, T* __restrict__ data_col) { - + // Reconstruct dimensions from DivMod objects const int64_t out_h = out_h_div.d_; const int64_t out_w = out_w_div.d_; const int64_t parallel_imgs = parallel_imgs_div.d_; - + const int64_t out_size = out_h * out_w; + // The stride for data_col is (batch * out_h * out_w) const int64_t col_stride = parallel_imgs * out_size; for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { IndexT val = index; IndexT out_x, out_y, out_b, in_c; - + // Fast division/modulo to recover coordinates out_w_div.divmod(val, val, out_x); out_h_div.divmod(val, val, out_y); parallel_imgs_div.divmod(val, in_c, out_b); - IndexT offset_grp, dummy; - channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); + // [Optimization 3] Avoid expensive division if offset_group is 1 (very common case). + IndexT offset_grp = 0; + if (offset_group > 1) { + IndexT dummy; + channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); + } + + // [Optimization 2] Common Subexpression Elimination (CSE) & Pointer Arithmetic + // Pre-calculate base pointers to reduce integer arithmetic inside the inner loops. + // 1. Input pointer base for this batch and channel. const T* input_ptr = input + out_b * (channels * height * width) + in_c * (height * width); - const T* offset_ptr = offset + out_b * (offset_group * 2 * weight_h * weight_w * out_size) + - offset_grp * (2 * weight_h * weight_w * out_size); - const T* mask_ptr = use_mask ? (mask + out_b * (offset_group * weight_h * weight_w * out_size) + - offset_grp * (weight_h * weight_w * out_size)) - : nullptr; - const int64_t c_col = out_b * out_size + out_y * out_w + out_x; + // 2. Spatial index in the output feature map. + const int64_t spatial_idx = out_y * out_w + out_x; + + // 3. Offset pointer base calculation. + // Layout: (N, offset_groups, 2*KH*KW, OH, OW) + // We pre-calculate the pointer to the start of the specific (n, g) block, plus spatial_idx. + const int64_t offset_group_block_size = 2 * weight_h * weight_w * out_size; + const T* offset_ptr_base = offset + (out_b * offset_group + offset_grp) * offset_group_block_size + spatial_idx; + + // 4. Mask pointer base calculation (if used). + // Layout: (N, offset_groups, KH*KW, OH, OW) + const T* mask_ptr_base = nullptr; + if (use_mask) { + const int64_t mask_group_block_size = weight_h * weight_w * out_size; + mask_ptr_base = mask + (out_b * offset_group + offset_grp) * mask_group_block_size + spatial_idx; + } + + // 5. Output pointer base calculation. + // data_col Layout: (C * KH * KW, N * OH * OW) + // The current thread writes to the column `c_col` = (b * OH * OW) + spatial_idx. + // The starting row for this channel is `in_c * KH * KW`. + const int64_t c_col = out_b * out_size + spatial_idx; + T* data_col_ptr_base = data_col + (in_c * weight_h * weight_w) * col_stride + c_col; + // 6. Pre-calculate invariant coordinate parts. // Use float for coordinate math when T is half or BFloat16 to avoid precision loss. using CoordT = typename std::conditional::value, float, T>::type; + const CoordT base_h_im = static_cast(out_y * stride_h - pad_h); + const CoordT base_w_im = static_cast(out_x * stride_w - pad_w); #pragma unroll for (int64_t i = 0; i < weight_h; ++i) { #pragma unroll for (int64_t j = 0; j < weight_w; ++j) { - const int64_t mask_idx = i * weight_w + j; - const int64_t offset_idx = 2 * mask_idx; + const int64_t kernel_idx = i * weight_w + j; T mask_val = static_cast(1); if (use_mask) { - mask_val = DeformConvLdg(mask_ptr + mask_idx * out_size + out_y * out_w + out_x); + // Access mask using pre-calculated base and stride. + mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size); + + // [Optimization 1] Early Exit / Pruning + // If mask is 0, the contribution is 0. Skip expensive offset load and interpolation. + // Note: casting to float for comparison is safe for standard floating point types. + if (static_cast(mask_val) == 0.0f) { + data_col_ptr_base[kernel_idx * col_stride] = static_cast(0); + continue; + } } - const int64_t offset_h_idx = (offset_idx)*out_size + out_y * out_w + out_x; - const int64_t offset_w_idx = (offset_idx + 1) * out_size + out_y * out_w + out_x; - const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr + offset_h_idx)); - const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr + offset_w_idx)); + // Calculate offset pointers relative to the base. + // The offset tensor stores (y_offset, x_offset) pairs for each kernel weight. + // Stride between y_offset and x_offset is `out_size`. + const int64_t offset_offset_idx = (2 * kernel_idx) * out_size; - const CoordT h_im = static_cast(out_y * stride_h - pad_h + i * dilation_h) + offset_h; - const CoordT w_im = static_cast(out_x * stride_w - pad_w + j * dilation_w) + offset_w; + const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx)); + const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx + out_size)); - T val = static_cast(0); - if (mask_val != static_cast(0)) { - val = BilinearInterpolate(input_ptr, height, width, h_im, w_im); - } + const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h; + const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w; - const int64_t col_row_idx = (in_c * weight_h * weight_w) + (i * weight_w + j); - data_col[col_row_idx * col_stride + c_col] = val * mask_val; + T val = BilinearInterpolate(input_ptr, height, width, h_im, w_im); + + // Write result to data_col using pre-calculated base. + data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; } } } @@ -225,12 +260,30 @@ __global__ void DeformableIm2ColKernel( // Bias add: Y[n,m,oh,ow] += B[m]. Layout NCHW. template -__global__ void DeformConvAddBiasKernel(T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { - int64_t out_size = out_h * out_w; - int64_t total = N * M * out_size; - for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { - int64_t m = (idx / out_size) % M; - Y[idx] += DeformConvLdg(B + m); +__global__ void DeformConvAddBiasKernel( + T* Y, + const T* B, + DivMod spatial_div, // For dividing by (H * W) + DivMod channel_div, // For dividing by M (channel count) + int64_t total_elements) { + + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += blockDim.x * gridDim.x) { + int64_t val = idx; + int64_t batch_channel_idx, pixel_idx; + + // 1. First decomposition: decompose idx into (batch_channel_idx, pixel_idx) + // 等价于: batch_channel_idx = idx / (H*W); pixel_idx = idx % (H*W); + spatial_div.divmod(val, batch_channel_idx, pixel_idx); + + int64_t batch_idx, channel_idx; + + // 2. Second decomposition: decompose batch_channel_idx into (batch_idx, channel_idx) + // Equivalent to: channel_idx = batch_channel_idx % M; + // We only need channel_idx (i.e. m) + channel_div.divmod(batch_channel_idx, batch_idx, channel_idx); + + // channel_idx is what we need (i.e. m) + Y[idx] += DeformConvLdg(B + channel_idx); } } @@ -261,8 +314,24 @@ template void DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { int64_t total = N * M * out_h * out_w; if (total <= 0) return; + + // 1. Prepare divisor + int64_t out_size = out_h * out_w; + + // 2. Create FastDivMod object (note: ensure int64_t version of DivMod is used here) + DivMod spatial_div(out_size); + DivMod channel_div(M); + int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); - DeformConvAddBiasKernel<<>>(Y, B, N, M, out_h, out_w); + + // 3. Pass DivMod objects + DeformConvAddBiasKernel<<>>( + Y, + B, + spatial_div, + channel_div, + total + ); } template From c5bd48af6ed39691ad60112e91e0c322618b30f4 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Thu, 19 Feb 2026 21:44:19 +0800 Subject: [PATCH 05/58] Add more test cases --- .../providers/cpu/nn/deform_conv_op_test.cc | 299 +++++++++++++++++- 1 file changed, 298 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 69ba6badac4bf..ca35361078c61 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -1,13 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Unit tests for DeformConv (CPU), aligned with PyTorch Vision deform_conv2d tests. +// Unit tests for DeformConv (CPU and Cuda), aligned with PyTorch Vision deform_conv2d tests. // Reference: https://github.com/pytorch/vision/blob/main/test/test_ops.py (TestDeformConv) #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/unittest_util/conversion.h" +#if defined(USE_CUDA) +#include "test/common/cuda_op_test_utils.h" +#endif + namespace onnxruntime { namespace test { @@ -125,6 +129,61 @@ void RunDeformConvTestFP16(const DeformConvTestParams& params, test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } +void RunDeformConvTestDouble(const DeformConvTestParams& params, + const std::vector& X, + const std::vector& W, + const std::vector& offset, + const std::vector& B, + const std::vector* mask, + const std::vector& expected_Y, + int opset = 19, + double rtol = 1e-8, + double atol = 1e-8) { + const int64_t kH = params.kernel_shape[0]; + const int64_t kW = params.kernel_shape[1]; + const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - + params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; + const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - + params.dilation[1] * (kW - 1) - 1) / params.stride[1] + 1; + + OpTester test("DeformConv", opset); + test.AddAttribute("kernel_shape", params.kernel_shape); + test.AddAttribute("strides", params.stride); + test.AddAttribute("pads", params.pad); + test.AddAttribute("dilations", params.dilation); + test.AddAttribute("group", params.n_weight_grps); + test.AddAttribute("offset_group", params.n_offset_grps); + + const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; + const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; + const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; + const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; + + std::vector X_d(X.begin(), X.end()); + std::vector W_d(W.begin(), W.end()); + std::vector offset_d(offset.begin(), offset.end()); + std::vector B_d(B.begin(), B.end()); + std::vector expected_Y_d(expected_Y.begin(), expected_Y.end()); + + test.AddInput("X", X_shape, X_d); + test.AddInput("W", W_shape, W_d); + test.AddInput("offset", offset_shape, offset_d); + test.AddInput("B", {params.n_out_channels}, B_d); + if (mask != nullptr) { + const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; + std::vector mask_d(mask->begin(), mask->end()); + test.AddInput("mask", mask_shape, mask_d); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("Y", Y_shape, expected_Y_d, false, rtol, atol); + + std::unordered_set excluded = {kCpuExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + } // namespace // Minimal case: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). @@ -186,6 +245,118 @@ TEST(DeformConvTest, MinimalBilinearFP16) { RunDeformConvTestFP16(p, X, W, offset, B, &mask, expected_Y); } +// Minimal case BFloat16: Same as MinimalBilinear but in BFloat16 (CUDA BFloat16 coordinate precision). +#if defined(USE_CUDA) +void RunDeformConvTestBFloat16(const DeformConvTestParams& params, + const std::vector& X, + const std::vector& W, + const std::vector& offset, + const std::vector& B, + const std::vector* mask, + const std::vector& expected_Y, + int opset = 19, + float rtol = 1e-2f, + float atol = 1e-2f) { + const int64_t kH = params.kernel_shape[0]; + const int64_t kW = params.kernel_shape[1]; + const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - + params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; + const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - + params.dilation[1] * (kW - 1) - 1) / params.stride[1] + 1; + + OpTester test("DeformConv", opset); + test.AddAttribute("kernel_shape", params.kernel_shape); + test.AddAttribute("strides", params.stride); + test.AddAttribute("pads", params.pad); + test.AddAttribute("dilations", params.dilation); + test.AddAttribute("group", params.n_weight_grps); + test.AddAttribute("offset_group", params.n_offset_grps); + + const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; + const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; + const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; + const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; + + test.AddInput("X", X_shape, FloatsToBFloat16s(X)); + test.AddInput("W", W_shape, FloatsToBFloat16s(W)); + test.AddInput("offset", offset_shape, FloatsToBFloat16s(offset)); + test.AddInput("B", {params.n_out_channels}, FloatsToBFloat16s(B)); + if (mask != nullptr) { + const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; + test.AddInput("mask", mask_shape, FloatsToBFloat16s(*mask)); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("Y", Y_shape, FloatsToBFloat16s(expected_Y), false, rtol, atol); + + std::unordered_set excluded = {kCpuExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + +TEST(DeformConvTest, MinimalBilinearBFloat16) { + int min_cuda_architecture = 800; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; + return; + } + + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + std::vector offset = { + 0.5f, 0.f, 0.f, 0.f, + 0.5f, -1.0f, 0.f, 0.f + }; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; + + RunDeformConvTestBFloat16(p, X, W, offset, B, &mask, expected_Y); +} +#endif + +// Minimal case Double (FP64): Same as MinimalBilinear in double precision. +TEST(DeformConvTest, MinimalBilinearDouble) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 2; + p.in_w = 2; + + std::vector X = {1.f, 2.f, 3.f, 4.f}; + std::vector W = {1.f}; + std::vector offset = { + 0.5f, 0.f, 0.f, 0.f, + 0.5f, -1.0f, 0.f, 0.f + }; + std::vector B = {0.f}; + std::vector mask = {1.f, 1.f, 1.f, 1.f}; + std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; + + RunDeformConvTestDouble(p, X, W, offset, B, &mask, expected_Y); +} + // Forward with mask and bias FP16 TEST(DeformConvTest, ForwardWithMaskAndBiasFP16) { DeformConvTestParams p = {}; @@ -751,5 +922,131 @@ TEST(DeformConvTest, OffsetAtPixelCenters) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } +// Large batch (N=64) to trigger CUDA ComputeInternal chunking loop (b += n_parallel_imgs). +TEST(DeformConvTest, LargeBatchSize) { + const int64_t N = 64; + DeformConvTestParams p = {}; + p.batch_sz = N; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(N * 1 * 3 * 3); + const size_t offset_size = static_cast(N * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(N * 1 * 2 * 2 * out_h * out_w); + const size_t y_size = static_cast(N * 1 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(y_size, 0.04f); // 4 * 0.1 * 0.1 per position + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// group=1, offset_group=2: weights not grouped, offset/mask grouped. +TEST(DeformConvTest, Group1OffsetGroup2) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 4; // C must be divisible by offset_group + p.n_out_channels = 2; + p.n_weight_grps = 1; + p.n_offset_grps = 2; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(1 * 4 * 3 * 3); + const size_t w_size = static_cast(2 * 4 * 2 * 2); + const size_t offset_size = static_cast(1 * 2 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(1 * 2 * 2 * 2 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(w_size, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f, 0.f}; + // group=1: full conv. Each output: 4 in_ch * 4 kernel = 16 * 0.01 = 0.16 per channel, 2 out ch -> 0.16 each + std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.16f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Mask with zeros: exercises CUDA early-exit when mask_val == 0. +TEST(DeformConvTest, MaskWithZeros) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 2 * out_h * out_w); + std::vector offset(offset_size, 0.f); + // mask: (1, 4, 2, 2). Set all to 0 -> output should be 0. + std::vector mask(static_cast(1 * 1 * 2 * 2 * out_h * out_w), 0.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Extreme aspect ratio (1x100): thin horizontal strip to verify coordinate indexing. +TEST(DeformConvTest, ExtremeAspectRatio) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 3}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 100; + // out_h = 1, out_w = (100 - 1*(3-1) - 1)/1 + 1 = 98 + const int64_t out_h = 1; + const int64_t out_w = 98; + + std::vector X(100, 0.1f); + std::vector W(1 * 1 * 1 * 3, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 1 * 3 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 1 * 3 * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + // Each output: 3 * 0.1 * 0.1 = 0.03 + std::vector expected_Y(static_cast(out_h * out_w), 0.03f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + } // namespace test } // namespace onnxruntime From 952b3a1276db6212c079a327536f6aac8655d801 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Fri, 20 Feb 2026 13:35:32 +0800 Subject: [PATCH 06/58] Fix copilot suggestions --- onnxruntime/core/providers/cpu/nn/deform_conv.cc | 9 +++++++-- onnxruntime/core/providers/cuda/nn/deform_conv.cc | 8 ++++++++ onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu | 2 +- onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc | 4 ++-- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 7bca8f3e48301..289af84aa3bc4 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -6,11 +6,9 @@ #include "deform_conv.h" #include -#include #include "core/common/common.h" #include "core/common/narrow.h" -#include "core/common/safeint.h" #include "core/util/math.h" namespace onnxruntime { @@ -175,6 +173,13 @@ Status DeformConv::Compute(OpKernelContext* context) const { const int64_t group = attrs_.group; const int64_t offset_group = attrs_.offset_group; + // Validate input shapes + ORT_RETURN_IF_NOT(stride_h > 0 && stride_w > 0, "Strides must be positive."); + ORT_RETURN_IF_NOT(dilation_h > 0 && dilation_w > 0, "Dilations must be positive."); + ORT_RETURN_IF_NOT(kH > 0 && kW > 0, "Kernel shape must be positive."); + ORT_RETURN_IF_NOT(group > 0, "group must be positive"); + ORT_RETURN_IF_NOT(offset_group > 0, "offset_group must be positive"); + const int64_t out_h = (H + pad_h + pad_h_end - dilation_h * (kH - 1) - 1) / stride_h + 1; const int64_t out_w = (W_in + pad_w + pad_w_end - dilation_w * (kW - 1) - 1) / stride_w + 1; diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 2df4d25cd7892..40eec903a41b3 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -67,9 +67,17 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { const int64_t group = attrs_.group; const int64_t offset_group = attrs_.offset_group; + // Validate input shapes + ORT_RETURN_IF_NOT(stride_h > 0 && stride_w > 0, "Strides must be positive."); + ORT_RETURN_IF_NOT(dilation_h > 0 && dilation_w > 0, "Dilations must be positive."); + ORT_RETURN_IF_NOT(kH > 0 && kW > 0, "Kernel shape must be positive."); + ORT_RETURN_IF_NOT(group > 0, "group must be positive"); + ORT_RETURN_IF_NOT(offset_group > 0, "offset_group must be positive"); + const int64_t out_h = (H + pad_h + pad_h_end - dilation_h * (kH - 1) - 1) / stride_h + 1; const int64_t out_w = (W_in + pad_w + pad_w_end - dilation_w * (kW - 1) - 1) / stride_w + 1; + // Checks ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); ORT_RETURN_IF_NOT(offset_shape[1] == offset_group * 2 * kH * kW, diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index 3092b5b56a9ab..240f741139210 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -272,7 +272,7 @@ __global__ void DeformConvAddBiasKernel( int64_t batch_channel_idx, pixel_idx; // 1. First decomposition: decompose idx into (batch_channel_idx, pixel_idx) - // 等价于: batch_channel_idx = idx / (H*W); pixel_idx = idx % (H*W); + // Equivalent to: batch_channel_idx = idx / (H*W); pixel_idx = idx % (H*W); spatial_div.divmod(val, batch_channel_idx, pixel_idx); int64_t batch_idx, channel_idx; diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index ca35361078c61..4d2e2b64dbce3 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -494,7 +494,7 @@ TEST(DeformConvTest, ForwardNoMask) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } -// Empty batch (like PyTorch batch_sz=0). (like PyTorch batch_sz=0). +// Empty batch (like PyTorch batch_sz=0). TEST(DeformConvTest, EmptyBatch) { DeformConvTestParams p = {}; p.batch_sz = 0; @@ -540,7 +540,7 @@ TEST(DeformConvTest, EmptyBatch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } -// Wrong offset channel count -> expect failure (like PyTorch test_wrong_sizes). -> expect failure (like PyTorch test_wrong_sizes). +// Wrong offset channel count -> expect failure (like PyTorch test_wrong_sizes). TEST(DeformConvTest, WrongOffsetShape) { DeformConvTestParams p = {}; p.batch_sz = 1; From e5c043c635a11d8ec2fcefaffca9b5880ac68bf2 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 21 Feb 2026 01:35:53 +0800 Subject: [PATCH 07/58] Fix default attrs value of DeformConv --- onnxruntime/core/providers/cpu/nn/deform_conv.h | 14 ++++++-------- onnxruntime/core/providers/cuda/nn/deform_conv.h | 14 ++++++-------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.h b/onnxruntime/core/providers/cpu/nn/deform_conv.h index 47e310da004c3..ee4c2981b7573 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.h @@ -14,14 +14,12 @@ namespace onnxruntime { // See https://onnx.ai/onnx/operators/onnx__DeformConv.html struct DeformConvAttributes { explicit DeformConvAttributes(const OpKernelInfo& info) { - Status status = info.GetAttrs("kernel_shape", kernel_shape); - ORT_ENFORCE(status.IsOK(), "Attribute kernel_shape is not set."); - status = info.GetAttrs("strides", strides); - ORT_ENFORCE(status.IsOK(), "Attribute strides is not set."); - status = info.GetAttrs("pads", pads); - ORT_ENFORCE(status.IsOK(), "Attribute pads is not set."); - status = info.GetAttrs("dilations", dilations); - ORT_ENFORCE(status.IsOK(), "Attribute dilations is not set."); + // Optional attributes. + // If not present, they will be empty/default, and handled in Compute. + (void)info.GetAttrs("kernel_shape", kernel_shape); + (void)info.GetAttrs("strides", strides); + (void)info.GetAttrs("pads", pads); + (void)info.GetAttrs("dilations", dilations); group = info.GetAttrOrDefault("group", 1); offset_group = info.GetAttrOrDefault("offset_group", 1); } diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h index 44eb7237b327e..7243d19f71585 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.h +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h @@ -15,14 +15,12 @@ namespace cuda { // See https://onnx.ai/onnx/operators/onnx__DeformConv.html struct DeformConvAttributes { explicit DeformConvAttributes(const OpKernelInfo& info) { - Status status = info.GetAttrs("kernel_shape", kernel_shape); - ORT_ENFORCE(status.IsOK(), "Attribute kernel_shape is not set."); - status = info.GetAttrs("strides", strides); - ORT_ENFORCE(status.IsOK(), "Attribute strides is not set."); - status = info.GetAttrs("pads", pads); - ORT_ENFORCE(status.IsOK(), "Attribute pads is not set."); - status = info.GetAttrs("dilations", dilations); - ORT_ENFORCE(status.IsOK(), "Attribute dilations is not set."); + // Optional attributes. + // If not present, they will be empty/default, and handled in Compute/ComputeInternal. + (void)info.GetAttrs("kernel_shape", kernel_shape); + (void)info.GetAttrs("strides", strides); + (void)info.GetAttrs("pads", pads); + (void)info.GetAttrs("dilations", dilations); group = info.GetAttrOrDefault("group", 1); offset_group = info.GetAttrOrDefault("offset_group", 1); } From eee517da396de161132f8f4171678eeb1b5c7407 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 23 Feb 2026 23:25:14 +0800 Subject: [PATCH 08/58] Fix schema definition for DeformConv op --- .../providers/cpu/cpu_execution_provider.cc | 12 ++-- .../core/providers/cpu/nn/deform_conv.cc | 41 ++++++------ .../providers/cuda/cuda_execution_provider.cc | 2 - .../core/providers/cuda/nn/deform_conv.cc | 14 ++-- onnxruntime/core/util/math_cpu.cc | 65 +++++++++++++++++++ .../providers/cpu/nn/deform_conv_op_test.cc | 8 ++- 6 files changed, 108 insertions(+), 34 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 71bd94e9d8a55..9f19a20a2e680 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1220,7 +1220,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Shape); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); // Opset 20 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, 20, ConstantOfShape); @@ -1317,7 +1318,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Ac class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Atanh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, ConvTranspose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float, DeformConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, double, DeformConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Det); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_float, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float_double, Dropout); @@ -3279,7 +3281,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Resize)>, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 20 BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 289af84aa3bc4..15f73af7d460b 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -318,24 +318,27 @@ Status DeformConv::Compute(OpKernelContext* context) const { // Explicit template instantiation for float and double template class DeformConv; - -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - DeformConv, - 19, - 21, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .InputMemoryType(OrtMemTypeCPUInput, 2) // offset - .InputMemoryType(OrtMemTypeCPUInput, 4), // optional mask - DeformConv); - -ONNX_CPU_OPERATOR_KERNEL( - DeformConv, - 22, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .InputMemoryType(OrtMemTypeCPUInput, 2) - .InputMemoryType(OrtMemTypeCPUInput, 4), - DeformConv); +template class DeformConv; + +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + DeformConv, 19, 21, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) /* offset */ \ + .InputMemoryType(OrtMemTypeCPUInput, 4), /* optional mask */ \ + DeformConv); \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + DeformConv, 22, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 4), \ + DeformConv) + +REGISTER_DEFORMCONV_KERNEL_TYPED(float) +REGISTER_DEFORMCONV_KERNEL_TYPED(double) + +#undef REGISTER_DEFORMCONV_KERNEL_TYPED } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index fdc23b5277370..4f36789191bb2 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1458,7 +1458,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, float, DeformConv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, double, DeformConv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, MLFloat16, DeformConv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 21, BFloat16, DeformConv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Cast); @@ -2585,7 +2584,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 40eec903a41b3..a589497303fd0 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -235,7 +235,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ DeformConv, \ kOnnxDomain, \ @@ -250,14 +250,16 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { kOnnxDomain, \ 22, \ T, \ - kCudaExecutionProvider, \ + kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ DeformConv); -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(double) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) +REGISTER_DEFORMCONV_KERNEL_TYPED(float) +REGISTER_DEFORMCONV_KERNEL_TYPED(double) +REGISTER_DEFORMCONV_KERNEL_TYPED(MLFloat16) +REGISTER_DEFORMCONV_KERNEL_TYPED(BFloat16) + +#undef REGISTER_DEFORMCONV_KERNEL_TYPED } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 25b868b76bce4..9cf8f619c7e1b 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -201,6 +201,71 @@ void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, cons EIGEN_MATMUL_FUNCTION(double) #endif +#ifdef MLAS_SUPPORTS_GEMM_DOUBLE +template <> +void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, + double alpha, const double* A, int lda, const double* B, int ldb, double beta, double* C, + int ldc, ThreadPool* threadpool) { + MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool); +} +#else +template <> +void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, + double alpha, const double* A, int lda, const double* B, int ldb, double beta, double* C, + int ldc, ThreadPool*) { + auto C_mat = EigenMatrixMapWithStrides(C, N, M, Eigen::Stride(ldc, 1)); + if (beta == 0) { + C_mat.setZero(); + } else { + C_mat *= beta; + } + switch (TransA) { + case CblasNoTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += alpha * (ConstEigenMatrixMapWithStrides( + B, N, K, Eigen::Stride(ldb, 1)) * + ConstEigenMatrixMapWithStrides( + A, K, M, Eigen::Stride(lda, 1))); + return; + case CblasTrans: + C_mat.noalias() += alpha * (ConstEigenMatrixMapWithStrides( + B, K, N, Eigen::Stride(ldb, 1)) + .transpose() * + ConstEigenMatrixMapWithStrides( + A, K, M, Eigen::Stride(lda, 1))); + return; + default: + ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); + } + } + case CblasTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += alpha * (ConstEigenMatrixMapWithStrides( + B, N, K, Eigen::Stride(ldb, 1)) * + ConstEigenMatrixMapWithStrides( + A, M, K, Eigen::Stride(lda, 1)) + .transpose()); + return; + case CblasTrans: + C_mat.noalias() += alpha * (ConstEigenMatrixMapWithStrides( + B, K, N, Eigen::Stride(ldb, 1)) + .transpose() * + ConstEigenMatrixMapWithStrides( + A, M, K, Eigen::Stride(lda, 1)) + .transpose()); + return; + default: + ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); + } + } + default: + ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA); + } +} +#endif + template <> void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C, diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 4d2e2b64dbce3..d0f563ea9b79d 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -325,7 +325,7 @@ TEST(DeformConvTest, MinimalBilinearBFloat16) { std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; - RunDeformConvTestBFloat16(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTestBFloat16(p, X, W, offset, B, &mask, expected_Y, 22); } #endif @@ -579,7 +579,8 @@ TEST(DeformConvTest, WrongOffsetShape) { test.AddInput("B", {2}, B); test.AddOptionalInputEdge(); test.AddOutput("Y", Y_shape_wrong, expected_Y); - test.Run(OpTester::ExpectResult::kExpectFailure, "Offset channel count must be offset_group * 2 * kH * kW"); + std::unordered_set excluded = {kTensorrtExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectFailure, "Offset channel count must be offset_group * 2 * kH * kW", excluded); } // Wrong mask channel count -> expect failure. @@ -624,7 +625,8 @@ TEST(DeformConvTest, WrongMaskShape) { test.AddInput("B", {2}, B); test.AddInput("mask", mask_shape_wrong, wrong_mask); test.AddOutput("Y", Y_shape_mask, expected_Y); - test.Run(OpTester::ExpectResult::kExpectFailure, "Mask channel count"); + std::unordered_set excluded = {kTensorrtExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectFailure, "Mask channel count", excluded); } // Opset 22 (same behavior, different opset). From 12b19c8a4342bbace86e8356849fec00ff263baf Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 23 Feb 2026 23:44:09 +0800 Subject: [PATCH 09/58] Refactor DeformConv test cases --- .../core/providers/cuda/nn/deform_conv.cc | 2 +- .../providers/cpu/nn/deform_conv_op_test.cc | 269 ++++++------------ 2 files changed, 87 insertions(+), 184 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index a589497303fd0..fa84c44161fd3 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -235,7 +235,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ DeformConv, \ kOnnxDomain, \ diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index d0f563ea9b79d..486f577f9791f 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -32,6 +32,57 @@ struct DeformConvTestParams { int64_t in_w; }; +// Traits for type-specific DeformConv test behavior. +template +struct DeformConvTestTraits; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return v; } + static std::unordered_set ExcludedProviders() { + return {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-5f; } + static constexpr float DefaultAtol() { return 1e-5f; } +}; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return FloatsToMLFloat16s(v); } + static std::unordered_set ExcludedProviders() { + return {kCpuExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-2f; } + static constexpr float DefaultAtol() { return 1e-2f; } +}; + +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { + return std::vector(v.begin(), v.end()); + } + static std::unordered_set ExcludedProviders() { + return {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr double DefaultRtol() { return 1e-8; } + static constexpr double DefaultAtol() { return 1e-8; } +}; + +#if defined(USE_CUDA) +template <> +struct DeformConvTestTraits { + static std::vector Convert(const std::vector& v) { return FloatsToBFloat16s(v); } + static std::unordered_set ExcludedProviders() { + return {kCpuExecutionProvider, kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, kQnnExecutionProvider}; + } + static constexpr float DefaultRtol() { return 1e-2f; } + static constexpr float DefaultAtol() { return 1e-2f; } +}; +#endif + +template void RunDeformConvTest(const DeformConvTestParams& params, const std::vector& X, const std::vector& W, @@ -40,8 +91,8 @@ void RunDeformConvTest(const DeformConvTestParams& params, const std::vector* mask, const std::vector& expected_Y, int opset = 19, - float rtol = 1e-5f, - float atol = 1e-5f) { + decltype(DeformConvTestTraits::DefaultRtol()) rtol = DeformConvTestTraits::DefaultRtol(), + decltype(DeformConvTestTraits::DefaultAtol()) atol = DeformConvTestTraits::DefaultAtol()) { const int64_t kH = params.kernel_shape[0]; const int64_t kW = params.kernel_shape[1]; const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - @@ -62,126 +113,26 @@ void RunDeformConvTest(const DeformConvTestParams& params, const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; - test.AddInput("X", X_shape, X); - test.AddInput("W", W_shape, W); - test.AddInput("offset", offset_shape, offset); - test.AddInput("B", {params.n_out_channels}, B); - if (mask != nullptr) { - const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; - test.AddInput("mask", mask_shape, *mask); - } else { - test.AddOptionalInputEdge(); - } - - test.AddOutput("Y", Y_shape, expected_Y, false, rtol, atol); - - std::unordered_set excluded = {kTensorrtExecutionProvider, - kOpenVINOExecutionProvider, kQnnExecutionProvider}; - test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); -} - -void RunDeformConvTestFP16(const DeformConvTestParams& params, - const std::vector& X, - const std::vector& W, - const std::vector& offset, - const std::vector& B, - const std::vector* mask, - const std::vector& expected_Y, - int opset = 19, - float rtol = 1e-2f, // FP16 requires looser tolerance - float atol = 1e-2f) { - const int64_t kH = params.kernel_shape[0]; - const int64_t kW = params.kernel_shape[1]; - const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - - params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; - const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - - params.dilation[1] * (kW - 1) - 1) / params.stride[1] + 1; - - OpTester test("DeformConv", opset); - test.AddAttribute("kernel_shape", params.kernel_shape); - test.AddAttribute("strides", params.stride); - test.AddAttribute("pads", params.pad); - test.AddAttribute("dilations", params.dilation); - test.AddAttribute("group", params.n_weight_grps); - test.AddAttribute("offset_group", params.n_offset_grps); - - const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; - const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; - const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; - const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; + auto X_t = DeformConvTestTraits::Convert(X); + auto W_t = DeformConvTestTraits::Convert(W); + auto offset_t = DeformConvTestTraits::Convert(offset); + auto B_t = DeformConvTestTraits::Convert(B); + auto expected_Y_t = DeformConvTestTraits::Convert(expected_Y); - test.AddInput("X", X_shape, FloatsToMLFloat16s(X)); - test.AddInput("W", W_shape, FloatsToMLFloat16s(W)); - test.AddInput("offset", offset_shape, FloatsToMLFloat16s(offset)); - test.AddInput("B", {params.n_out_channels}, FloatsToMLFloat16s(B)); + test.AddInput("X", X_shape, X_t); + test.AddInput("W", W_shape, W_t); + test.AddInput("offset", offset_shape, offset_t); + test.AddInput("B", {params.n_out_channels}, B_t); if (mask != nullptr) { const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; - test.AddInput("mask", mask_shape, FloatsToMLFloat16s(*mask)); + test.AddInput("mask", mask_shape, DeformConvTestTraits::Convert(*mask)); } else { - test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); } - test.AddOutput("Y", Y_shape, FloatsToMLFloat16s(expected_Y), false, rtol, atol); + test.AddOutput("Y", Y_shape, expected_Y_t, false, rtol, atol); - // Exclude CPU provider as it likely doesn't support FP16 DeformConv - std::unordered_set excluded = {kCpuExecutionProvider, kTensorrtExecutionProvider, - kOpenVINOExecutionProvider, kQnnExecutionProvider}; - test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); -} - -void RunDeformConvTestDouble(const DeformConvTestParams& params, - const std::vector& X, - const std::vector& W, - const std::vector& offset, - const std::vector& B, - const std::vector* mask, - const std::vector& expected_Y, - int opset = 19, - double rtol = 1e-8, - double atol = 1e-8) { - const int64_t kH = params.kernel_shape[0]; - const int64_t kW = params.kernel_shape[1]; - const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - - params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; - const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - - params.dilation[1] * (kW - 1) - 1) / params.stride[1] + 1; - - OpTester test("DeformConv", opset); - test.AddAttribute("kernel_shape", params.kernel_shape); - test.AddAttribute("strides", params.stride); - test.AddAttribute("pads", params.pad); - test.AddAttribute("dilations", params.dilation); - test.AddAttribute("group", params.n_weight_grps); - test.AddAttribute("offset_group", params.n_offset_grps); - - const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; - const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; - const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; - const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; - - std::vector X_d(X.begin(), X.end()); - std::vector W_d(W.begin(), W.end()); - std::vector offset_d(offset.begin(), offset.end()); - std::vector B_d(B.begin(), B.end()); - std::vector expected_Y_d(expected_Y.begin(), expected_Y.end()); - - test.AddInput("X", X_shape, X_d); - test.AddInput("W", W_shape, W_d); - test.AddInput("offset", offset_shape, offset_d); - test.AddInput("B", {params.n_out_channels}, B_d); - if (mask != nullptr) { - const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; - std::vector mask_d(mask->begin(), mask->end()); - test.AddInput("mask", mask_shape, mask_d); - } else { - test.AddOptionalInputEdge(); - } - - test.AddOutput("Y", Y_shape, expected_Y_d, false, rtol, atol); - - std::unordered_set excluded = {kCpuExecutionProvider, kTensorrtExecutionProvider, - kOpenVINOExecutionProvider, kQnnExecutionProvider}; - test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", DeformConvTestTraits::ExcludedProviders()); } } // namespace @@ -213,7 +164,7 @@ TEST(DeformConvTest, MinimalBilinear) { std::vector mask = {1.f, 1.f, 1.f, 1.f}; // (1,1,2,2) std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Minimal case FP16: Same as MinimalBilinear but in FP16. @@ -242,59 +193,11 @@ TEST(DeformConvTest, MinimalBilinearFP16) { std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; - RunDeformConvTestFP16(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Minimal case BFloat16: Same as MinimalBilinear but in BFloat16 (CUDA BFloat16 coordinate precision). #if defined(USE_CUDA) -void RunDeformConvTestBFloat16(const DeformConvTestParams& params, - const std::vector& X, - const std::vector& W, - const std::vector& offset, - const std::vector& B, - const std::vector* mask, - const std::vector& expected_Y, - int opset = 19, - float rtol = 1e-2f, - float atol = 1e-2f) { - const int64_t kH = params.kernel_shape[0]; - const int64_t kW = params.kernel_shape[1]; - const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - - params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; - const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - - params.dilation[1] * (kW - 1) - 1) / params.stride[1] + 1; - - OpTester test("DeformConv", opset); - test.AddAttribute("kernel_shape", params.kernel_shape); - test.AddAttribute("strides", params.stride); - test.AddAttribute("pads", params.pad); - test.AddAttribute("dilations", params.dilation); - test.AddAttribute("group", params.n_weight_grps); - test.AddAttribute("offset_group", params.n_offset_grps); - - const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; - const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; - const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; - const std::vector Y_shape = {params.batch_sz, params.n_out_channels, out_h, out_w}; - - test.AddInput("X", X_shape, FloatsToBFloat16s(X)); - test.AddInput("W", W_shape, FloatsToBFloat16s(W)); - test.AddInput("offset", offset_shape, FloatsToBFloat16s(offset)); - test.AddInput("B", {params.n_out_channels}, FloatsToBFloat16s(B)); - if (mask != nullptr) { - const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; - test.AddInput("mask", mask_shape, FloatsToBFloat16s(*mask)); - } else { - test.AddOptionalInputEdge(); - } - - test.AddOutput("Y", Y_shape, FloatsToBFloat16s(expected_Y), false, rtol, atol); - - std::unordered_set excluded = {kCpuExecutionProvider, kTensorrtExecutionProvider, - kOpenVINOExecutionProvider, kQnnExecutionProvider}; - test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); -} - TEST(DeformConvTest, MinimalBilinearBFloat16) { int min_cuda_architecture = 800; if (!HasCudaEnvironment(min_cuda_architecture)) { @@ -325,7 +228,7 @@ TEST(DeformConvTest, MinimalBilinearBFloat16) { std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; - RunDeformConvTestBFloat16(p, X, W, offset, B, &mask, expected_Y, 22); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); } #endif @@ -354,7 +257,7 @@ TEST(DeformConvTest, MinimalBilinearDouble) { std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; - RunDeformConvTestDouble(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Forward with mask and bias FP16 @@ -396,7 +299,7 @@ TEST(DeformConvTest, ForwardWithMaskAndBiasFP16) { } } - RunDeformConvTestFP16(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // With offset=0 and mask=1, Y = Conv(X,W) + B. Use small inputs and compute expected. TEST(DeformConvTest, ForwardWithMaskAndBias) { @@ -439,7 +342,7 @@ TEST(DeformConvTest, ForwardWithMaskAndBias) { } } - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); } // No mask (optional): same as above but mask omitted; compare to run with ones mask via tolerance. @@ -651,7 +554,7 @@ TEST(DeformConvTest, Opset22) { std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 2.f, 3.f, 4.f}; - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); } // Non-square kernel (kH != kW): 2x3 kernel, zero offset -> same as standard conv. @@ -685,7 +588,7 @@ TEST(DeformConvTest, NonSquareKernel) { // With offset=0, mask=1: each output = 6 * 0.1 * 0.1 = 0.06 (9 positions) std::vector expected_Y(static_cast(out_h * out_w), 0.06f); - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Asymmetric stride (stride_h != stride_w): stride=(2,1), zero offset. @@ -718,7 +621,7 @@ TEST(DeformConvTest, AsymmetricStride) { std::vector B = {0.f}; std::vector expected_Y(static_cast(out_h * out_w), 0.04f); - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // groups > 0 and non-zero offset; expected from deform_conv_expected_gen.py (seed=123). @@ -743,7 +646,7 @@ TEST(DeformConvTest, GroupsWithNonZeroOffset) { std::vector mask = {-0.031861f, -0.478956f, 0.766809f, 0.027468f, 0.047470f, -0.923866f, -1.060737f, -2.324446f, -2.062818f, 0.006375f, -0.989555f, 0.701609f, -0.982238f, 0.277031f, 0.645495f, -0.895681f, 0.492753f, -0.014078f, -0.274663f, -0.764091f, -0.587157f, 1.195165f, -1.209575f, -0.556008f, -0.077105f, 1.277377f, -1.459629f, -2.159528f, -0.706709f, -0.922245f, 3.895372f, -0.602697f}; std::vector expected_Y = {0.971546f, 1.139858f, 0.452817f, 1.863882f, -0.565266f, 1.423187f, -2.462833f, -0.104923f}; - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); } // Sampling out of bounds: offset pushes sampling to (-5,-5), BilinearInterpolate returns 0. @@ -769,7 +672,7 @@ TEST(DeformConvTest, OutOfBoundsSampling) { std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {0.f, 0.f, 0.f, 0.f}; - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Dilation > 1: 2x2 kernel with dilation (2,2), zero offset -> 4 sample points with stride 2. @@ -803,7 +706,7 @@ TEST(DeformConvTest, DilationGt1) { // Each output: 4 samples at (0,0),(0,2),(2,0),(2,2) -> 4 * 0.1 * 0.1 = 0.04 std::vector expected_Y(static_cast(out_h * out_w), 0.04f); - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Decoupled groups: group=2, offset_group=1 (one offset map shared by all input channels). @@ -836,7 +739,7 @@ TEST(DeformConvTest, DecoupledGroups) { // Zero offset -> grouped conv. Per output ch: 2 in_ch * 4 kernel * 0.01 = 0.08 std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.08f); - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Asymmetric padding: pads [top=1, left=0, bottom=0, right=1]; output 3x3, some positions have OOB samples. @@ -868,7 +771,7 @@ TEST(DeformConvTest, AsymmetricPadding) { // Row 0: (0,0),(0,1) 2 valid -> 0.02; (0,2) only (0,2) in, (0,3) OOB -> 1 valid -> 0.01. Row 1/2: as before. std::vector expected_Y = {0.02f, 0.02f, 0.01f, 0.04f, 0.04f, 0.02f, 0.04f, 0.04f, 0.02f}; - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Tiny offset (near zero): offset (1e-6, 1e-6), sample ~(0,0) -> bilinear ≈ X[0,0]. Use 1x1 input for 1 output. @@ -893,7 +796,7 @@ TEST(DeformConvTest, TinyOffset) { std::vector mask = {1.f}; std::vector expected_Y = {1.f}; // bilinear at (1e-6, 1e-6) ≈ 1 - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 19, 1e-4f, 1e-4f); } // Offset (0.5, 0.5) at each kernel point: sampling at (i+0.5, j+0.5) -> (0.5,0.5),(0.5,1.5),(1.5,0.5),(1.5,1.5). @@ -921,7 +824,7 @@ TEST(DeformConvTest, OffsetAtPixelCenters) { std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {1.6875f}; // op output: one center sample 2.5 + boundary samples - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Large batch (N=64) to trigger CUDA ComputeInternal chunking loop (b += n_parallel_imgs). @@ -954,7 +857,7 @@ TEST(DeformConvTest, LargeBatchSize) { std::vector B = {0.f}; std::vector expected_Y(y_size, 0.04f); // 4 * 0.1 * 0.1 per position - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // group=1, offset_group=2: weights not grouped, offset/mask grouped. @@ -987,7 +890,7 @@ TEST(DeformConvTest, Group1OffsetGroup2) { // group=1: full conv. Each output: 4 in_ch * 4 kernel = 16 * 0.01 = 0.16 per channel, 2 out ch -> 0.16 each std::vector expected_Y(static_cast(1 * 2 * out_h * out_w), 0.16f); - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Mask with zeros: exercises CUDA early-exit when mask_val == 0. @@ -1016,7 +919,7 @@ TEST(DeformConvTest, MaskWithZeros) { std::vector B = {0.f}; std::vector expected_Y(static_cast(out_h * out_w), 0.f); - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } // Extreme aspect ratio (1x100): thin horizontal strip to verify coordinate indexing. @@ -1047,7 +950,7 @@ TEST(DeformConvTest, ExtremeAspectRatio) { // Each output: 3 * 0.1 * 0.1 = 0.03 std::vector expected_Y(static_cast(out_h * out_w), 0.03f); - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } } // namespace test From d6c19be5f8dd2481c1bec9bf17d496d745edae57 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sun, 1 Mar 2026 23:26:24 +0800 Subject: [PATCH 10/58] Fix OrtMemTypeCPUInput issue and add cuda error check --- .../core/providers/cpu/nn/deform_conv.cc | 11 ++--- .../core/providers/cuda/nn/deform_conv.cc | 10 ++--- .../providers/cuda/nn/deform_conv_impl.cu | 45 ++++++++++--------- .../core/providers/cuda/nn/deform_conv_impl.h | 7 +-- onnxruntime/core/util/math_cpu.cc | 7 ++- 5 files changed, 42 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 15f73af7d460b..93ba26ead6c84 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -296,7 +296,8 @@ Status DeformConv::Compute(OpKernelContext* context) const { static_cast(0), // beta Y_g, // C narrow(output_image_size), // ldc - thread_pool); + thread_pool, + nullptr); // mlas_backend_kernel_selector_config } } @@ -324,16 +325,12 @@ template class DeformConv; ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ DeformConv, 19, 21, T, \ KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, 2) /* offset */ \ - .InputMemoryType(OrtMemTypeCPUInput, 4), /* optional mask */ \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ DeformConv); \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ DeformConv, 22, T, \ KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 4), \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ DeformConv) REGISTER_DEFORMCONV_KERNEL_TYPED(float) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index fa84c44161fd3..b8b3508c468c5 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -149,7 +149,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { const T* offset_block = offset_data + b * (offset_group * 2 * kernel_size * output_image_size); const T* mask_block = use_mask ? (mask_data + b * (offset_group * kernel_size * output_image_size)) : nullptr; - DeformConvIm2ColImpl( + ORT_RETURN_IF_ERROR(DeformConvIm2ColImpl( stream, X_block, offset_block, @@ -170,7 +170,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { dilation_h, dilation_w, offset_group, - use_mask); + use_mask)); for (int64_t g = 0; g < group; ++g) { const T* W_g = Wdata + g * (M / group) * kernel_dim; @@ -217,19 +217,19 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { // The output gemm_output_buffer is now Row-Major [M/group, cur_out_size]. // We need to copy it to Y_g (NCHW). - DeformConvCopyGemmOutputRowMajorToNCHW( + ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW( stream, gemm_output_buffer.get(), Y_g, M, M / group, output_image_size, - cur_parallel); + cur_parallel)); } } if (Bdata != nullptr) { - DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w); + ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w)); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index 240f741139210..728dec1a3beb6 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -311,9 +311,9 @@ __global__ void CopyGemmOutputRowMajorToNCHWKernel( } // namespace template -void DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { +Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { int64_t total = N * M * out_h * out_w; - if (total <= 0) return; + if (total <= 0) return Status::OK(); // 1. Prepare divisor int64_t out_size = out_h * out_w; @@ -332,10 +332,11 @@ void DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int channel_div, total ); + return CUDA_CALL(cudaGetLastError()); } template -void DeformConvCopyGemmOutputRowMajorToNCHW( +Status DeformConvCopyGemmOutputRowMajorToNCHW( cudaStream_t stream, const T* gemm_output, T* Y_g, @@ -344,14 +345,15 @@ void DeformConvCopyGemmOutputRowMajorToNCHW( int64_t output_image_size, int64_t cur_parallel) { int64_t total = cur_parallel * M_per_group * output_image_size; - if (total <= 0) return; + if (total <= 0) return Status::OK(); int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); CopyGemmOutputRowMajorToNCHWKernel<<>>( gemm_output, Y_g, M, M_per_group, output_image_size, cur_parallel); + return CUDA_CALL(cudaGetLastError()); } template -void DeformConvIm2ColImpl( +Status DeformConvIm2ColImpl( cudaStream_t stream, const T* input, const T* offset, @@ -375,7 +377,7 @@ void DeformConvIm2ColImpl( bool use_mask) { const int64_t num_kernels = static_cast(C) * out_h * out_w * parallel_imgs; if (num_kernels <= 0) { - return; + return Status::OK(); } const int64_t col_numel = static_cast(C) * kH * kW * parallel_imgs * out_h * out_w; @@ -433,36 +435,37 @@ void DeformConvIm2ColImpl( use_mask, col_buffer); } + return CUDA_CALL(cudaGetLastError()); } #define INST_DeformConvIm2ColImpl(T) \ - template void DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool); + template Status DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool); INST_DeformConvIm2ColImpl(float) INST_DeformConvIm2ColImpl(double) INST_DeformConvIm2ColImpl(half) INST_DeformConvIm2ColImpl(BFloat16) -template void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); -template void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); -template void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); -template void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); -template void DeformConvAddBiasImpl(cudaStream_t, float*, const float*, int64_t, int64_t, int64_t, int64_t); -template void DeformConvAddBiasImpl(cudaStream_t, double*, const double*, int64_t, int64_t, int64_t, int64_t); -template void DeformConvAddBiasImpl(cudaStream_t, half*, const half*, int64_t, int64_t, int64_t, int64_t); -template void DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, float*, const float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, double*, const double*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, half*, const half*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); // Delegate ORT type to CUDA type (e.g. MLFloat16 -> half); avoids repeating three identical specializations. #define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ template <> \ - void DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ + Status DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ int64_t dilation_h, int64_t dilation_w, int64_t offset_group, bool use_mask) { \ - DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ + return DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ reinterpret_cast(offset), \ mask ? reinterpret_cast(mask) : nullptr, \ reinterpret_cast(col_buffer), \ @@ -471,19 +474,19 @@ template void DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFl offset_group, use_mask); \ } \ template <> \ - void DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ + Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ const ORT_T* gemm_output, ORT_T* Y_g, \ int64_t M, int64_t M_per_group, \ int64_t output_image_size, int64_t cur_parallel) { \ - DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ + return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ reinterpret_cast(gemm_output), \ reinterpret_cast(Y_g), \ M, M_per_group, output_image_size, cur_parallel); \ } \ template <> \ - void DeformConvAddBiasImpl(cudaStream_t stream, ORT_T* Y, const ORT_T* B, \ + Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T* Y, const ORT_T* B, \ int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ - DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ + return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ reinterpret_cast(B), N, M, out_h, out_w); \ } diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h index 55f0b0eccf54d..2f38d6ef18a7c 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h @@ -4,6 +4,7 @@ #pragma once #include +#include "core/common/status.h" namespace onnxruntime { namespace cuda { @@ -11,7 +12,7 @@ namespace cuda { // Adds bias to output: Y[n,m,oh,ow] += B[m]. Y is [N, M, out_h, out_w], B is [M]. // T may be float, double, MLFloat16 (FP16), or BFloat16. template -void DeformConvAddBiasImpl( +Status DeformConvAddBiasImpl( cudaStream_t stream, T* Y, const T* B, @@ -23,7 +24,7 @@ void DeformConvAddBiasImpl( // Copies GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) to NCHW slice at Y_g. // T may be float, double, MLFloat16 (FP16), or BFloat16. template -void DeformConvCopyGemmOutputRowMajorToNCHW( +Status DeformConvCopyGemmOutputRowMajorToNCHW( cudaStream_t stream, const T* gemm_output, T* Y_g, @@ -35,7 +36,7 @@ void DeformConvCopyGemmOutputRowMajorToNCHW( // Fills col_buffer with deformable im2col. col_buffer layout: row-major [C*kH*kW, parallel_imgs*out_h*out_w]. // Called once per batch block; caller does GEMM and bias. T may be float, double, MLFloat16 (FP16), or BFloat16. template -void DeformConvIm2ColImpl( +Status DeformConvIm2ColImpl( cudaStream_t stream, const T* input, // [parallel_imgs, C, H, W] const T* offset, // [parallel_imgs, offset_group*2*kH*kW, out_h, out_w] diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 9cf8f619c7e1b..30642f1fd377c 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -205,14 +205,17 @@ EIGEN_MATMUL_FUNCTION(double) template <> void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, double alpha, const double* A, int lda, const double* B, int ldb, double beta, double* C, - int ldc, ThreadPool* threadpool) { + int ldc, ThreadPool* threadpool, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { + ORT_UNUSED_PARAMETER(mlas_backend_kernel_selector_config); + // DGEMM in MLAS has no BackendKernelSelectorConfig parameter MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool); } #else template <> void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, double alpha, const double* A, int lda, const double* B, int ldb, double beta, double* C, - int ldc, ThreadPool*) { + int ldc, ThreadPool*, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { + ORT_UNUSED_PARAMETER(mlas_backend_kernel_selector_config); auto C_mat = EigenMatrixMapWithStrides(C, N, M, Eigen::Stride(ldc, 1)); if (beta == 0) { C_mat.setZero(); From 12fd042b4db8d1917083de46b0d66a995df0a3e9 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 2 Mar 2026 01:16:30 +0800 Subject: [PATCH 11/58] Remove GemmEx double specialization --- .../core/providers/cpu/nn/deform_conv.cc | 7 +- onnxruntime/core/util/math_cpu.cc | 68 ------------------- 2 files changed, 2 insertions(+), 73 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 93ba26ead6c84..4d2bb69ec43f1 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -282,7 +282,7 @@ Status DeformConv::Compute(OpKernelContext* context) const { // W matrix: [M/group, kernel_dim] // Col matrix: [kernel_dim, output_image_size] // Y matrix: [M/group, output_image_size] - math::GemmEx( + math::Gemm( CblasNoTrans, CblasNoTrans, narrow(M / group), // M @@ -290,14 +290,11 @@ Status DeformConv::Compute(OpKernelContext* context) const { narrow(kernel_dim), // K static_cast(1), // alpha weight_g, // A - narrow(kernel_dim), // lda col_g, // B - narrow(output_image_size), // ldb static_cast(0), // beta Y_g, // C - narrow(output_image_size), // ldc thread_pool, - nullptr); // mlas_backend_kernel_selector_config + nullptr); // mlas_backend_kernel_selector_config } } diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 30642f1fd377c..25b868b76bce4 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -201,74 +201,6 @@ void MatMul(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, cons EIGEN_MATMUL_FUNCTION(double) #endif -#ifdef MLAS_SUPPORTS_GEMM_DOUBLE -template <> -void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, - double alpha, const double* A, int lda, const double* B, int ldb, double beta, double* C, - int ldc, ThreadPool* threadpool, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { - ORT_UNUSED_PARAMETER(mlas_backend_kernel_selector_config); - // DGEMM in MLAS has no BackendKernelSelectorConfig parameter - MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool); -} -#else -template <> -void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, - double alpha, const double* A, int lda, const double* B, int ldb, double beta, double* C, - int ldc, ThreadPool*, const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { - ORT_UNUSED_PARAMETER(mlas_backend_kernel_selector_config); - auto C_mat = EigenMatrixMapWithStrides(C, N, M, Eigen::Stride(ldc, 1)); - if (beta == 0) { - C_mat.setZero(); - } else { - C_mat *= beta; - } - switch (TransA) { - case CblasNoTrans: { - switch (TransB) { - case CblasNoTrans: - C_mat.noalias() += alpha * (ConstEigenMatrixMapWithStrides( - B, N, K, Eigen::Stride(ldb, 1)) * - ConstEigenMatrixMapWithStrides( - A, K, M, Eigen::Stride(lda, 1))); - return; - case CblasTrans: - C_mat.noalias() += alpha * (ConstEigenMatrixMapWithStrides( - B, K, N, Eigen::Stride(ldb, 1)) - .transpose() * - ConstEigenMatrixMapWithStrides( - A, K, M, Eigen::Stride(lda, 1))); - return; - default: - ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); - } - } - case CblasTrans: { - switch (TransB) { - case CblasNoTrans: - C_mat.noalias() += alpha * (ConstEigenMatrixMapWithStrides( - B, N, K, Eigen::Stride(ldb, 1)) * - ConstEigenMatrixMapWithStrides( - A, M, K, Eigen::Stride(lda, 1)) - .transpose()); - return; - case CblasTrans: - C_mat.noalias() += alpha * (ConstEigenMatrixMapWithStrides( - B, K, N, Eigen::Stride(ldb, 1)) - .transpose() * - ConstEigenMatrixMapWithStrides( - A, M, K, Eigen::Stride(lda, 1)) - .transpose()); - return; - default: - ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); - } - } - default: - ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA); - } -} -#endif - template <> void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C, From 9b069e33443cef40accec8a5ff1c9e41591287c9 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 2 Mar 2026 01:22:15 +0800 Subject: [PATCH 12/58] Fix potential integer overflow in CUDA DeformableIm2ColKernel --- .../core/providers/cuda/nn/deform_conv_impl.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index 728dec1a3beb6..c9e44eb214566 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -186,31 +186,31 @@ __global__ void DeformableIm2ColKernel( // Pre-calculate base pointers to reduce integer arithmetic inside the inner loops. // 1. Input pointer base for this batch and channel. - const T* input_ptr = input + out_b * (channels * height * width) + in_c * (height * width); + const T* input_ptr = input + static_cast(out_b) * (channels * height * width) + static_cast(in_c) * (height * width); // 2. Spatial index in the output feature map. - const int64_t spatial_idx = out_y * out_w + out_x; + const int64_t spatial_idx = static_cast(out_y) * out_w + static_cast(out_x); // 3. Offset pointer base calculation. // Layout: (N, offset_groups, 2*KH*KW, OH, OW) // We pre-calculate the pointer to the start of the specific (n, g) block, plus spatial_idx. const int64_t offset_group_block_size = 2 * weight_h * weight_w * out_size; - const T* offset_ptr_base = offset + (out_b * offset_group + offset_grp) * offset_group_block_size + spatial_idx; + const T* offset_ptr_base = offset + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * offset_group_block_size + spatial_idx; // 4. Mask pointer base calculation (if used). // Layout: (N, offset_groups, KH*KW, OH, OW) const T* mask_ptr_base = nullptr; if (use_mask) { const int64_t mask_group_block_size = weight_h * weight_w * out_size; - mask_ptr_base = mask + (out_b * offset_group + offset_grp) * mask_group_block_size + spatial_idx; + mask_ptr_base = mask + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * mask_group_block_size + spatial_idx; } // 5. Output pointer base calculation. // data_col Layout: (C * KH * KW, N * OH * OW) // The current thread writes to the column `c_col` = (b * OH * OW) + spatial_idx. // The starting row for this channel is `in_c * KH * KW`. - const int64_t c_col = out_b * out_size + spatial_idx; - T* data_col_ptr_base = data_col + (in_c * weight_h * weight_w) * col_stride + c_col; + const int64_t c_col = static_cast(out_b) * out_size + spatial_idx; + T* data_col_ptr_base = data_col + (static_cast(in_c) * weight_h * weight_w) * col_stride + c_col; // 6. Pre-calculate invariant coordinate parts. // Use float for coordinate math when T is half or BFloat16 to avoid precision loss. From cbadf1312dddca2406052ac23f623a202b7d3e95 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 2 Mar 2026 01:27:45 +0800 Subject: [PATCH 13/58] Optimize CPU DeformableIm2Col loop order for better cache locality --- .../core/providers/cpu/nn/deform_conv.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 4d2bb69ec43f1..a5793e5ce9cf0 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -67,17 +67,17 @@ void DeformableIm2col( const int64_t channel_per_offset_group = channels / offset_groups; - // We iterate over the output matrix columns (spatial locations) - // and fill the matrix rows (channels * kernels). - // Note: Parallelization can be applied here over 'c_col' (spatial index). + // Loop order optimized for cache locality: + // Outer loop: Channels + // Inner loop: Spatial locations (c_col) + // This ensures sequential access to data_col and better locality for data_im. - for (int64_t c_col = 0; c_col < height_col * width_col; ++c_col) { - const int64_t w_col = c_col % width_col; - const int64_t h_col = c_col / width_col; + for (int64_t c_im = 0; c_im < channels; ++c_im) { + const int64_t offset_grp = c_im / channel_per_offset_group; - // For each spatial location (h_col, w_col), we iterate over all input channels - for (int64_t c_im = 0; c_im < channels; ++c_im) { - const int64_t offset_grp = c_im / channel_per_offset_group; + for (int64_t c_col = 0; c_col < height_col * width_col; ++c_col) { + const int64_t w_col = c_col % width_col; + const int64_t h_col = c_col / width_col; // Iterate over kernel window for (int64_t i = 0; i < kernel_h; ++i) { From a9515683f6294ef27524c80b3ba819b697f4a11a Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 2 Mar 2026 01:38:37 +0800 Subject: [PATCH 14/58] Parallelize CPU DeformConv Im2Col and bias addition --- .../core/providers/cpu/nn/deform_conv.cc | 141 ++++++++++-------- 1 file changed, 76 insertions(+), 65 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index a5793e5ce9cf0..1c9f078940e8a 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -63,7 +63,8 @@ void DeformableIm2col( int64_t offset_groups, // Number of offset groups int64_t height_col, int64_t width_col, // Output dimensions bool use_mask, - T* data_col) { // Output buffer + T* data_col, // Output buffer + concurrency::ThreadPool* thread_pool) { const int64_t channel_per_offset_group = channels / offset_groups; @@ -72,62 +73,66 @@ void DeformableIm2col( // Inner loop: Spatial locations (c_col) // This ensures sequential access to data_col and better locality for data_im. - for (int64_t c_im = 0; c_im < channels; ++c_im) { - const int64_t offset_grp = c_im / channel_per_offset_group; - - for (int64_t c_col = 0; c_col < height_col * width_col; ++c_col) { - const int64_t w_col = c_col % width_col; - const int64_t h_col = c_col / width_col; - - // Iterate over kernel window - for (int64_t i = 0; i < kernel_h; ++i) { - for (int64_t j = 0; j < kernel_w; ++j) { - - // Calculate the index in the offset/mask tensors. - // The offset tensor is organized as: (offset_groups, 2 * kH * kW, H_out, W_out). - // Flattened offset channel index relative to the start of the tensor: - // base = offset_grp * (2 * kH * kW). - // specific = 2 * (i * kW + j). - - const int64_t data_offset_h_ptr = - ((offset_grp * (2 * kernel_h * kernel_w) + 2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - - const int64_t data_offset_w_ptr = - ((offset_grp * (2 * kernel_h * kernel_w) + 2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - - const int64_t data_mask_ptr = - ((offset_grp * (kernel_h * kernel_w) + (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - - const T offset_h = data_offset[data_offset_h_ptr]; - const T offset_w = data_offset[data_offset_w_ptr]; - - T val = static_cast(0); - T mask_val = static_cast(1); - if (use_mask) { - mask_val = data_mask[data_mask_ptr]; - } - - // Only compute interpolation if mask is not zero (optimization) - if (mask_val != 0) { - const T h_im = h_col * stride_h - pad_h + i * dilation_h + offset_h; - const T w_im = w_col * stride_w - pad_w + j * dilation_w + offset_w; - - // Map (c_im, h_im, w_im) back to input - // data_im is [C, H, W] - const T* data_im_ptr = data_im + c_im * (height * width); - val = BilinearInterpolate(data_im_ptr, height, width, h_im, w_im); + concurrency::ThreadPool::TryParallelFor( + thread_pool, channels, 1.0, + [&](ptrdiff_t c_im_start, ptrdiff_t c_im_end) { + for (int64_t c_im = c_im_start; c_im < c_im_end; ++c_im) { + const int64_t offset_grp = c_im / channel_per_offset_group; + + for (int64_t c_col = 0; c_col < height_col * width_col; ++c_col) { + const int64_t w_col = c_col % width_col; + const int64_t h_col = c_col / width_col; + + // Iterate over kernel window + for (int64_t i = 0; i < kernel_h; ++i) { + for (int64_t j = 0; j < kernel_w; ++j) { + + // Calculate the index in the offset/mask tensors. + // The offset tensor is organized as: (offset_groups, 2 * kH * kW, H_out, W_out). + // Flattened offset channel index relative to the start of the tensor: + // base = offset_grp * (2 * kH * kW). + // specific = 2 * (i * kW + j). + + const int64_t data_offset_h_ptr = + ((offset_grp * (2 * kernel_h * kernel_w) + 2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + + const int64_t data_offset_w_ptr = + ((offset_grp * (2 * kernel_h * kernel_w) + 2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + + const int64_t data_mask_ptr = + ((offset_grp * (kernel_h * kernel_w) + (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + + const T offset_h = data_offset[data_offset_h_ptr]; + const T offset_w = data_offset[data_offset_w_ptr]; + + T val = static_cast(0); + T mask_val = static_cast(1); + if (use_mask) { + mask_val = data_mask[data_mask_ptr]; + } + + // Only compute interpolation if mask is not zero (optimization) + if (mask_val != 0) { + const T h_im = h_col * stride_h - pad_h + i * dilation_h + offset_h; + const T w_im = w_col * stride_w - pad_w + j * dilation_w + offset_w; + + // Map (c_im, h_im, w_im) back to input + // data_im is [C, H, W] + const T* data_im_ptr = data_im + c_im * (height * width); + val = BilinearInterpolate(data_im_ptr, height, width, h_im, w_im); + } + + // Assign to data_col + // The layout of data_col row is: [Channel, KernelH, KernelW] flattened. + // Row index: c_im * (kH * kW) + i * kW + j + const int64_t col_row_idx = (c_im * kernel_h * kernel_w) + (i * kernel_w + j); + + data_col[col_row_idx * (height_col * width_col) + c_col] = val * mask_val; + } + } } - - // Assign to data_col - // The layout of data_col row is: [Channel, KernelH, KernelW] flattened. - // Row index: c_im * (kH * kW) + i * kW + j - const int64_t col_row_idx = (c_im * kernel_h * kernel_w) + (i * kernel_w + j); - - data_col[col_row_idx * (height_col * width_col) + c_col] = val * mask_val; } - } - } - } + }); } } // namespace @@ -256,7 +261,8 @@ Status DeformConv::Compute(OpKernelContext* context) const { offset_group, out_h, out_w, use_mask, - col_buffer_ptr); + col_buffer_ptr, + thread_pool); // 2. Perform GEMM for each group for (int64_t g = 0; g < group; ++g) { @@ -300,15 +306,20 @@ Status DeformConv::Compute(OpKernelContext* context) const { // 3. Add Bias if present if (Bdata != nullptr) { - for (int64_t n = 0; n < N; ++n) { - T* Y_curr = Ydata + n * M * output_image_size; - for (int64_t m = 0; m < M; ++m) { - T bias_val = Bdata[m]; - for (int64_t i = 0; i < output_image_size; ++i) { - Y_curr[m * output_image_size + i] += bias_val; - } - } - } + int64_t total_work = N * M; + concurrency::ThreadPool::TryParallelFor( + thread_pool, total_work, static_cast(output_image_size), + [&](ptrdiff_t first, ptrdiff_t last) { + for (ptrdiff_t idx = first; idx < last; ++idx) { + int64_t n = idx / M; + int64_t m = idx % M; + T* Y_ptr = Ydata + n * M * output_image_size + m * output_image_size; + T bias_val = Bdata[m]; + for (int64_t i = 0; i < output_image_size; ++i) { + Y_ptr[i] += bias_val; + } + } + }); } return Status::OK(); From f1a9832554b8c629a88fc8aa09ec4ab378634f33 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 2 Mar 2026 02:21:10 +0800 Subject: [PATCH 15/58] Use GPU free memory in DeformConv temp memory heuristic --- .../core/providers/cuda/nn/deform_conv.cc | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index b8b3508c468c5..ea694cdd9de4d 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -111,11 +111,31 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { // We use a safe max(1, ...) for bytes_per_image to avoid division by zero in edge cases const size_t bytes_per_image = SafeInt(output_image_size) * (C * kernel_size + M / group) * sizeof(T); - // Heuristic: limit temp memory to 256MB per chunk to balance parallelism and memory usage. - // For small images, this allows up to kMaxParallelImgs (32). - // For large images (4K/8K), this restricts parallelism to 1 to prevent OOM. - constexpr size_t kMaxTempMemSize = 256 * 1024 * 1024; - const int max_parallel_imgs_mem = std::max(1, static_cast(kMaxTempMemSize / std::max(size_t(1), bytes_per_image))); + // Heuristic: limit temp memory per chunk to balance parallelism and memory usage. + // Mirrors Conv's approach (conv_8.h): use 90% of free memory (10% fragmentation buffer). + // Tiered cap based on free memory: larger GPUs get higher limits for better parallelism. + size_t effective_max_temp = 256ULL * 1024 * 1024; // default fallback + constexpr size_t kMinTempMemSize = 32ULL * 1024 * 1024; // 32MB floor + { + size_t free_mem = 0, total_mem = 0; + if (cudaMemGetInfo(&free_mem, &total_mem) == cudaSuccess && free_mem > 0) { + free_mem = static_cast(static_cast(free_mem) * 0.9); // 10% fragmentation buffer + size_t kMaxTempMemSize; + if (free_mem > 16ULL * 1024 * 1024 * 1024) { + kMaxTempMemSize = 2ULL * 1024 * 1024 * 1024; // 16GB+ free → 2GB + } else if (free_mem > 8ULL * 1024 * 1024 * 1024) { + kMaxTempMemSize = 1ULL * 1024 * 1024 * 1024; // 8-16GB → 1GB + } else if (free_mem > 4ULL * 1024 * 1024 * 1024) { + kMaxTempMemSize = 512ULL * 1024 * 1024; // 4-8GB → 512MB + } else if (free_mem > 2ULL * 1024 * 1024 * 1024) { + kMaxTempMemSize = 256ULL * 1024 * 1024; // 2-4GB → 256MB + } else { + kMaxTempMemSize = 128ULL * 1024 * 1024; // <2GB → 128MB + } + effective_max_temp = std::max(kMinTempMemSize, std::min(kMaxTempMemSize, free_mem)); + } + } + const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); const int n_parallel_imgs = GetGreatestDivisorBelowBound(static_cast(N), target_parallel_imgs); From d99994ffd423d0ed98fec24d877b432b23044930 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 2 Mar 2026 02:22:03 +0800 Subject: [PATCH 16/58] Extract DeformConvAttributes to shared header --- .../core/providers/cpu/nn/deform_conv.h | 24 +------------ .../providers/cpu/nn/deform_conv_attributes.h | 35 +++++++++++++++++++ .../core/providers/cuda/nn/deform_conv.h | 24 +------------ .../providers/cpu/nn/deform_conv_op_test.cc | 1 + 4 files changed, 38 insertions(+), 46 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.h b/onnxruntime/core/providers/cpu/nn/deform_conv.h index ee4c2981b7573..c8d7763e58bcb 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.h @@ -6,32 +6,10 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/framework/op_node_proto_helper.h" -#include "core/framework/tensor_shape.h" +#include "deform_conv_attributes.h" namespace onnxruntime { -// Attributes for ONNX DeformConv (opset 19+). -// See https://onnx.ai/onnx/operators/onnx__DeformConv.html -struct DeformConvAttributes { - explicit DeformConvAttributes(const OpKernelInfo& info) { - // Optional attributes. - // If not present, they will be empty/default, and handled in Compute. - (void)info.GetAttrs("kernel_shape", kernel_shape); - (void)info.GetAttrs("strides", strides); - (void)info.GetAttrs("pads", pads); - (void)info.GetAttrs("dilations", dilations); - group = info.GetAttrOrDefault("group", 1); - offset_group = info.GetAttrOrDefault("offset_group", 1); - } - - TensorShapeVector kernel_shape; - TensorShapeVector strides; - TensorShapeVector pads; - TensorShapeVector dilations; - int64_t group{1}; - int64_t offset_group{1}; -}; - template class DeformConv : public OpKernel { public: diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h new file mode 100644 index 0000000000000..2e8c13f43ae56 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensor_shape.h" + +namespace onnxruntime { + +// Shared attributes for ONNX DeformConv (opset 19+). +// See https://onnx.ai/onnx/operators/onnx__DeformConv.html +// Used by both CPU and CUDA implementations (CUDA includes from here). +struct DeformConvAttributes { + explicit DeformConvAttributes(const OpKernelInfo& info) { + // Optional attributes. + // If not present, they will be empty/default, and handled in Compute/ComputeInternal. + (void)info.GetAttrs("kernel_shape", kernel_shape); + (void)info.GetAttrs("strides", strides); + (void)info.GetAttrs("pads", pads); + (void)info.GetAttrs("dilations", dilations); + group = info.GetAttrOrDefault("group", 1); + offset_group = info.GetAttrOrDefault("offset_group", 1); + } + + TensorShapeVector kernel_shape; + TensorShapeVector strides; + TensorShapeVector pads; + TensorShapeVector dilations; + int64_t group{1}; + int64_t offset_group{1}; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h index 7243d19f71585..fa564641d4b98 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.h +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h @@ -5,34 +5,12 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/framework/tensor_shape.h" +#include "core/providers/cpu/nn/deform_conv_attributes.h" #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { namespace cuda { -// Attributes for ONNX DeformConv (opset 19+). Mirrors CPU for consistency. -// See https://onnx.ai/onnx/operators/onnx__DeformConv.html -struct DeformConvAttributes { - explicit DeformConvAttributes(const OpKernelInfo& info) { - // Optional attributes. - // If not present, they will be empty/default, and handled in Compute/ComputeInternal. - (void)info.GetAttrs("kernel_shape", kernel_shape); - (void)info.GetAttrs("strides", strides); - (void)info.GetAttrs("pads", pads); - (void)info.GetAttrs("dilations", dilations); - group = info.GetAttrOrDefault("group", 1); - offset_group = info.GetAttrOrDefault("offset_group", 1); - } - - TensorShapeVector kernel_shape; - TensorShapeVector strides; - TensorShapeVector pads; - TensorShapeVector dilations; - int64_t group{1}; - int64_t offset_group{1}; -}; - template class DeformConv final : public CudaKernel { public: diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 486f577f9791f..78106503344b6 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -95,6 +95,7 @@ void RunDeformConvTest(const DeformConvTestParams& params, decltype(DeformConvTestTraits::DefaultAtol()) atol = DeformConvTestTraits::DefaultAtol()) { const int64_t kH = params.kernel_shape[0]; const int64_t kW = params.kernel_shape[1]; + // ONNX pads format: [pad_top, pad_left, pad_bottom, pad_right] = [pad[0], pad[1], pad[2], pad[3]] const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - From 7d7f66ea1a96a45e22dc116193bcd8205d215921 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 2 Mar 2026 02:33:19 +0800 Subject: [PATCH 17/58] DeformConv op shared attributes and validation --- .../core/providers/cpu/nn/deform_conv.cc | 82 +++-------- .../providers/cpu/nn/deform_conv_attributes.h | 109 ++++++++++++++ .../core/providers/cuda/nn/deform_conv.cc | 136 +++++++----------- 3 files changed, 182 insertions(+), 145 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 1c9f078940e8a..cdd67c4117cbe 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -145,67 +145,27 @@ Status DeformConv::Compute(OpKernelContext* context) const { const auto* B = context->Input(3); // optional const auto* mask = context->Input(4); // optional - const auto& X_shape = X->Shape(); - const auto& W_shape = W->Shape(); - const auto& offset_shape = offset->Shape(); - - // Validate Input Shapes - const int64_t N = X_shape[0]; - const int64_t C = X_shape[1]; - const int64_t H = X_shape[2]; - const int64_t W_in = X_shape[3]; - - const int64_t M = W_shape[0]; // out channels - // Handle kernel shape inference - const int64_t kH = attrs_.kernel_shape.size() >= 1 ? attrs_.kernel_shape[0] : W_shape[2]; - const int64_t kW = attrs_.kernel_shape.size() >= 2 ? attrs_.kernel_shape[1] : W_shape[3]; - - int64_t pad_h = 0; - int64_t pad_w = 0; - int64_t pad_h_end = 0; - int64_t pad_w_end = 0; - if (attrs_.pads.size() >= 4) { - pad_h = attrs_.pads[0]; - pad_w = attrs_.pads[1]; - pad_h_end = attrs_.pads[2]; - pad_w_end = attrs_.pads[3]; - } - - const int64_t stride_h = attrs_.strides.empty() ? 1 : attrs_.strides[0]; - const int64_t stride_w = attrs_.strides.size() < 2 ? 1 : attrs_.strides[1]; - const int64_t dilation_h = attrs_.dilations.empty() ? 1 : attrs_.dilations[0]; - const int64_t dilation_w = attrs_.dilations.size() < 2 ? 1 : attrs_.dilations[1]; - const int64_t group = attrs_.group; - const int64_t offset_group = attrs_.offset_group; - - // Validate input shapes - ORT_RETURN_IF_NOT(stride_h > 0 && stride_w > 0, "Strides must be positive."); - ORT_RETURN_IF_NOT(dilation_h > 0 && dilation_w > 0, "Dilations must be positive."); - ORT_RETURN_IF_NOT(kH > 0 && kW > 0, "Kernel shape must be positive."); - ORT_RETURN_IF_NOT(group > 0, "group must be positive"); - ORT_RETURN_IF_NOT(offset_group > 0, "offset_group must be positive"); - - const int64_t out_h = (H + pad_h + pad_h_end - dilation_h * (kH - 1) - 1) / stride_h + 1; - const int64_t out_w = (W_in + pad_w + pad_w_end - dilation_w * (kW - 1) - 1) / stride_w + 1; - - // Checks - ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); - ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); - ORT_RETURN_IF_NOT(offset_shape[1] == offset_group * 2 * kH * kW, - "Offset channel count must be offset_group * 2 * kH * kW."); - ORT_RETURN_IF_NOT(offset_shape[2] == out_h, "Offset spatial height must match output oH."); - ORT_RETURN_IF_NOT(offset_shape[3] == out_w, "Offset spatial width must match output oW."); - ORT_RETURN_IF_NOT(C % offset_group == 0, "Input channels must be divisible by offset_group."); - ORT_RETURN_IF_NOT(C == W_shape[1] * group, "Input channels must match weight in channels * group."); - ORT_RETURN_IF_NOT(M % group == 0, "Output channels must be divisible by group."); - - const bool use_mask = (mask != nullptr); - if (use_mask) { - ORT_RETURN_IF_NOT(mask->Shape().NumDimensions() == 4, "Mask must be 4D."); - ORT_RETURN_IF_NOT(mask->Shape()[1] == offset_group * kH * kW, "Mask channel count must be offset_group * kH * kW."); - ORT_RETURN_IF_NOT(mask->Shape()[2] == out_h, "Mask spatial height must match output oH."); - ORT_RETURN_IF_NOT(mask->Shape()[3] == out_w, "Mask spatial width must match output oW."); - } + DeformConvParams params; + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse(attrs_, X->Shape(), W->Shape(), offset->Shape(), mask ? &mask->Shape() : nullptr, params)); + + const int64_t N = params.N; + const int64_t C = params.C; + const int64_t H = params.H; + const int64_t W_in = params.W_in; + const int64_t M = params.M; + const int64_t kH = params.kH; + const int64_t kW = params.kW; + const int64_t pad_h = params.pad_h; + const int64_t pad_w = params.pad_w; + const int64_t stride_h = params.stride_h; + const int64_t stride_w = params.stride_w; + const int64_t dilation_h = params.dilation_h; + const int64_t dilation_w = params.dilation_w; + const int64_t group = params.group; + const int64_t offset_group = params.offset_group; + const int64_t out_h = params.out_h; + const int64_t out_w = params.out_w; + const bool use_mask = params.use_mask; // Allocate Output const TensorShape Y_shape({N, M, out_h, out_w}); diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index 2e8c13f43ae56..f2a8fdb58bd64 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -32,4 +32,113 @@ struct DeformConvAttributes { int64_t offset_group{1}; }; +// Parsed and validated parameters from DeformConv inputs. +// Used by both CPU and CUDA implementations. +// Field names align with ONNX DeformConv spec: https://onnx.ai/onnx/operators/onnx__DeformConv.html +struct DeformConvParams { + // Input X shape (N, C, H, W) + int64_t N{0}; // Batch size + int64_t C{0}; // Number of input channels + int64_t H{0}; // Input height + int64_t W_in{0}; // Input width (W_in to avoid collision with weight W) + + // Weight W shape (oC, C/group, kH, kW) + int64_t M{0}; // Number of output channels (oC) + int64_t kH{0}; // Kernel height + int64_t kW{0}; // Kernel width + + // Pads [x1_begin, x2_begin, x1_end, x2_end] for spatial axes H, W + int64_t pad_h{0}; + int64_t pad_w{0}; + int64_t pad_h_end{0}; + int64_t pad_w_end{0}; + + // Strides and dilations along each spatial axis (default 1) + int64_t stride_h{1}; + int64_t stride_w{1}; + int64_t dilation_h{1}; + int64_t dilation_w{1}; + + // Attributes: C and oC must be divisible by group; C must be divisible by offset_group + int64_t group{1}; // Number of groups for input/output channels + int64_t offset_group{1}; // Number of groups of offset + + // Output Y shape (N, oC, oH, oW) + int64_t out_h{0}; // Output height (oH) + int64_t out_w{0}; // Output width (oW) + + bool use_mask{false}; // Whether optional mask input is provided +}; + +// Validates inputs and parses attributes into params. +// Returns Status::OK() on success; on failure, params may be partially filled. +inline Status DeformConvValidateAndParse( + const DeformConvAttributes& attrs, + const TensorShape& X_shape, + const TensorShape& W_shape, + const TensorShape& offset_shape, + const TensorShape* mask_shape, + DeformConvParams& params) { + // Parse input shapes + params.N = X_shape[0]; + params.C = X_shape[1]; + params.H = X_shape[2]; + params.W_in = X_shape[3]; + params.M = W_shape[0]; + + // Handle kernel shape inference + params.kH = attrs.kernel_shape.size() >= 1 ? attrs.kernel_shape[0] : W_shape[2]; + params.kW = attrs.kernel_shape.size() >= 2 ? attrs.kernel_shape[1] : W_shape[3]; + + params.pad_h = params.pad_w = params.pad_h_end = params.pad_w_end = 0; + if (attrs.pads.size() >= 4) { + params.pad_h = attrs.pads[0]; + params.pad_w = attrs.pads[1]; + params.pad_h_end = attrs.pads[2]; + params.pad_w_end = attrs.pads[3]; + } + + params.stride_h = attrs.strides.empty() ? 1 : attrs.strides[0]; + params.stride_w = attrs.strides.size() < 2 ? 1 : attrs.strides[1]; + params.dilation_h = attrs.dilations.empty() ? 1 : attrs.dilations[0]; + params.dilation_w = attrs.dilations.size() < 2 ? 1 : attrs.dilations[1]; + params.group = attrs.group; + params.offset_group = attrs.offset_group; + params.use_mask = (mask_shape != nullptr); + + // Validate attributes + ORT_RETURN_IF_NOT(params.stride_h > 0 && params.stride_w > 0, "Strides must be positive."); + ORT_RETURN_IF_NOT(params.dilation_h > 0 && params.dilation_w > 0, "Dilations must be positive."); + ORT_RETURN_IF_NOT(params.kH > 0 && params.kW > 0, "Kernel shape must be positive."); + ORT_RETURN_IF_NOT(params.group > 0, "group must be positive"); + ORT_RETURN_IF_NOT(params.offset_group > 0, "offset_group must be positive"); + + params.out_h = (params.H + params.pad_h + params.pad_h_end - params.dilation_h * (params.kH - 1) - 1) / params.stride_h + 1; + params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1; + + // Validate tensor shapes + ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); + ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); + ORT_RETURN_IF_NOT( + offset_shape[1] == params.offset_group * 2 * params.kH * params.kW, + "Offset channel count must be offset_group * 2 * kH * kW."); + ORT_RETURN_IF_NOT(offset_shape[2] == params.out_h, "Offset spatial height must match output oH."); + ORT_RETURN_IF_NOT(offset_shape[3] == params.out_w, "Offset spatial width must match output oW."); + ORT_RETURN_IF_NOT(params.C % params.offset_group == 0, "Input channels must be divisible by offset_group."); + ORT_RETURN_IF_NOT(params.C == W_shape[1] * params.group, "Input channels must match weight in channels * group."); + ORT_RETURN_IF_NOT(params.M % params.group == 0, "Output channels must be divisible by group."); + + // Validate mask if present + if (params.use_mask) { + ORT_RETURN_IF_NOT(mask_shape->NumDimensions() == 4, "Mask must be 4D."); + ORT_RETURN_IF_NOT( + (*mask_shape)[1] == params.offset_group * params.kH * params.kW, + "Mask channel count must be offset_group * kH * kW."); + ORT_RETURN_IF_NOT((*mask_shape)[2] == params.out_h, "Mask spatial height must match output oH."); + ORT_RETURN_IF_NOT((*mask_shape)[3] == params.out_w, "Mask spatial width must match output oW."); + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index ea694cdd9de4d..f83e15801defb 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -27,6 +27,35 @@ int GetGreatestDivisorBelowBound(int n, int bound) { return 1; } +// Returns effective max temp memory (bytes) for DeformConv batching. +// Uses 90% of free GPU memory with tiered cap; fallback 256MB if cudaMemGetInfo fails. +// Mirrors Conv's approach (conv_8.h); tiered limits avoid OOM on smaller GPUs. +size_t GetDeformConvEffectiveMaxTempBytes() { + constexpr size_t kDefaultFallback = 256ULL * 1024 * 1024; + constexpr size_t kMinTempMemSize = 32ULL * 1024 * 1024; + + size_t free_mem = 0, total_mem = 0; + if (cudaMemGetInfo(&free_mem, &total_mem) != cudaSuccess || free_mem == 0) { + return kDefaultFallback; + } + free_mem = static_cast(static_cast(free_mem) * 0.9); // 10% fragmentation buffer + + size_t tier_cap; + if (free_mem > 16ULL * 1024 * 1024 * 1024) { + tier_cap = 2ULL * 1024 * 1024 * 1024; // 16GB+ free → 2GB + } else if (free_mem > 8ULL * 1024 * 1024 * 1024) { + tier_cap = 1ULL * 1024 * 1024 * 1024; // 8-16GB → 1GB + } else if (free_mem > 4ULL * 1024 * 1024 * 1024) { + tier_cap = 512ULL * 1024 * 1024; // 4-8GB → 512MB + } else if (free_mem > 2ULL * 1024 * 1024 * 1024) { + tier_cap = 256ULL * 1024 * 1024; // 2-4GB → 256MB + } else { + tier_cap = 128ULL * 1024 * 1024; // <2GB → 128MB + } + + return std::max(kMinTempMemSize, std::min(tier_cap, free_mem)); +} + } // namespace template @@ -39,61 +68,27 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { const auto* B = context->Input(3); const auto* mask = context->Input(4); - const auto& X_shape = X->Shape(); - const auto& W_shape = W->Shape(); - const auto& offset_shape = offset->Shape(); - - const int64_t N = X_shape[0]; - const int64_t C = X_shape[1]; - const int64_t H = X_shape[2]; - const int64_t W_in = X_shape[3]; - - const int64_t M = W_shape[0]; - const int64_t kH = attrs_.kernel_shape.size() >= 1 ? attrs_.kernel_shape[0] : W_shape[2]; - const int64_t kW = attrs_.kernel_shape.size() >= 2 ? attrs_.kernel_shape[1] : W_shape[3]; - - int64_t pad_h = 0, pad_w = 0, pad_h_end = 0, pad_w_end = 0; - if (attrs_.pads.size() >= 4) { - pad_h = attrs_.pads[0]; - pad_w = attrs_.pads[1]; - pad_h_end = attrs_.pads[2]; - pad_w_end = attrs_.pads[3]; - } - - const int64_t stride_h = attrs_.strides.empty() ? 1 : attrs_.strides[0]; - const int64_t stride_w = attrs_.strides.size() < 2 ? 1 : attrs_.strides[1]; - const int64_t dilation_h = attrs_.dilations.empty() ? 1 : attrs_.dilations[0]; - const int64_t dilation_w = attrs_.dilations.size() < 2 ? 1 : attrs_.dilations[1]; - const int64_t group = attrs_.group; - const int64_t offset_group = attrs_.offset_group; - - // Validate input shapes - ORT_RETURN_IF_NOT(stride_h > 0 && stride_w > 0, "Strides must be positive."); - ORT_RETURN_IF_NOT(dilation_h > 0 && dilation_w > 0, "Dilations must be positive."); - ORT_RETURN_IF_NOT(kH > 0 && kW > 0, "Kernel shape must be positive."); - ORT_RETURN_IF_NOT(group > 0, "group must be positive"); - ORT_RETURN_IF_NOT(offset_group > 0, "offset_group must be positive"); - - const int64_t out_h = (H + pad_h + pad_h_end - dilation_h * (kH - 1) - 1) / stride_h + 1; - const int64_t out_w = (W_in + pad_w + pad_w_end - dilation_w * (kW - 1) - 1) / stride_w + 1; - - // Checks - ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); - ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); - ORT_RETURN_IF_NOT(offset_shape[1] == offset_group * 2 * kH * kW, - "Offset channel count must be offset_group * 2 * kH * kW."); - ORT_RETURN_IF_NOT(offset_shape[2] == out_h && offset_shape[3] == out_w, - "Offset spatial dims must match output."); - ORT_RETURN_IF_NOT(C % offset_group == 0, "Input channels must be divisible by offset_group."); - ORT_RETURN_IF_NOT(C == W_shape[1] * group, "Input channels must match weight in channels * group."); - ORT_RETURN_IF_NOT(M % group == 0, "Output channels must be divisible by group."); - - const bool use_mask = (mask != nullptr); - if (use_mask) { - ORT_RETURN_IF_NOT(mask->Shape().NumDimensions() == 4, "Mask must be 4D."); - ORT_RETURN_IF_NOT(mask->Shape()[1] == offset_group * kH * kW, "Mask channel count invalid."); - ORT_RETURN_IF_NOT(mask->Shape()[2] == out_h && mask->Shape()[3] == out_w, "Mask spatial dims must match output."); - } + DeformConvParams params; + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse(attrs_, X->Shape(), W->Shape(), offset->Shape(), mask ? &mask->Shape() : nullptr, params)); + + const int64_t N = params.N; + const int64_t C = params.C; + const int64_t H = params.H; + const int64_t W_in = params.W_in; + const int64_t M = params.M; + const int64_t kH = params.kH; + const int64_t kW = params.kW; + const int64_t pad_h = params.pad_h; + const int64_t pad_w = params.pad_w; + const int64_t stride_h = params.stride_h; + const int64_t stride_w = params.stride_w; + const int64_t dilation_h = params.dilation_h; + const int64_t dilation_w = params.dilation_w; + const int64_t group = params.group; + const int64_t offset_group = params.offset_group; + const int64_t out_h = params.out_h; + const int64_t out_w = params.out_w; + const bool use_mask = params.use_mask; Tensor* Y = context->Output(0, {N, M, out_h, out_w}); if (Y->Shape().Size() == 0) { @@ -105,36 +100,9 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { const int64_t input_image_size = H * W_in; const int64_t kernel_dim = (C / group) * kernel_size; - // Calculate memory usage per image to avoid OOM with large images - // col_buffer: C * kernel_size * output_image_size - // gemm_output_buffer: (M / group) * output_image_size - // We use a safe max(1, ...) for bytes_per_image to avoid division by zero in edge cases + // col_buffer: C * kernel_size * output_image_size; gemm_output_buffer: (M/group) * output_image_size const size_t bytes_per_image = SafeInt(output_image_size) * (C * kernel_size + M / group) * sizeof(T); - - // Heuristic: limit temp memory per chunk to balance parallelism and memory usage. - // Mirrors Conv's approach (conv_8.h): use 90% of free memory (10% fragmentation buffer). - // Tiered cap based on free memory: larger GPUs get higher limits for better parallelism. - size_t effective_max_temp = 256ULL * 1024 * 1024; // default fallback - constexpr size_t kMinTempMemSize = 32ULL * 1024 * 1024; // 32MB floor - { - size_t free_mem = 0, total_mem = 0; - if (cudaMemGetInfo(&free_mem, &total_mem) == cudaSuccess && free_mem > 0) { - free_mem = static_cast(static_cast(free_mem) * 0.9); // 10% fragmentation buffer - size_t kMaxTempMemSize; - if (free_mem > 16ULL * 1024 * 1024 * 1024) { - kMaxTempMemSize = 2ULL * 1024 * 1024 * 1024; // 16GB+ free → 2GB - } else if (free_mem > 8ULL * 1024 * 1024 * 1024) { - kMaxTempMemSize = 1ULL * 1024 * 1024 * 1024; // 8-16GB → 1GB - } else if (free_mem > 4ULL * 1024 * 1024 * 1024) { - kMaxTempMemSize = 512ULL * 1024 * 1024; // 4-8GB → 512MB - } else if (free_mem > 2ULL * 1024 * 1024 * 1024) { - kMaxTempMemSize = 256ULL * 1024 * 1024; // 2-4GB → 256MB - } else { - kMaxTempMemSize = 128ULL * 1024 * 1024; // <2GB → 128MB - } - effective_max_temp = std::max(kMinTempMemSize, std::min(kMaxTempMemSize, free_mem)); - } - } + const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(); const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); From 8b5a13f5b5a3abf9f3554cb2415dc968ea57b894 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Mon, 2 Mar 2026 03:02:34 +0800 Subject: [PATCH 18/58] Refactor attributes/validation and optimize CUDA DeformConvIm2Col kernel --- .../providers/cuda/nn/deform_conv_impl.cu | 231 +++++++++--------- 1 file changed, 114 insertions(+), 117 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index c9e44eb214566..a55538ca1ddc2 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -19,6 +19,9 @@ namespace { constexpr int kDeformConvThreadsPerBlock = 256; +template +struct DeformConvKSize { static constexpr int value = N; }; + // Calculate grid size with a safety limit to prevent overflow. // Since we use grid-stride loops in kernels, limiting the grid size is safe. inline int GetGridSize(size_t n, size_t threads_per_block) { @@ -130,9 +133,8 @@ __device__ __inline__ BFloat16 BilinearInterpolate( return BFloat16(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); } -// 1D parallel: each thread handles one output pixel (out_b, out_y, out_x) for a specific channel (in_c). -// Optimized memory access patterns and removed redundant calculations. -template +// kH/kW = -1 means dynamic (runtime); >= 0 means compile-time constant for loop unrolling. +template __global__ void DeformableIm2ColKernel( IndexT num_kernels, const T* __restrict__ input, @@ -157,15 +159,21 @@ __global__ void DeformableIm2ColKernel( bool use_mask, T* __restrict__ data_col) { + constexpr bool is_fixed = (kH >= 0 && kW >= 0); + const int64_t h_dim = is_fixed ? kH : weight_h; + const int64_t w_dim = is_fixed ? kW : weight_w; + // Reconstruct dimensions from DivMod objects const int64_t out_h = out_h_div.d_; const int64_t out_w = out_w_div.d_; const int64_t parallel_imgs = parallel_imgs_div.d_; const int64_t out_size = out_h * out_w; - // The stride for data_col is (batch * out_h * out_w) + // The stride for data_col is (parallel_imgs * out_h * out_w) const int64_t col_stride = parallel_imgs * out_size; + using CoordT = typename std::conditional::value, float, T>::type; + for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { IndexT val = index; IndexT out_x, out_y, out_b, in_c; @@ -178,8 +186,8 @@ __global__ void DeformableIm2ColKernel( // [Optimization 3] Avoid expensive division if offset_group is 1 (very common case). IndexT offset_grp = 0; if (offset_group > 1) { - IndexT dummy; - channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); + IndexT dummy; + channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); } // [Optimization 2] Common Subexpression Elimination (CSE) & Pointer Arithmetic @@ -194,15 +202,15 @@ __global__ void DeformableIm2ColKernel( // 3. Offset pointer base calculation. // Layout: (N, offset_groups, 2*KH*KW, OH, OW) // We pre-calculate the pointer to the start of the specific (n, g) block, plus spatial_idx. - const int64_t offset_group_block_size = 2 * weight_h * weight_w * out_size; + const int64_t offset_group_block_size = 2 * h_dim * w_dim * out_size; const T* offset_ptr_base = offset + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * offset_group_block_size + spatial_idx; // 4. Mask pointer base calculation (if used). // Layout: (N, offset_groups, KH*KW, OH, OW) const T* mask_ptr_base = nullptr; if (use_mask) { - const int64_t mask_group_block_size = weight_h * weight_w * out_size; - mask_ptr_base = mask + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * mask_group_block_size + spatial_idx; + const int64_t mask_group_block_size = h_dim * w_dim * out_size; + mask_ptr_base = mask + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * mask_group_block_size + spatial_idx; } // 5. Output pointer base calculation. @@ -210,49 +218,59 @@ __global__ void DeformableIm2ColKernel( // The current thread writes to the column `c_col` = (b * OH * OW) + spatial_idx. // The starting row for this channel is `in_c * KH * KW`. const int64_t c_col = static_cast(out_b) * out_size + spatial_idx; - T* data_col_ptr_base = data_col + (static_cast(in_c) * weight_h * weight_w) * col_stride + c_col; + T* data_col_ptr_base = data_col + (static_cast(in_c) * h_dim * w_dim) * col_stride + c_col; // 6. Pre-calculate invariant coordinate parts. // Use float for coordinate math when T is half or BFloat16 to avoid precision loss. - using CoordT = typename std::conditional::value, float, T>::type; const CoordT base_h_im = static_cast(out_y * stride_h - pad_h); const CoordT base_w_im = static_cast(out_x * stride_w - pad_w); -#pragma unroll - for (int64_t i = 0; i < weight_h; ++i) { -#pragma unroll - for (int64_t j = 0; j < weight_w; ++j) { - const int64_t kernel_idx = i * weight_w + j; - - T mask_val = static_cast(1); - if (use_mask) { - // Access mask using pre-calculated base and stride. - mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size); - - // [Optimization 1] Early Exit / Pruning - // If mask is 0, the contribution is 0. Skip expensive offset load and interpolation. - // Note: casting to float for comparison is safe for standard floating point types. - if (static_cast(mask_val) == 0.0f) { - data_col_ptr_base[kernel_idx * col_stride] = static_cast(0); - continue; - } + auto process_kernel_point = [&](int64_t i, int64_t j) { + const int64_t kernel_idx = i * w_dim + j; + T mask_val = static_cast(1); + if (use_mask) { + // Access mask using pre-calculated base and stride. + mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size); + + // [Optimization 1] Early Exit / Pruning + // If mask is 0, the contribution is 0. Skip expensive offset load and interpolation. + // Note: casting to float for comparison is safe for standard floating point types. + if (static_cast(mask_val) == 0.0f) { + data_col_ptr_base[kernel_idx * col_stride] = static_cast(0); + return; } + } + + // Calculate offset pointers relative to the base. + // The offset tensor stores (y_offset, x_offset) pairs for each kernel weight. + // Stride between y_offset and x_offset is `out_size`. + const int64_t offset_offset_idx = (2 * kernel_idx) * out_size; - // Calculate offset pointers relative to the base. - // The offset tensor stores (y_offset, x_offset) pairs for each kernel weight. - // Stride between y_offset and x_offset is `out_size`. - const int64_t offset_offset_idx = (2 * kernel_idx) * out_size; + const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx)); + const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx + out_size)); - const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx)); - const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx + out_size)); + const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h; + const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w; - const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h; - const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w; + T val = BilinearInterpolate(input_ptr, height, width, h_im, w_im); - T val = BilinearInterpolate(input_ptr, height, width, h_im, w_im); + // Write result to data_col using pre-calculated base. + data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; + }; - // Write result to data_col using pre-calculated base. - data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; + if constexpr (is_fixed) { +#pragma unroll + for (int i = 0; i < kH; ++i) { +#pragma unroll + for (int j = 0; j < kW; ++j) { + process_kernel_point(i, j); + } + } + } else { + for (int64_t i = 0; i < weight_h; ++i) { + for (int64_t j = 0; j < weight_w; ++j) { + process_kernel_point(i, j); + } } } } @@ -386,54 +404,33 @@ Status DeformConvIm2ColImpl( int blocks = GetGridSize(static_cast(num_kernels), kDeformConvThreadsPerBlock); - if (use_64bit) { - DeformableIm2ColKernel<<>>( - num_kernels, - input, - offset, - mask, - H, - W, - kH, - kW, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - C, // channels is C - offset_group, - DivMod(out_h), - DivMod(out_w), - DivMod(parallel_imgs), - DivMod(C / offset_group), - use_mask, - col_buffer); + auto launch = [&](auto kH_tag, auto kW_tag) { + constexpr int KH = decltype(kH_tag)::value; + constexpr int KW = decltype(kW_tag)::value; + if (use_64bit) { + DeformableIm2ColKernel<<>>( + num_kernels, input, offset, mask, H, W, kH, kW, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, C, offset_group, + DivMod(out_h), DivMod(out_w), DivMod(parallel_imgs), + DivMod(C / offset_group), use_mask, col_buffer); + } else { + DeformableIm2ColKernel<<>>( + static_cast(num_kernels), input, offset, mask, H, W, kH, kW, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w, C, offset_group, + DivMod(static_cast(out_h)), + DivMod(static_cast(out_w)), + DivMod(static_cast(parallel_imgs)), + DivMod(static_cast(C / offset_group)), + use_mask, col_buffer); + } + }; + + if (kH == 3 && kW == 3) { + launch(DeformConvKSize<3>{}, DeformConvKSize<3>{}); + } else if (kH == 5 && kW == 5) { + launch(DeformConvKSize<5>{}, DeformConvKSize<5>{}); } else { - DeformableIm2ColKernel<<>>( - static_cast(num_kernels), - input, - offset, - mask, - H, - W, - kH, - kW, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - C, // channels is C - offset_group, - DivMod(static_cast(out_h)), - DivMod(static_cast(out_w)), - DivMod(static_cast(parallel_imgs)), - DivMod(static_cast(C / offset_group)), - use_mask, - col_buffer); + launch(DeformConvKSize<-1>{}, DeformConvKSize<-1>{}); } return CUDA_CALL(cudaGetLastError()); } @@ -457,37 +454,37 @@ template Status DeformConvAddBiasImpl(cudaStream_t, half*, const half*, in template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); // Delegate ORT type to CUDA type (e.g. MLFloat16 -> half); avoids repeating three identical specializations. -#define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ - template <> \ - Status DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ - const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ - int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ - int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ - int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ +#define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ + template <> \ + Status DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ + const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ + int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ + int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ + int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ int64_t dilation_h, int64_t dilation_w, int64_t offset_group, bool use_mask) { \ - return DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ - reinterpret_cast(offset), \ - mask ? reinterpret_cast(mask) : nullptr, \ - reinterpret_cast(col_buffer), \ - parallel_imgs, C, H, W, kH, kW, out_h, out_w, \ - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, \ - offset_group, use_mask); \ - } \ - template <> \ - Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ - const ORT_T* gemm_output, ORT_T* Y_g, \ - int64_t M, int64_t M_per_group, \ - int64_t output_image_size, int64_t cur_parallel) { \ - return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ - reinterpret_cast(gemm_output), \ - reinterpret_cast(Y_g), \ - M, M_per_group, output_image_size, cur_parallel); \ - } \ - template <> \ - Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T* Y, const ORT_T* B, \ - int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ - return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ - reinterpret_cast(B), N, M, out_h, out_w); \ + return DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ + reinterpret_cast(offset), \ + mask ? reinterpret_cast(mask) : nullptr, \ + reinterpret_cast(col_buffer), \ + parallel_imgs, C, H, W, kH, kW, out_h, out_w, \ + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, \ + offset_group, use_mask); \ + } \ + template <> \ + Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ + const ORT_T* gemm_output, ORT_T* Y_g, \ + int64_t M, int64_t M_per_group, \ + int64_t output_image_size, int64_t cur_parallel) { \ + return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ + reinterpret_cast(gemm_output), \ + reinterpret_cast(Y_g), \ + M, M_per_group, output_image_size, cur_parallel); \ + } \ + template <> \ + Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T* Y, const ORT_T* B, \ + int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ + return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ + reinterpret_cast(B), N, M, out_h, out_w); \ } DELEGATE_DEFORM_CONV_IMPL(MLFloat16, half) From e5ec6defdc09aeabf36046c09c14d8c37bba4c76 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 3 Mar 2026 00:51:01 +0800 Subject: [PATCH 19/58] Add DeformConv OnnxModelTest with reference ONNX model --- .../cpu/nn/deform_conv_expected_gen.py | 6 +- .../providers/cpu/nn/deform_conv_op_test.cc | 28 ++- .../test/testdata/deform_conv_test.onnx | Bin 0 -> 353 bytes .../test/testdata/deform_conv_test_data.inc | 10 + .../test/testdata/deform_conv_test_data.npz | Bin 0 -> 2092 bytes .../test/testdata/nn/deform_conv_test_gen.py | 178 ++++++++++++++++++ 6 files changed, 220 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/testdata/deform_conv_test.onnx create mode 100644 onnxruntime/test/testdata/deform_conv_test_data.inc create mode 100644 onnxruntime/test/testdata/deform_conv_test_data.npz create mode 100644 onnxruntime/test/testdata/nn/deform_conv_test_gen.py diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py index 0dffd5f337c61..860e2f0322e72 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -3,6 +3,10 @@ Generate expected outputs for DeformConv tests using torchvision.ops.deform_conv2d. Run with: .venv/bin/python onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py Outputs C++-friendly std::vector initializer lists for pasting into deform_conv_op_test.cc + +Limitation: Uses symmetric padding only. PyTorch padding=(pad_h, pad_w) and ONNX pads +[pad_h, pad_w, pad_h, pad_w] are derived from a single (pad_h, pad_w) pair. Asymmetric +pads (e.g. pad_top != pad_bottom) would require PyTorch API support and are not generated. """ import torch import torchvision.ops @@ -41,7 +45,7 @@ def run_case(name: str, batch_sz: int, n_in: int, n_out: int, n_weight_grps: int stride=(stride_h, stride_w), padding=(pad_h, pad_w), dilation=(dil_h, dil_w), mask=mask ) - # ONNX pads = [top, left, bottom, right] + # ONNX pads = [top, left, bottom, right] (symmetric: single pad_h, pad_w expanded) pads_onnx = [pad_h, pad_w, pad_h, pad_w] print(f"// --- {name} (seed={seed}) ---") diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 78106503344b6..fa5f6a0945c15 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -6,6 +6,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/testdata/deform_conv_test_data.inc" #include "test/unittest_util/conversion.h" #if defined(USE_CUDA) @@ -156,7 +157,9 @@ TEST(DeformConvTest, MinimalBilinear) { std::vector X = {1.f, 2.f, 3.f, 4.f}; // NCHW std::vector W = {1.f}; - // offset (1, 2, 2, 2): ch0=offset_h, ch1=offset_w per output position. (0,0):(0.5,0)->2.5, (0,1):(0.5,-1)->1 + // offset shape [N, 2*kH*kW, out_h, out_w] = [1, 2, 2, 2]: ch0=offset_h, ch1=offset_w (for kernel pt 0) + // Layout: offset[n,c,oh,ow]. Flattened (NCHW): [ch0@00, ch0@01, ch0@10, ch0@11, ch1@00, ch1@01, ch1@10, ch1@11] + // (0,0): (0.5, 0.5)->center of [1,2;3,4]->2.5; (0,1): (0,-1)->(0,0)->1; (1,0): (0,0)->3; (1,1): (0,0)->4 std::vector offset = { 0.5f, 0.f, 0.f, 0.f, 0.5f, -1.0f, 0.f, 0.f @@ -954,5 +957,28 @@ TEST(DeformConvTest, ExtremeAspectRatio) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } +// ONNX model data test: fixed inputs from deform_conv_test_gen.py (torchvision ref, seed=123). +// Validates output matches torch reference. The .onnx/.npz can be used for standalone model zoo validation. +TEST(DeformConvTest, OnnxModelTest) { + OpTester test("DeformConv", 19); + test.AddAttribute("kernel_shape", std::vector{2, 2}); + test.AddAttribute("strides", std::vector{1, 1}); + test.AddAttribute("pads", std::vector{0, 0, 0, 0}); + test.AddAttribute("dilations", std::vector{1, 1}); + test.AddAttribute("group", static_cast(2)); + test.AddAttribute("offset_group", static_cast(2)); + + test.AddInput("X", {1, 4, 3, 3}, kDeformConvOnnxTest_X); + test.AddInput("W", {2, 2, 2, 2}, kDeformConvOnnxTest_W); + test.AddInput("offset", {1, 16, 2, 2}, kDeformConvOnnxTest_offset); + test.AddInput("B", {2}, kDeformConvOnnxTest_B); + test.AddInput("mask", {1, 8, 2, 2}, kDeformConvOnnxTest_mask); + test.AddReferenceOutputs("testdata/deform_conv_test.onnx", 1e-4f); + + std::unordered_set excluded = {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/deform_conv_test.onnx b/onnxruntime/test/testdata/deform_conv_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b643014e44acbe4ee5d557f5fdf1b7eeca1951f0 GIT binary patch literal 353 zcmZ9IL2kk@5Jed`CGJ-W!L+DSrK;*OM_^@9j(|kT2BnHaBOs3Aw0G%_Q*j>b(m+7g z-^~1Zf5!ZNyl40&&8|Nz#t| zD8TKi(%sv5lL zo#SW9)bX=jRgCb!NrYgWtURk5C)b>}n#>kYieH=aS`IhvFn_MNZx0s$w`|8`@yq`= VT;}m+;M3+Uu4t#ciHA-&JOF_iJKX>P literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/deform_conv_test_data.inc b/onnxruntime/test/testdata/deform_conv_test_data.inc new file mode 100644 index 0000000000000..7f9901c823d55 --- /dev/null +++ b/onnxruntime/test/testdata/deform_conv_test_data.inc @@ -0,0 +1,10 @@ +// Auto-generated by deform_conv_test_gen.py - do not edit + +#include + +static const std::vector kDeformConvOnnxTest_X = {0.296111941, 0.516562283, 0.251670718, 0.68855679, 0.0739724636, 0.866521955, 0.136579871, 0.102479041, 0.184056461, 0.726446748, 0.315253913, 0.687106669, 0.075635314, 0.196638167, 0.316411972, 0.401740134, 0.118568301, 0.82739538, 0.382084429, 0.660493851, 0.853571773, 0.593153, 0.636725366, 0.982629359, 0.274495304, 0.658375621, 0.277541935, 0.857324839, 0.899328232, 0.0390138626, 0.926822901, 0.738757193, 0.717883527, 0.705837429, 0.915649533, 0.433980227}; +static const std::vector kDeformConvOnnxTest_W = {-1.18204546, -0.287744999, -0.604300678, 0.600236714, -1.42047262, -0.223827749, 0.430554837, -0.89885664, -0.0178579595, 0.426403075, -0.765740693, -0.0545141846, -0.732052684, 1.23474216, 1.18622088, -0.220098898}; +static const std::vector kDeformConvOnnxTest_offset = {-0.388483077, -0.934345901, -0.499144107, -1.08665264, 0.962421, 0.249208495, -0.484502077, -2.09291434, 0.0982837752, -0.0935074314, 0.266214728, -0.585035503, -0.343037993, -0.682147384, -0.988689423, -1.70183039, -1.2202903, 1.31385386, 1.05329967, 0.138805181, -0.204444751, -2.26852894, -0.913327932, -0.420362711, -0.659559608, -0.797927678, 0.18383126, 0.229347408, 0.617742658, -0.287577927, 0.821824312, 0.151177585, -0.0443819836, 1.62355745, -2.32287097, 1.08783054, -0.0635453761, -0.448640704, -1.27846932, -1.14400387, -0.152640373, 0.116741188, 0.44026047, -1.44654655, -0.558081627, -0.0516963229, -0.90827328, 0.350683212, -0.394808769, 0.489227712, -0.216814891, -1.74716449, 1.72284174, 0.773806036, 0.404629797, -1.64612663, -0.59508425, -0.711217523, 0.622964859, -1.37288189, -0.128064156, -1.28383458, -0.290120065, 1.27674019}; +static const std::vector kDeformConvOnnxTest_B = {0.983955026, 0.204511523}; +static const std::vector kDeformConvOnnxTest_mask = {-0.0318612382, -0.478955716, 0.766808629, 0.0274681915, 0.0474699028, -0.92386651, -1.06073678, -2.32444572, -2.06281757, 0.00637452863, -0.989554703, 0.701609194, -0.982237995, 0.277030349, 0.645495057, -0.895680785, 0.492752999, -0.0140781598, -0.274662733, -0.764091492, -0.58715719, 1.1951654, -1.20957518, -0.556007624, -0.0771045536, 1.27737665, -1.45962942, -2.15952778, -0.70670861, -0.92224431, 3.89537215, -0.602696717}; +static const std::vector kDeformConvOnnxTest_expected_Y = {0.971546292, 1.1398586, 0.452816963, 1.86388242, -0.565265715, 1.42318761, -2.46283293, -0.104923099}; diff --git a/onnxruntime/test/testdata/deform_conv_test_data.npz b/onnxruntime/test/testdata/deform_conv_test_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..68639f753501dc778a4a237855d3266f25edd252 GIT binary patch literal 2092 zcmbtVe@qj16u;8qNJkpBR75sMH&}|ZfVqaMW;|J^MJ>TzwhLI;p#rQQkp(uuAKG49Q8K;P5DjYhdPS9mL zHxrgRO(r{f$?t>Hba;O<)YyB7!hbn@ad-!e-&_I~Dj1F*Erw4kt4U&2D9kKOkoC`}qB$DYulHa^vH^{sX2 z<9ibPhw7oxXk=W|9@iv}3293?_E^kjt0`Z|-R7i_U_Hv6qfkWjeEtheUb^A-B}Nxv z!D}BpWGqXnP-85B!>isSt$q)fEz{ntX!8|%(1wl#eFeF7aWQ^=x}AwW6^3kJ3D#Zz z0gED=aF?nT$_L6}`UM&JDt|-Nr5VrV)ZjPzXBmc&CuwwRt((HGeQQdM*4d znTYAN#L{JlvQTsB2%hVavubY%n>EqD?(O6fY-@^w`_^K-Gjx=^1TVrioe6{DXUXDL z^MwX`HDhax!+Tr&G3=iPcw_JcmTi;af4_X}ke4nK zABe7n7k7l03dqSg zKNPR{5@eq&WA981ISe|H*GkIuP~*=nUcH^U@MeAx%^h*>F8lyzxh1^1x@_V$nH%Ts zF8uK0cjQ&dxd~zu5ijboQRwc+Hy+ str: + """Format numpy array as C++ initializer list.""" + flat = arr.flatten().tolist() + vals = ", ".join(f"{x:.9g}" for x in flat) + return f"static const std::vector {name} = {{{vals}}};" + + +def _write_cpp_inc(data: dict, inc_path: Path) -> None: + """Write C++ include file with test data.""" + lines = [ + "// Auto-generated by deform_conv_test_gen.py - do not edit", + "", + "#include ", + "", + _to_cpp_array("kDeformConvOnnxTest_X", data["X"]), + _to_cpp_array("kDeformConvOnnxTest_W", data["W"]), + _to_cpp_array("kDeformConvOnnxTest_offset", data["offset"]), + _to_cpp_array("kDeformConvOnnxTest_B", data["B"]), + _to_cpp_array("kDeformConvOnnxTest_mask", data["mask"]), + _to_cpp_array("kDeformConvOnnxTest_expected_Y", data["expected_Y"]), + "", + ] + inc_path.write_text("\n".join(lines), encoding="utf-8") + + +def main(): + # Output to testdata/ root (same as layernorm.onnx, attention_past_state.onnx, etc.) + script_dir = Path(__file__).resolve().parent + assert script_dir.name == "nn", "Script must live in testdata/nn/" + testdata_root = script_dir.parent + model_path = testdata_root / "deform_conv_test.onnx" + data_path = testdata_root / "deform_conv_test_data.npz" + inc_path = testdata_root / "deform_conv_test_data.inc" + + print("Generating reference via torchvision.ops.deform_conv2d...") + data = _generate_reference() + + print("Building ONNX model...") + model = _build_onnx_model() + onnx.save(model, str(model_path)) + print(f" Saved {model_path}") + + np.savez(str(data_path), **data) + print(f" Saved {data_path}") + + _write_cpp_inc(data, inc_path) + print(f" Saved {inc_path}") + + # Validate with onnxruntime if available + try: + import onnxruntime as ort + + print("Validating with ONNX Runtime...") + sess = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + ort_out = sess.run( + ["Y"], + { + "X": data["X"], + "W": data["W"], + "offset": data["offset"], + "B": data["B"], + "mask": data["mask"], + }, + )[0] + + rtol, atol = 1e-4, 1e-4 + if np.allclose(ort_out, data["expected_Y"], rtol=rtol, atol=atol): + print(" PASS: ORT output matches reference.") + else: + diff = np.abs(ort_out.astype(np.float64) - data["expected_Y"].astype(np.float64)) + print(f" FAIL: max |diff|={diff.max()}, mean={diff.mean()}") + except ImportError: + print(" (onnxruntime not installed; skip validation)") + + +if __name__ == "__main__": + main() From 14cf455d99c9dde0d5f1c0df26c0f2d1e01ee221 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 3 Mar 2026 01:08:02 +0800 Subject: [PATCH 20/58] Optimize GetGreatestDivisorBelowBound in CUDA DeformConv --- .../core/providers/cuda/nn/deform_conv.cc | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index f83e15801defb..fcd24bb6abe38 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -18,11 +18,29 @@ namespace { constexpr int kMaxParallelImgs = 32; +// Returns the greatest divisor of n that is <= bound. Used to choose uniform batch chunk sizes. +// Fast path: if n % bound == 0 (common for batch 32/64/128), return immediately. +// When n >= bound^2, linear scan from bound down is O(bound). Otherwise divisor enumeration +// from 1 to sqrt(n) is O(sqrt(n)). Uses integer comparison (no sqrt) for branch decision. int GetGreatestDivisorBelowBound(int n, int bound) { - for (int k = bound; k > 1; --k) { - if (n % k == 0) { - return k; + if (bound <= 0 || n <= 0) return 1; + if (n % bound == 0) return bound; // Fast path: batch is multiple of target + + // n >= bound^2 <=> bound <= sqrt(n) => linear scan is cheaper + if (static_cast(n) >= static_cast(bound) * bound) { + for (int k = bound - 1; k > 1; --k) { + if (n % k == 0) return k; + } + } else { + // n < bound^2 <=> bound > sqrt(n) => divisor enumeration is cheaper + int best = 1; + for (int i = 1; static_cast(i) * i <= static_cast(n); ++i) { + if (n % i != 0) continue; + const int q = n / i; + if (q <= bound && q > best) best = q; + if (i <= bound && i > best) best = i; } + return best; } return 1; } From 4121178a0c5b1ad9d8d1abbf112ed23da9b9b70b Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 3 Mar 2026 01:10:14 +0800 Subject: [PATCH 21/58] Document symmetric-padding-only limitation in deform_conv_test_gen --- onnxruntime/test/testdata/nn/deform_conv_test_gen.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py index b9f2b2aa7bff6..301f14617e9a4 100644 --- a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py +++ b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py @@ -6,6 +6,10 @@ Uses a moderately complex config: groups=2, offset_group=2, 2x2 kernel, non-zero offsets. Reference output from torchvision.ops.deform_conv2d. +Limitation: Uses symmetric padding only. PyTorch padding=(pad_h, pad_w) and ONNX pads +[pad_top, pad_left, pad_bottom, pad_right] = [pad_h, pad_w, pad_h, pad_w]. Asymmetric +pads (e.g. pad_top != pad_bottom) would require PyTorch API support and are not generated. + Run from repo root: python onnxruntime/test/testdata/nn/deform_conv_test_gen.py @@ -65,6 +69,7 @@ def _generate_reference(): def _build_onnx_model(): """Build DeformConv ONNX model. ONNX pads = [pad_top, pad_left, pad_bottom, pad_right].""" + # Symmetric padding only: (pad_h, pad_w) -> [pad_h, pad_w, pad_h, pad_w] pads = [PAD_H, PAD_W, PAD_H, PAD_W] node = helper.make_node( From df9d0b10aad8b27e0852c0219638aec8aaf92896 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 3 Mar 2026 01:26:51 +0800 Subject: [PATCH 22/58] Skip cuda DeformConv op copy kernel when cur_parallel==1 --- .../core/providers/cuda/nn/deform_conv.cc | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index fcd24bb6abe38..eb917ed6e47e4 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -202,6 +202,11 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { // lda = cur_out_size. // ldb = kernel_dim. // ldc = cur_out_size. + // + // When cur_parallel == 1: cur_out_size == output_image_size, so C layout (pos, channel) matches + // NCHW Y_g[0, channel, pos] exactly. Write directly to Y_g and skip the copy kernel. + // When cur_parallel > 1: layouts differ, must copy via DeformConvCopyGemmOutputRowMajorToNCHW. + const bool gemm_writes_directly = (cur_parallel == 1); CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( cublas, @@ -216,21 +221,21 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(W_g), narrow(kernel_dim), &beta, - reinterpret_cast(gemm_output_buffer.get()), - narrow(cur_out_size), + (gemm_writes_directly ? reinterpret_cast(Y_g) : reinterpret_cast(gemm_output_buffer.get())), + narrow(gemm_writes_directly ? output_image_size : cur_out_size), device_prop, UseTF32()))); - // The output gemm_output_buffer is now Row-Major [M/group, cur_out_size]. - // We need to copy it to Y_g (NCHW). - ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW( - stream, - gemm_output_buffer.get(), - Y_g, - M, - M / group, - output_image_size, - cur_parallel)); + if (!gemm_writes_directly) { + ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW( + stream, + gemm_output_buffer.get(), + Y_g, + M, + M / group, + output_image_size, + cur_parallel)); + } } } From 03cc5e5b14bd3ea2bafd93f948ca2f2954308a07 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 3 Mar 2026 01:47:47 +0800 Subject: [PATCH 23/58] Reformat code --- .../core/providers/cpu/nn/deform_conv.cc | 71 +++++++------- .../providers/cpu/nn/deform_conv_attributes.h | 20 ++-- .../core/providers/cuda/nn/deform_conv.cc | 44 ++++----- .../providers/cuda/nn/deform_conv_impl.cu | 85 +++++++++-------- .../core/providers/cuda/nn/deform_conv_impl.h | 8 +- .../cpu/nn/deform_conv_expected_gen.py | 94 +++++++++++++++---- .../providers/cpu/nn/deform_conv_op_test.cc | 39 ++++---- .../test/testdata/nn/deform_conv_test_gen.py | 32 ++++--- 8 files changed, 228 insertions(+), 165 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index cdd67c4117cbe..7aa65b98869c5 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -51,21 +51,20 @@ T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { // Output 'data_col' shape: [C_in * kH * kW, H_out * W_out] template void DeformableIm2col( - const T* data_im, // Input image [C, H, W] - const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] - const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (optional) - int64_t height, int64_t width, // Input dimensions - int64_t kernel_h, int64_t kernel_w, // Kernel dimensions - int64_t pad_h, int64_t pad_w, // Padding - int64_t stride_h, int64_t stride_w, // Stride - int64_t dilation_h, int64_t dilation_w, // Dilation - int64_t channels, // Input channels - int64_t offset_groups, // Number of offset groups - int64_t height_col, int64_t width_col, // Output dimensions - bool use_mask, - T* data_col, // Output buffer + const T* data_im, // Input image [C, H, W] + const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] + const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (optional) + int64_t height, int64_t width, // Input dimensions + int64_t kernel_h, int64_t kernel_w, // Kernel dimensions + int64_t pad_h, int64_t pad_w, // Padding + int64_t stride_h, int64_t stride_w, // Stride + int64_t dilation_h, int64_t dilation_w, // Dilation + int64_t channels, // Input channels + int64_t offset_groups, // Number of offset groups + int64_t height_col, int64_t width_col, // Output dimensions + bool use_mask, // Use mask + T* data_col, // Output buffer concurrency::ThreadPool* thread_pool) { - const int64_t channel_per_offset_group = channels / offset_groups; // Loop order optimized for cache locality: @@ -86,7 +85,6 @@ void DeformableIm2col( // Iterate over kernel window for (int64_t i = 0; i < kernel_h; ++i) { for (int64_t j = 0; j < kernel_w; ++j) { - // Calculate the index in the offset/mask tensors. // The offset tensor is organized as: (offset_groups, 2 * kH * kW, H_out, W_out). // Flattened offset channel index relative to the start of the tensor: @@ -142,7 +140,7 @@ Status DeformConv::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); const auto* W = context->Input(1); const auto* offset = context->Input(2); - const auto* B = context->Input(3); // optional + const auto* B = context->Input(3); // optional const auto* mask = context->Input(4); // optional DeformConvParams params; @@ -178,7 +176,7 @@ Status DeformConv::Compute(OpKernelContext* context) const { const int64_t kernel_size = kH * kW; const int64_t output_image_size = out_h * out_w; const int64_t input_image_size = H * W_in; - const int64_t kernel_dim = C / group * kernel_size; // The "K" dimension for GEMM (per group) + const int64_t kernel_dim = C / group * kernel_size; // The "K" dimension for GEMM (per group) // Total col buffer size: (C * kH * kW) * (out_h * out_w) // We allocate this per image to save memory compared to batch allocation if N is large, @@ -200,7 +198,6 @@ Status DeformConv::Compute(OpKernelContext* context) const { // Main Loop: Iterate over Batch for (int64_t n = 0; n < N; ++n) { - // 1. Perform Im2Col for the current image n // Pointers for current image const T* X_curr = Xdata + n * (C * input_image_size); @@ -251,16 +248,16 @@ Status DeformConv::Compute(OpKernelContext* context) const { math::Gemm( CblasNoTrans, CblasNoTrans, - narrow(M / group), // M - narrow(output_image_size),// N - narrow(kernel_dim), // K - static_cast(1), // alpha - weight_g, // A - col_g, // B - static_cast(0), // beta - Y_g, // C + narrow(M / group), // M + narrow(output_image_size), // N + narrow(kernel_dim), // K + static_cast(1), // alpha + weight_g, // A + col_g, // B + static_cast(0), // beta + Y_g, // C thread_pool, - nullptr); // mlas_backend_kernel_selector_config + nullptr); // mlas_backend_kernel_selector_config } } @@ -289,16 +286,16 @@ Status DeformConv::Compute(OpKernelContext* context) const { template class DeformConv; template class DeformConv; -#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ - ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ - DeformConv, 19, 21, T, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - DeformConv); \ - ONNX_CPU_OPERATOR_TYPED_KERNEL( \ - DeformConv, 22, T, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + DeformConv, 19, 21, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + DeformConv, 22, T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ DeformConv) REGISTER_DEFORMCONV_KERNEL_TYPED(float) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index f2a8fdb58bd64..d103bc132f076 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -37,15 +37,15 @@ struct DeformConvAttributes { // Field names align with ONNX DeformConv spec: https://onnx.ai/onnx/operators/onnx__DeformConv.html struct DeformConvParams { // Input X shape (N, C, H, W) - int64_t N{0}; // Batch size - int64_t C{0}; // Number of input channels - int64_t H{0}; // Input height - int64_t W_in{0}; // Input width (W_in to avoid collision with weight W) + int64_t N{0}; // Batch size + int64_t C{0}; // Number of input channels + int64_t H{0}; // Input height + int64_t W_in{0}; // Input width (W_in to avoid collision with weight W) // Weight W shape (oC, C/group, kH, kW) - int64_t M{0}; // Number of output channels (oC) - int64_t kH{0}; // Kernel height - int64_t kW{0}; // Kernel width + int64_t M{0}; // Number of output channels (oC) + int64_t kH{0}; // Kernel height + int64_t kW{0}; // Kernel width // Pads [x1_begin, x2_begin, x1_end, x2_end] for spatial axes H, W int64_t pad_h{0}; @@ -64,10 +64,10 @@ struct DeformConvParams { int64_t offset_group{1}; // Number of groups of offset // Output Y shape (N, oC, oH, oW) - int64_t out_h{0}; // Output height (oH) - int64_t out_w{0}; // Output width (oW) + int64_t out_h{0}; // Output height (oH) + int64_t out_w{0}; // Output width (oW) - bool use_mask{false}; // Whether optional mask input is provided + bool use_mask{false}; // Whether optional mask input is provided }; // Validates inputs and parses attributes into params. diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index eb917ed6e47e4..1b7912f4990c1 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -60,15 +60,15 @@ size_t GetDeformConvEffectiveMaxTempBytes() { size_t tier_cap; if (free_mem > 16ULL * 1024 * 1024 * 1024) { - tier_cap = 2ULL * 1024 * 1024 * 1024; // 16GB+ free → 2GB + tier_cap = 2ULL * 1024 * 1024 * 1024; // 16GB+ free → 2GB } else if (free_mem > 8ULL * 1024 * 1024 * 1024) { - tier_cap = 1ULL * 1024 * 1024 * 1024; // 8-16GB → 1GB + tier_cap = 1ULL * 1024 * 1024 * 1024; // 8-16GB → 1GB } else if (free_mem > 4ULL * 1024 * 1024 * 1024) { - tier_cap = 512ULL * 1024 * 1024; // 4-8GB → 512MB + tier_cap = 512ULL * 1024 * 1024; // 4-8GB → 512MB } else if (free_mem > 2ULL * 1024 * 1024 * 1024) { - tier_cap = 256ULL * 1024 * 1024; // 2-4GB → 256MB + tier_cap = 256ULL * 1024 * 1024; // 2-4GB → 256MB } else { - tier_cap = 128ULL * 1024 * 1024; // <2GB → 128MB + tier_cap = 128ULL * 1024 * 1024; // <2GB → 128MB } return std::max(kMinTempMemSize, std::min(tier_cap, free_mem)); @@ -246,23 +246,23 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - DeformConv, \ - kOnnxDomain, \ - 19, \ - 21, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - DeformConv); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - DeformConv, \ - kOnnxDomain, \ - 22, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_DEFORMCONV_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 19, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DeformConv); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DeformConv, \ + kOnnxDomain, \ + 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ DeformConv); REGISTER_DEFORMCONV_KERNEL_TYPED(float) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index a55538ca1ddc2..b2f4e479bdd93 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -20,7 +20,9 @@ namespace { constexpr int kDeformConvThreadsPerBlock = 256; template -struct DeformConvKSize { static constexpr int value = N; }; +struct DeformConvKSize { + static constexpr int value = N; +}; // Calculate grid size with a safety limit to prevent overflow. // Since we use grid-stride loops in kernels, limiting the grid size is safe. @@ -158,7 +160,6 @@ __global__ void DeformableIm2ColKernel( DivMod channel_per_offset_grp_div, bool use_mask, T* __restrict__ data_col) { - constexpr bool is_fixed = (kH >= 0 && kW >= 0); const int64_t h_dim = is_fixed ? kH : weight_h; const int64_t w_dim = is_fixed ? kW : weight_w; @@ -281,10 +282,9 @@ template __global__ void DeformConvAddBiasKernel( T* Y, const T* B, - DivMod spatial_div, // For dividing by (H * W) - DivMod channel_div, // For dividing by M (channel count) + DivMod spatial_div, // For dividing by (H * W) + DivMod channel_div, // For dividing by M (channel count) int64_t total_elements) { - for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += blockDim.x * gridDim.x) { int64_t val = idx; int64_t batch_channel_idx, pixel_idx; @@ -348,8 +348,7 @@ Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, i B, spatial_div, channel_div, - total - ); + total); return CUDA_CALL(cudaGetLastError()); } @@ -400,7 +399,7 @@ Status DeformConvIm2ColImpl( const int64_t col_numel = static_cast(C) * kH * kW * parallel_imgs * out_h * out_w; const bool use_64bit = (num_kernels > static_cast(std::numeric_limits::max())) || - (col_numel > static_cast(std::numeric_limits::max())); + (col_numel > static_cast(std::numeric_limits::max())); int blocks = GetGridSize(static_cast(num_kernels), kDeformConvThreadsPerBlock); @@ -439,11 +438,11 @@ Status DeformConvIm2ColImpl( template Status DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool); INST_DeformConvIm2ColImpl(float) -INST_DeformConvIm2ColImpl(double) -INST_DeformConvIm2ColImpl(half) -INST_DeformConvIm2ColImpl(BFloat16) + INST_DeformConvIm2ColImpl(double) + INST_DeformConvIm2ColImpl(half) + INST_DeformConvIm2ColImpl(BFloat16) -template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); + template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); @@ -454,37 +453,37 @@ template Status DeformConvAddBiasImpl(cudaStream_t, half*, const half*, in template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); // Delegate ORT type to CUDA type (e.g. MLFloat16 -> half); avoids repeating three identical specializations. -#define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ - template <> \ - Status DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ - const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ - int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ - int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ - int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ - int64_t dilation_h, int64_t dilation_w, int64_t offset_group, bool use_mask) { \ - return DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ - reinterpret_cast(offset), \ - mask ? reinterpret_cast(mask) : nullptr, \ - reinterpret_cast(col_buffer), \ - parallel_imgs, C, H, W, kH, kW, out_h, out_w, \ - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, \ - offset_group, use_mask); \ - } \ - template <> \ - Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ - const ORT_T* gemm_output, ORT_T* Y_g, \ - int64_t M, int64_t M_per_group, \ - int64_t output_image_size, int64_t cur_parallel) { \ - return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ - reinterpret_cast(gemm_output), \ - reinterpret_cast(Y_g), \ - M, M_per_group, output_image_size, cur_parallel); \ - } \ - template <> \ - Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T* Y, const ORT_T* B, \ - int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ - return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ - reinterpret_cast(B), N, M, out_h, out_w); \ +#define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ + template <> \ + Status DeformConvIm2ColImpl(cudaStream_t stream, const ORT_T* input, \ + const ORT_T* offset, const ORT_T* mask, ORT_T* col_buffer, \ + int64_t parallel_imgs, int64_t C, int64_t H, int64_t W, \ + int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, \ + int64_t pad_h, int64_t pad_w, int64_t stride_h, int64_t stride_w, \ + int64_t dilation_h, int64_t dilation_w, int64_t offset_group, bool use_mask) { \ + return DeformConvIm2ColImpl(stream, reinterpret_cast(input), \ + reinterpret_cast(offset), \ + mask ? reinterpret_cast(mask) : nullptr, \ + reinterpret_cast(col_buffer), \ + parallel_imgs, C, H, W, kH, kW, out_h, out_w, \ + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, \ + offset_group, use_mask); \ + } \ + template <> \ + Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ + const ORT_T* gemm_output, ORT_T* Y_g, \ + int64_t M, int64_t M_per_group, \ + int64_t output_image_size, int64_t cur_parallel) { \ + return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ + reinterpret_cast(gemm_output), \ + reinterpret_cast(Y_g), \ + M, M_per_group, output_image_size, cur_parallel); \ + } \ + template <> \ + Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T * Y, const ORT_T* B, \ + int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ + return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ + reinterpret_cast(B), N, M, out_h, out_w); \ } DELEGATE_DEFORM_CONV_IMPL(MLFloat16, half) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h index 2f38d6ef18a7c..0c26cb55311bc 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h @@ -38,10 +38,10 @@ Status DeformConvCopyGemmOutputRowMajorToNCHW( template Status DeformConvIm2ColImpl( cudaStream_t stream, - const T* input, // [parallel_imgs, C, H, W] - const T* offset, // [parallel_imgs, offset_group*2*kH*kW, out_h, out_w] - const T* mask, // [parallel_imgs, offset_group*kH*kW, out_h, out_w] or nullptr - T* col_buffer, // [C*kH*kW, parallel_imgs*out_h*out_w] + const T* input, // [parallel_imgs, C, H, W] + const T* offset, // [parallel_imgs, offset_group*2*kH*kW, out_h, out_w] + const T* mask, // [parallel_imgs, offset_group*kH*kW, out_h, out_w] or nullptr + T* col_buffer, // [C*kH*kW, parallel_imgs*out_h*out_w] int64_t parallel_imgs, int64_t C, int64_t H, diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py index 860e2f0322e72..cb3cd38d120a6 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -8,22 +8,39 @@ [pad_h, pad_w, pad_h, pad_w] are derived from a single (pad_h, pad_w) pair. Asymmetric pads (e.g. pad_top != pad_bottom) would require PyTorch API support and are not generated. """ + import torch import torchvision.ops + def _pair(x): if isinstance(x, int): return (x, x) return x + def to_cpp_list(t: torch.Tensor, fmt="{:.6f}") -> str: """Flatten tensor in NCHW order and format as C++ initializer list.""" t = t.detach().float().contiguous() return ", ".join(fmt.format(x) + "f" for x in t.flatten().tolist()) -def run_case(name: str, batch_sz: int, n_in: int, n_out: int, n_weight_grps: int, n_offset_grps: int, - kernel_h: int, kernel_w: int, stride: tuple, pad: tuple, dilation: tuple, - in_h: int, in_w: int, seed: int = 42): + +def run_case( + name: str, + batch_sz: int, + n_in: int, + n_out: int, + n_weight_grps: int, + n_offset_grps: int, + kernel_h: int, + kernel_w: int, + stride: tuple, + pad: tuple, + dilation: tuple, + in_h: int, + in_w: int, + seed: int = 42, +): """Build inputs with seed, run deform_conv2d, print C++ snippets.""" torch.manual_seed(seed) stride_h, stride_w = _pair(stride) @@ -41,15 +58,21 @@ def run_case(name: str, batch_sz: int, n_in: int, n_out: int, n_weight_grps: int # Standard answer from torchvision out = torchvision.ops.deform_conv2d( - x, offset, weight, bias=bias, - stride=(stride_h, stride_w), padding=(pad_h, pad_w), dilation=(dil_h, dil_w), mask=mask + x, + offset, + weight, + bias=bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dil_h, dil_w), + mask=mask, ) # ONNX pads = [top, left, bottom, right] (symmetric: single pad_h, pad_w expanded) pads_onnx = [pad_h, pad_w, pad_h, pad_w] print(f"// --- {name} (seed={seed}) ---") - print(f"// Shapes: X({batch_sz},{n_in},{in_h},{in_w}) W({n_out},{n_in//n_weight_grps},{kernel_h},{kernel_w})") + print(f"// Shapes: X({batch_sz},{n_in},{in_h},{in_w}) W({n_out},{n_in // n_weight_grps},{kernel_h},{kernel_w})") print(f"// stride=({stride_h},{stride_w}) pad=({pad_h},{pad_w}) dilation=({dil_h},{dil_w})") print(f"// out_h={out_h} out_w={out_w}") print() @@ -60,10 +83,24 @@ def run_case(name: str, batch_sz: int, n_in: int, n_out: int, n_weight_grps: int print("std::vector mask = {" + to_cpp_list(mask) + "};") print("std::vector expected_Y = {" + to_cpp_list(out) + "};") print() - print("// Params: kernel_shape={" + f"{kernel_h}, {kernel_w}" + "}, stride={" + f"{stride_h}, {stride_w}" + "}, pads={" + ", ".join(map(str, pads_onnx)) + "}, dilations={" + f"{dil_h}, {dil_w}" + "}, group=" + str(n_weight_grps) + ", offset_group=" + str(n_offset_grps)) + print( + "// Params: kernel_shape={" + f"{kernel_h}, {kernel_w}" + "}, stride={" + f"{stride_h}, {stride_w}" + "}, pads={" + + ", ".join(map(str, pads_onnx)) + + "}, dilations={" + + f"{dil_h}, {dil_w}" + + "}, group=" + + str(n_weight_grps) + + ", offset_group=" + + str(n_offset_grps) + ) print() return out + def main(): print("// Generated by deform_conv_expected_gen.py (torchvision.ops.deform_conv2d)") print() @@ -72,10 +109,17 @@ def main(): run_case( "PyTorch get_fn_args style (batch=1)", batch_sz=1, - n_in=6, n_out=2, n_weight_grps=2, n_offset_grps=3, - kernel_h=3, kernel_w=2, - stride=(2, 1), pad=(1, 0), dilation=(2, 1), - in_h=5, in_w=4, + n_in=6, + n_out=2, + n_weight_grps=2, + n_offset_grps=3, + kernel_h=3, + kernel_w=2, + stride=(2, 1), + pad=(1, 0), + dilation=(2, 1), + in_h=5, + in_w=4, seed=42, ) @@ -83,7 +127,7 @@ def main(): torch.manual_seed(42) n_in, n_out = 6, 2 n_weight_grps, n_offset_grps = 2, 3 - kH, kW = 3, 2 + kH, kW = 3, 2 # noqa: N806 stride_h, stride_w = 2, 1 pad_h, pad_w = 1, 0 dil_h, dil_w = 2, 1 @@ -98,8 +142,14 @@ def main(): bias = torch.randn(n_out, dtype=torch.float32) out_no_mask = torchvision.ops.deform_conv2d( - x, offset, weight, bias=bias, - stride=(stride_h, stride_w), padding=(pad_h, pad_w), dilation=(dil_h, dil_w), mask=None + x, + offset, + weight, + bias=bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(dil_h, dil_w), + mask=None, ) print("// --- Same inputs, no mask (expected_Y when mask is omitted) ---") print("std::vector expected_Y_no_mask = {" + to_cpp_list(out_no_mask) + "};") @@ -109,12 +159,20 @@ def main(): run_case( "Groups with non-zero offset (batch=1, 2 groups)", batch_sz=1, - n_in=4, n_out=2, n_weight_grps=2, n_offset_grps=2, - kernel_h=2, kernel_w=2, - stride=(1, 1), pad=(0, 0), dilation=(1, 1), - in_h=3, in_w=3, + n_in=4, + n_out=2, + n_weight_grps=2, + n_offset_grps=2, + kernel_h=2, + kernel_w=2, + stride=(1, 1), + pad=(0, 0), + dilation=(1, 1), + in_h=3, + in_w=3, seed=123, ) + if __name__ == "__main__": main() diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index fa5f6a0945c15..26d2ba2dcc256 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -98,9 +98,13 @@ void RunDeformConvTest(const DeformConvTestParams& params, const int64_t kW = params.kernel_shape[1]; // ONNX pads format: [pad_top, pad_left, pad_bottom, pad_right] = [pad[0], pad[1], pad[2], pad[3]] const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - - params.dilation[0] * (kH - 1) - 1) / params.stride[0] + 1; + params.dilation[0] * (kH - 1) - 1) / + params.stride[0] + + 1; const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - - params.dilation[1] * (kW - 1) - 1) / params.stride[1] + 1; + params.dilation[1] * (kW - 1) - 1) / + params.stride[1] + + 1; OpTester test("DeformConv", opset); test.AddAttribute("kernel_shape", params.kernel_shape); @@ -161,9 +165,8 @@ TEST(DeformConvTest, MinimalBilinear) { // Layout: offset[n,c,oh,ow]. Flattened (NCHW): [ch0@00, ch0@01, ch0@10, ch0@11, ch1@00, ch1@01, ch1@10, ch1@11] // (0,0): (0.5, 0.5)->center of [1,2;3,4]->2.5; (0,1): (0,-1)->(0,0)->1; (1,0): (0,0)->3; (1,1): (0,0)->4 std::vector offset = { - 0.5f, 0.f, 0.f, 0.f, - 0.5f, -1.0f, 0.f, 0.f - }; + 0.5f, 0.f, 0.f, 0.f, + 0.5f, -1.0f, 0.f, 0.f}; std::vector B = {0.f}; std::vector mask = {1.f, 1.f, 1.f, 1.f}; // (1,1,2,2) std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; @@ -190,9 +193,8 @@ TEST(DeformConvTest, MinimalBilinearFP16) { std::vector X = {1.f, 2.f, 3.f, 4.f}; std::vector W = {1.f}; std::vector offset = { - 0.5f, 0.f, 0.f, 0.f, - 0.5f, -1.0f, 0.f, 0.f - }; + 0.5f, 0.f, 0.f, 0.f, + 0.5f, -1.0f, 0.f, 0.f}; std::vector B = {0.f}; std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; @@ -225,9 +227,8 @@ TEST(DeformConvTest, MinimalBilinearBFloat16) { std::vector X = {1.f, 2.f, 3.f, 4.f}; std::vector W = {1.f}; std::vector offset = { - 0.5f, 0.f, 0.f, 0.f, - 0.5f, -1.0f, 0.f, 0.f - }; + 0.5f, 0.f, 0.f, 0.f, + 0.5f, -1.0f, 0.f, 0.f}; std::vector B = {0.f}; std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; @@ -254,9 +255,8 @@ TEST(DeformConvTest, MinimalBilinearDouble) { std::vector X = {1.f, 2.f, 3.f, 4.f}; std::vector W = {1.f}; std::vector offset = { - 0.5f, 0.f, 0.f, 0.f, - 0.5f, -1.0f, 0.f, 0.f - }; + 0.5f, 0.f, 0.f, 0.f, + 0.5f, -1.0f, 0.f, 0.f}; std::vector B = {0.f}; std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; @@ -397,7 +397,7 @@ TEST(DeformConvTest, ForwardNoMask) { test.AddOptionalInputEdge(); // no mask test.AddOutput("Y", Y_shape, expected_Y, false, 1e-4f, 1e-4f); std::unordered_set excluded = {kTensorrtExecutionProvider, - kOpenVINOExecutionProvider, kQnnExecutionProvider}; + kOpenVINOExecutionProvider, kQnnExecutionProvider}; test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } @@ -443,7 +443,7 @@ TEST(DeformConvTest, EmptyBatch) { test.AddOptionalInputEdge(); test.AddOutput("Y", Y_shape, expected_Y); std::unordered_set excluded = {kTensorrtExecutionProvider, - kOpenVINOExecutionProvider, kQnnExecutionProvider}; + kOpenVINOExecutionProvider, kQnnExecutionProvider}; test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } @@ -582,7 +582,7 @@ TEST(DeformConvTest, NonSquareKernel) { const size_t x_size = static_cast(1 * 1 * 4 * 5); const size_t w_size = static_cast(1 * 1 * 2 * 3); const size_t offset_size = static_cast(1 * 1 * 2 * 2 * 3 * out_h * out_w); // n_offset_grps * 2 * kH * kW * out_h * out_w - const size_t mask_size = static_cast(1 * 1 * 2 * 3 * out_h * out_w); // n_offset_grps * kH * kW * out_h * out_w + const size_t mask_size = static_cast(1 * 1 * 2 * 3 * out_h * out_w); // n_offset_grps * kH * kW * out_h * out_w std::vector X(x_size, 0.1f); std::vector W(w_size, 0.1f); @@ -822,8 +822,7 @@ TEST(DeformConvTest, OffsetAtPixelCenters) { std::vector X = {1.f, 2.f, 3.f, 4.f}; std::vector W = {0.25f, 0.25f, 0.25f, 0.25f}; std::vector offset = { - 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f - }; + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; std::vector B = {0.f}; std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {1.6875f}; // op output: one center sample 2.5 + boundary samples @@ -868,7 +867,7 @@ TEST(DeformConvTest, LargeBatchSize) { TEST(DeformConvTest, Group1OffsetGroup2) { DeformConvTestParams p = {}; p.batch_sz = 1; - p.n_in_channels = 4; // C must be divisible by offset_group + p.n_in_channels = 4; // C must be divisible by offset_group p.n_out_channels = 2; p.n_weight_grps = 1; p.n_offset_grps = 2; diff --git a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py index 301f14617e9a4..1a61ff5da08bd 100644 --- a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py +++ b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py @@ -18,13 +18,19 @@ - deform_conv_test_data.npz (X, W, offset, B, mask, expected_Y) - deform_conv_test_data.inc (C++ arrays for op test) """ + from pathlib import Path import numpy as np import onnx -from onnx import TensorProto, helper + +try: + import onnxruntime as ort +except ImportError: + ort = None import torch import torchvision.ops +from onnx import TensorProto, helper # Config: groups=2, offset_group=2, 2x2 kernel (from deform_conv_expected_gen Case 3) BATCH = 1 @@ -53,8 +59,14 @@ def _generate_reference(): bias = torch.randn(N_OUT, dtype=torch.float32) out = torchvision.ops.deform_conv2d( - x, offset, weight, bias=bias, - stride=(STRIDE_H, STRIDE_W), padding=(PAD_H, PAD_W), dilation=(DIL_H, DIL_W), mask=mask + x, + offset, + weight, + bias=bias, + stride=(STRIDE_H, STRIDE_W), + padding=(PAD_H, PAD_W), + dilation=(DIL_H, DIL_W), + mask=mask, ) return { @@ -90,11 +102,11 @@ def _build_onnx_model(): [ helper.make_tensor_value_info("X", TensorProto.FLOAT, [BATCH, N_IN, IN_H, IN_W]), helper.make_tensor_value_info("W", TensorProto.FLOAT, [N_OUT, N_IN // N_WEIGHT_GRPS, KH, KW]), - helper.make_tensor_value_info("offset", TensorProto.FLOAT, - [BATCH, N_OFFSET_GRPS * 2 * KH * KW, OUT_H, OUT_W]), + helper.make_tensor_value_info( + "offset", TensorProto.FLOAT, [BATCH, N_OFFSET_GRPS * 2 * KH * KW, OUT_H, OUT_W] + ), helper.make_tensor_value_info("B", TensorProto.FLOAT, [N_OUT]), - helper.make_tensor_value_info("mask", TensorProto.FLOAT, - [BATCH, N_OFFSET_GRPS * KH * KW, OUT_H, OUT_W]), + helper.make_tensor_value_info("mask", TensorProto.FLOAT, [BATCH, N_OFFSET_GRPS * KH * KW, OUT_H, OUT_W]), ], [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [BATCH, N_OUT, OUT_H, OUT_W])], ) @@ -153,9 +165,7 @@ def main(): print(f" Saved {inc_path}") # Validate with onnxruntime if available - try: - import onnxruntime as ort - + if ort is not None: print("Validating with ONNX Runtime...") sess = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) ort_out = sess.run( @@ -175,7 +185,7 @@ def main(): else: diff = np.abs(ort_out.astype(np.float64) - data["expected_Y"].astype(np.float64)) print(f" FAIL: max |diff|={diff.max()}, mean={diff.mean()}") - except ImportError: + else: print(" (onnxruntime not installed; skip validation)") From d9f65fb7473c1b5218d72647e29ea5fa278b94a5 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 3 Mar 2026 02:01:37 +0800 Subject: [PATCH 24/58] Fix cuda fp16 test cases --- .../providers/cpu/nn/deform_conv_op_test.cc | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 26d2ba2dcc256..a00b736ffa9db 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -176,7 +176,15 @@ TEST(DeformConvTest, MinimalBilinear) { // Minimal case FP16: Same as MinimalBilinear but in FP16. // Validates CUDA FP16 implementation (specifically coordinate precision logic). +// DeformConv FP16 is CUDA-only; skip when CUDA is not available (e.g., Linux x64/arm64 CPU-only builds). +#if defined(USE_CUDA) TEST(DeformConvTest, MinimalBilinearFP16) { + int min_cuda_architecture = 530; // FP16 requires SM 5.3+ + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "DeformConv FP16: CUDA not available, skipping."; + return; + } + DeformConvTestParams p = {}; p.batch_sz = 1; p.n_in_channels = 1; @@ -203,7 +211,6 @@ TEST(DeformConvTest, MinimalBilinearFP16) { } // Minimal case BFloat16: Same as MinimalBilinear but in BFloat16 (CUDA BFloat16 coordinate precision). -#if defined(USE_CUDA) TEST(DeformConvTest, MinimalBilinearBFloat16) { int min_cuda_architecture = 800; if (!HasCudaEnvironment(min_cuda_architecture)) { @@ -235,7 +242,7 @@ TEST(DeformConvTest, MinimalBilinearBFloat16) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); } -#endif +#endif // defined(USE_CUDA) // Minimal case Double (FP64): Same as MinimalBilinear in double precision. TEST(DeformConvTest, MinimalBilinearDouble) { @@ -264,8 +271,15 @@ TEST(DeformConvTest, MinimalBilinearDouble) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } -// Forward with mask and bias FP16 +// Forward with mask and bias FP16 (CUDA-only; skip when CUDA not available). +#if defined(USE_CUDA) TEST(DeformConvTest, ForwardWithMaskAndBiasFP16) { + int min_cuda_architecture = 530; // FP16 requires SM 5.3+ + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "DeformConv FP16: CUDA not available, skipping."; + return; + } + DeformConvTestParams p = {}; p.batch_sz = 2; p.n_in_channels = 4; @@ -305,6 +319,8 @@ TEST(DeformConvTest, ForwardWithMaskAndBiasFP16) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } +#endif // defined(USE_CUDA) + // With offset=0 and mask=1, Y = Conv(X,W) + B. Use small inputs and compute expected. TEST(DeformConvTest, ForwardWithMaskAndBias) { DeformConvTestParams p = {}; From 15fe856f8fe00985226faedfde795c6e1149e430 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Thu, 5 Mar 2026 21:48:47 +0800 Subject: [PATCH 25/58] Fix int64_t to ptrdiff_t conversion in deform_conv --- onnxruntime/core/providers/cpu/nn/deform_conv.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 7aa65b98869c5..73dc277d0c185 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -73,7 +73,7 @@ void DeformableIm2col( // This ensures sequential access to data_col and better locality for data_im. concurrency::ThreadPool::TryParallelFor( - thread_pool, channels, 1.0, + thread_pool, static_cast(channels), 1.0, [&](ptrdiff_t c_im_start, ptrdiff_t c_im_end) { for (int64_t c_im = c_im_start; c_im < c_im_end; ++c_im) { const int64_t offset_grp = c_im / channel_per_offset_group; @@ -265,7 +265,7 @@ Status DeformConv::Compute(OpKernelContext* context) const { if (Bdata != nullptr) { int64_t total_work = N * M; concurrency::ThreadPool::TryParallelFor( - thread_pool, total_work, static_cast(output_image_size), + thread_pool, static_cast(total_work), static_cast(output_image_size), [&](ptrdiff_t first, ptrdiff_t last) { for (ptrdiff_t idx = first; idx < last; ++idx) { int64_t n = idx / M; From 931c3862b904b9a99a648d96ae10deefb3380e21 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Fri, 6 Mar 2026 19:06:53 +0800 Subject: [PATCH 26/58] Resolve pipeline failures caused by unit tests --- onnxruntime/test/testdata/deform_conv_test_data.inc | 12 ++++++------ onnxruntime/test/testdata/nn/deform_conv_test_gen.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/test/testdata/deform_conv_test_data.inc b/onnxruntime/test/testdata/deform_conv_test_data.inc index 7f9901c823d55..206d8517dd3e3 100644 --- a/onnxruntime/test/testdata/deform_conv_test_data.inc +++ b/onnxruntime/test/testdata/deform_conv_test_data.inc @@ -2,9 +2,9 @@ #include -static const std::vector kDeformConvOnnxTest_X = {0.296111941, 0.516562283, 0.251670718, 0.68855679, 0.0739724636, 0.866521955, 0.136579871, 0.102479041, 0.184056461, 0.726446748, 0.315253913, 0.687106669, 0.075635314, 0.196638167, 0.316411972, 0.401740134, 0.118568301, 0.82739538, 0.382084429, 0.660493851, 0.853571773, 0.593153, 0.636725366, 0.982629359, 0.274495304, 0.658375621, 0.277541935, 0.857324839, 0.899328232, 0.0390138626, 0.926822901, 0.738757193, 0.717883527, 0.705837429, 0.915649533, 0.433980227}; -static const std::vector kDeformConvOnnxTest_W = {-1.18204546, -0.287744999, -0.604300678, 0.600236714, -1.42047262, -0.223827749, 0.430554837, -0.89885664, -0.0178579595, 0.426403075, -0.765740693, -0.0545141846, -0.732052684, 1.23474216, 1.18622088, -0.220098898}; -static const std::vector kDeformConvOnnxTest_offset = {-0.388483077, -0.934345901, -0.499144107, -1.08665264, 0.962421, 0.249208495, -0.484502077, -2.09291434, 0.0982837752, -0.0935074314, 0.266214728, -0.585035503, -0.343037993, -0.682147384, -0.988689423, -1.70183039, -1.2202903, 1.31385386, 1.05329967, 0.138805181, -0.204444751, -2.26852894, -0.913327932, -0.420362711, -0.659559608, -0.797927678, 0.18383126, 0.229347408, 0.617742658, -0.287577927, 0.821824312, 0.151177585, -0.0443819836, 1.62355745, -2.32287097, 1.08783054, -0.0635453761, -0.448640704, -1.27846932, -1.14400387, -0.152640373, 0.116741188, 0.44026047, -1.44654655, -0.558081627, -0.0516963229, -0.90827328, 0.350683212, -0.394808769, 0.489227712, -0.216814891, -1.74716449, 1.72284174, 0.773806036, 0.404629797, -1.64612663, -0.59508425, -0.711217523, 0.622964859, -1.37288189, -0.128064156, -1.28383458, -0.290120065, 1.27674019}; -static const std::vector kDeformConvOnnxTest_B = {0.983955026, 0.204511523}; -static const std::vector kDeformConvOnnxTest_mask = {-0.0318612382, -0.478955716, 0.766808629, 0.0274681915, 0.0474699028, -0.92386651, -1.06073678, -2.32444572, -2.06281757, 0.00637452863, -0.989554703, 0.701609194, -0.982237995, 0.277030349, 0.645495057, -0.895680785, 0.492752999, -0.0140781598, -0.274662733, -0.764091492, -0.58715719, 1.1951654, -1.20957518, -0.556007624, -0.0771045536, 1.27737665, -1.45962942, -2.15952778, -0.70670861, -0.92224431, 3.89537215, -0.602696717}; -static const std::vector kDeformConvOnnxTest_expected_Y = {0.971546292, 1.1398586, 0.452816963, 1.86388242, -0.565265715, 1.42318761, -2.46283293, -0.104923099}; +static const std::vector kDeformConvOnnxTest_X = {0.296111941f, 0.516562283f, 0.251670718f, 0.68855679f, 0.0739724636f, 0.866521955f, 0.136579871f, 0.102479041f, 0.184056461f, 0.726446748f, 0.315253913f, 0.687106669f, 0.075635314f, 0.196638167f, 0.316411972f, 0.401740134f, 0.118568301f, 0.82739538f, 0.382084429f, 0.660493851f, 0.853571773f, 0.593153f, 0.636725366f, 0.982629359f, 0.274495304f, 0.658375621f, 0.277541935f, 0.857324839f, 0.899328232f, 0.0390138626f, 0.926822901f, 0.738757193f, 0.717883527f, 0.705837429f, 0.915649533f, 0.433980227f}; +static const std::vector kDeformConvOnnxTest_W = {-1.18204546f, -0.287744999f, -0.604300678f, 0.600236714f, -1.42047262f, -0.223827749f, 0.430554837f, -0.89885664f, -0.0178579595f, 0.426403075f, -0.765740693f, -0.0545141846f, -0.732052684f, 1.23474216f, 1.18622088f, -0.220098898f}; +static const std::vector kDeformConvOnnxTest_offset = {-0.388483077f, -0.934345901f, -0.499144107f, -1.08665264f, 0.962421f, 0.249208495f, -0.484502077f, -2.09291434f, 0.0982837752f, -0.0935074314f, 0.266214728f, -0.585035503f, -0.343037993f, -0.682147384f, -0.988689423f, -1.70183039f, -1.2202903f, 1.31385386f, 1.05329967f, 0.138805181f, -0.204444751f, -2.26852894f, -0.913327932f, -0.420362711f, -0.659559608f, -0.797927678f, 0.18383126f, 0.229347408f, 0.617742658f, -0.287577927f, 0.821824312f, 0.151177585f, -0.0443819836f, 1.62355745f, -2.32287097f, 1.08783054f, -0.0635453761f, -0.448640704f, -1.27846932f, -1.14400387f, -0.152640373f, 0.116741188f, 0.44026047f, -1.44654655f, -0.558081627f, -0.0516963229f, -0.90827328f, 0.350683212f, -0.394808769f, 0.489227712f, -0.216814891f, -1.74716449f, 1.72284174f, 0.773806036f, 0.404629797f, -1.64612663f, -0.59508425f, -0.711217523f, 0.622964859f, -1.37288189f, -0.128064156f, -1.28383458f, -0.290120065f, 1.27674019f}; +static const std::vector kDeformConvOnnxTest_B = {0.983955026f, 0.204511523f}; +static const std::vector kDeformConvOnnxTest_mask = {-0.0318612382f, -0.478955716f, 0.766808629f, 0.0274681915f, 0.0474699028f, -0.92386651f, -1.06073678f, -2.32444572f, -2.06281757f, 0.00637452863f, -0.989554703f, 0.701609194f, -0.982237995f, 0.277030349f, 0.645495057f, -0.895680785f, 0.492752999f, -0.0140781598f, -0.274662733f, -0.764091492f, -0.58715719f, 1.1951654f, -1.20957518f, -0.556007624f, -0.0771045536f, 1.27737665f, -1.45962942f, -2.15952778f, -0.70670861f, -0.92224431f, 3.89537215f, -0.602696717f}; +static const std::vector kDeformConvOnnxTest_expected_Y = {0.971546292f, 1.1398586f, 0.452816963f, 1.86388242f, -0.565265715f, 1.42318761f, -2.46283293f, -0.104923099f}; diff --git a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py index 1a61ff5da08bd..8fcd9e34a26a6 100644 --- a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py +++ b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py @@ -119,7 +119,7 @@ def _build_onnx_model(): def _to_cpp_array(name: str, arr: np.ndarray) -> str: """Format numpy array as C++ initializer list.""" flat = arr.flatten().tolist() - vals = ", ".join(f"{x:.9g}" for x in flat) + vals = ", ".join(f"{x:.9g}f" for x in flat) return f"static const std::vector {name} = {{{vals}}};" From fedd389896dfd3ae7339e111f12fe19063fe0a34 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Fri, 6 Mar 2026 19:09:16 +0800 Subject: [PATCH 27/58] Add comments and handle unused variables --- onnxruntime/core/providers/cpu/nn/deform_conv.cc | 3 +++ .../core/providers/cuda/nn/deform_conv_impl.cu | 13 +++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 73dc277d0c185..e8c85dc6e4c33 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -205,6 +205,9 @@ Status DeformConv::Compute(OpKernelContext* context) const { const T* mask_curr = use_mask ? (mask_data + n * (offset_group * kernel_size * output_image_size)) : nullptr; T* col_buffer_ptr = col_buffer.get(); + // DeformableIm2col only needs pad_h, pad_w (begin-side pads) for coordinate mapping. + // pad_h_end and pad_w_end are used in out_h/out_w computation (params) but do not affect + // the im2col sampling logic; they only influence output dimensions. DeformableIm2col( X_curr, offset_curr, diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index b2f4e479bdd93..a6e1286f6f6d8 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -299,6 +299,7 @@ __global__ void DeformConvAddBiasKernel( // Equivalent to: channel_idx = batch_channel_idx % M; // We only need channel_idx (i.e. m) channel_div.divmod(batch_channel_idx, batch_idx, channel_idx); + (void)batch_idx; // Only channel_idx is needed // channel_idx is what we need (i.e. m) Y[idx] += DeformConvLdg(B + channel_idx); @@ -435,14 +436,14 @@ Status DeformConvIm2ColImpl( } #define INST_DeformConvIm2ColImpl(T) \ - template Status DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool); + template Status DeformConvIm2ColImpl(cudaStream_t, const T*, const T*, const T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, bool) -INST_DeformConvIm2ColImpl(float) - INST_DeformConvIm2ColImpl(double) - INST_DeformConvIm2ColImpl(half) - INST_DeformConvIm2ColImpl(BFloat16) +INST_DeformConvIm2ColImpl(float); +INST_DeformConvIm2ColImpl(double); +INST_DeformConvIm2ColImpl(half); +INST_DeformConvIm2ColImpl(BFloat16); - template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); From 7ebc49889f5dabc273dba6772a9e81ac78d4b29d Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 05:39:52 +0800 Subject: [PATCH 28/58] Address review feedback and align with Conv behavior - Add offset/mask batch size checks in DeformConvValidateAndParse - Replace cudaMemGetInfo call with UpdateState-style shape-based caching - Match Conv pattern: lock at ComputeInternal entry, consistent state handling --- .../providers/cpu/nn/deform_conv_attributes.h | 2 + .../core/providers/cuda/nn/deform_conv.cc | 47 ++++++++++++++++--- .../core/providers/cuda/nn/deform_conv.h | 15 ++++++ 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index d103bc132f076..ea1b5761f0913 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -119,6 +119,7 @@ inline Status DeformConvValidateAndParse( // Validate tensor shapes ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); + ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size."); ORT_RETURN_IF_NOT( offset_shape[1] == params.offset_group * 2 * params.kH * params.kW, "Offset channel count must be offset_group * 2 * kH * kW."); @@ -131,6 +132,7 @@ inline Status DeformConvValidateAndParse( // Validate mask if present if (params.use_mask) { ORT_RETURN_IF_NOT(mask_shape->NumDimensions() == 4, "Mask must be 4D."); + ORT_RETURN_IF_NOT((*mask_shape)[0] == params.N, "Mask batch size must match input batch size."); ORT_RETURN_IF_NOT( (*mask_shape)[1] == params.offset_group * params.kH * params.kW, "Mask channel count must be offset_group * kH * kW."); diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 1b7912f4990c1..e81bbbf300b67 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -8,6 +8,7 @@ #include "deform_conv_impl.h" #include "core/common/narrow.h" +#include "core/common/span_utils.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -48,6 +49,7 @@ int GetGreatestDivisorBelowBound(int n, int bound) { // Returns effective max temp memory (bytes) for DeformConv batching. // Uses 90% of free GPU memory with tiered cap; fallback 256MB if cudaMemGetInfo fails. // Mirrors Conv's approach (conv_8.h); tiered limits avoid OOM on smaller GPUs. +// Called only when input/weight shapes change (see UpdateState). size_t GetDeformConvEffectiveMaxTempBytes() { constexpr size_t kDefaultFallback = 256ULL * 1024 * 1024; constexpr size_t kMinTempMemSize = 32ULL * 1024 * 1024; @@ -76,10 +78,45 @@ size_t GetDeformConvEffectiveMaxTempBytes() { } // namespace +template +Status DeformConv::UpdateState(OpKernelContext* context, + const DeformConvParams& params, + int& n_parallel_imgs) const { + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto x_dims = X->Shape().AsShapeVector(); + const auto w_dims = W->Shape().AsShapeVector(); + + bool input_dims_changed = (state_.last_x_dims != x_dims); + bool w_dims_changed = (state_.last_w_dims != w_dims); + + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) { + state_.last_x_dims = gsl::make_span(x_dims); + } + if (w_dims_changed) { + state_.last_w_dims = gsl::make_span(w_dims); + } + + const int64_t kernel_size = params.kH * params.kW; + const int64_t output_image_size = params.out_h * params.out_w; + const size_t bytes_per_image = SafeInt(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T); + const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(); + const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); + const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); + state_.cached_n_parallel_imgs = GetGreatestDivisorBelowBound(static_cast(params.N), target_parallel_imgs); + } + + n_parallel_imgs = state_.cached_n_parallel_imgs; + return Status::OK(); +} + template Status DeformConv::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; + std::lock_guard lock(state_.mutex); + const auto* X = context->Input(0); const auto* W = context->Input(1); const auto* offset = context->Input(2); @@ -113,18 +150,14 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } + int n_parallel_imgs; + ORT_RETURN_IF_ERROR(UpdateState(context, params, n_parallel_imgs)); + const int64_t kernel_size = kH * kW; const int64_t output_image_size = out_h * out_w; const int64_t input_image_size = H * W_in; const int64_t kernel_dim = (C / group) * kernel_size; - // col_buffer: C * kernel_size * output_image_size; gemm_output_buffer: (M/group) * output_image_size - const size_t bytes_per_image = SafeInt(output_image_size) * (C * kernel_size + M / group) * sizeof(T); - const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(); - const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); - const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); - - const int n_parallel_imgs = GetGreatestDivisorBelowBound(static_cast(N), target_parallel_imgs); const int64_t col_stride = static_cast(n_parallel_imgs) * output_image_size; const int64_t col_buffer_size = (C * kernel_size) * col_stride; diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h index fa564641d4b98..f89c105327b38 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.h +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h @@ -3,14 +3,24 @@ #pragma once +#include + #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "core/framework/tensor_shape.h" #include "core/providers/cpu/nn/deform_conv_attributes.h" #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { namespace cuda { +struct DeformConvState { + TensorShape last_x_dims; + TensorShape last_w_dims; + int cached_n_parallel_imgs{0}; + std::mutex mutex; +}; + template class DeformConv final : public CudaKernel { public: @@ -19,7 +29,12 @@ class DeformConv final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; private: + Status UpdateState(OpKernelContext* context, + const DeformConvParams& params, + int& n_parallel_imgs) const; + DeformConvAttributes attrs_; + mutable DeformConvState state_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeformConv); }; From 0479aded6896eaf02908e23eb36c457cc858178e Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 05:47:11 +0800 Subject: [PATCH 29/58] Optimize DeformConv cpu bias add with Eigen SIMD --- onnxruntime/core/providers/cpu/nn/deform_conv.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index e8c85dc6e4c33..e9d5252ade110 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -8,6 +8,7 @@ #include #include "core/common/common.h" +#include "core/util/math_cpuonly.h" #include "core/common/narrow.h" #include "core/util/math.h" @@ -274,10 +275,8 @@ Status DeformConv::Compute(OpKernelContext* context) const { int64_t n = idx / M; int64_t m = idx % M; T* Y_ptr = Ydata + n * M * output_image_size + m * output_image_size; - T bias_val = Bdata[m]; - for (int64_t i = 0; i < output_image_size; ++i) { - Y_ptr[i] += bias_val; - } + // Vectorized: Y_ptr[i] += Bdata[m] for i in [0, output_image_size); uses Eigen SIMD. + EigenVectorArrayMap(Y_ptr, narrow(output_image_size)) += Bdata[m]; } }); } From f7819f14748a2d78099078c75f76c9920468a2b3 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 05:50:43 +0800 Subject: [PATCH 30/58] Document GEMM layout trick in DeformConv cuBLAS path --- .../core/providers/cuda/nn/deform_conv.cc | 34 +++++++------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index e81bbbf300b67..3b1a30d9f1112 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -216,29 +216,19 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { const T* col_g = col_buffer.get() + g * kernel_dim * col_stride; T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size; - // Avoid physical transpose by using cuBLAS OP_N/OP_N logic. - // We want Y = W * Col. - // W is [M/group, kernel_dim] (Row-Major). - // Col is [kernel_dim, cur_out_size] (Row-Major). - // We compute Y^T = Col^T * W^T. - // Col^T (Col-Major [cur_out_size, kernel_dim]) is exactly Col (Row-Major [kernel_dim, cur_out_size]) in memory. - // W^T (Col-Major [kernel_dim, M/group]) is exactly W (Row-Major [M/group, kernel_dim]) in memory. - // Result Y^T is Col-Major [cur_out_size, M/group]. - // In memory, Y^T (Col-Major) is exactly Y (Row-Major [M/group, cur_out_size]). - // So we get Y in Row-Major layout. - - // A = Col (Row-Major [kernel_dim, cur_out_size]) -> interpreted as Col-Major [cur_out_size, kernel_dim]. - // B = W (Row-Major [M/group, kernel_dim]) -> interpreted as Col-Major [kernel_dim, M/group]. - // C = A * B = Col^T * W^T = Y^T. - // C is Col-Major [cur_out_size, M/group]. - // m = cur_out_size, n = M/group, k = kernel_dim. - // lda = cur_out_size. - // ldb = kernel_dim. - // ldc = cur_out_size. + // GEMM layout trick: compute Y = W * Col without physical transpose. // - // When cur_parallel == 1: cur_out_size == output_image_size, so C layout (pos, channel) matches - // NCHW Y_g[0, channel, pos] exactly. Write directly to Y_g and skip the copy kernel. - // When cur_parallel > 1: layouts differ, must copy via DeformConvCopyGemmOutputRowMajorToNCHW. + // Our data is row-major: W [M/group, kernel_dim], Col [kernel_dim, cur_out_size], Y [M/group, cur_out_size]. + // cuBLAS is column-major. Key insight: row-major A[M,K] in memory equals column-major A^T[K,M]. + // We compute Y^T = Col^T * W^T by passing Col as A and W as B, both OP_N (no transpose): + // - Col (row [kernel_dim, cur_out_size]) -> cuBLAS interprets as col-major [cur_out_size, kernel_dim] = Col^T + // - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T + // - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major + // + // cublasGemmHelper: m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size + // + // cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write + // into Y_g. cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW. const bool gemm_writes_directly = (cur_parallel == 1); CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( From 34fae7d1f0b8d36a676911cf99a758216ddf5eab Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 05:55:49 +0800 Subject: [PATCH 31/58] Use int64_t for bilinear interpolation indices --- onnxruntime/core/providers/cpu/nn/deform_conv.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index e9d5252ade110..02038ccc91645 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -24,10 +24,10 @@ T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { return static_cast(0); } - const int h_low = static_cast(std::floor(h)); - const int w_low = static_cast(std::floor(w)); - const int h_high = h_low + 1; - const int w_high = w_low + 1; + const int64_t h_low = static_cast(std::floor(h)); + const int64_t w_low = static_cast(std::floor(w)); + const int64_t h_high = h_low + 1; + const int64_t w_high = w_low + 1; const T lh = h - static_cast(h_low); const T lw = w - static_cast(w_low); From 173fd6be3e37da16fd7a49dd20268874af6f3921 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 06:13:31 +0800 Subject: [PATCH 32/58] refactor(DeformConv CPU): template UseMask and improve im2col performance - Make use_mask a template parameter (UseMask) to eliminate branches at compile time - Refactor DeformableIm2col: parallelize over channels*kernel_size instead of channels; each task processes one full row for better cache/SIMD behavior - Always interpolate then multiply by mask (no skip when mask==0) to avoid branch misprediction and enable vectorization; document rationale in comments - Add English comments for design choices, tensor layouts, and GEMM pipeline --- .../core/providers/cpu/nn/deform_conv.cc | 227 +++++++++--------- 1 file changed, 110 insertions(+), 117 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 02038ccc91645..d63d619df924a 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -16,29 +16,35 @@ namespace onnxruntime { namespace { -// Bilinear interpolation at (h, w). Returns 0 if out of bounds. +// Bilinear interpolation at fractional coordinates (h, w). +// Returns 0 if (h,w) is out of bounds; otherwise computes weighted average of the four +// nearest integer grid points. Standard implementation for deformable convolution sampling. template T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { - // Check boundaries + // Out-of-bounds: return zero (same as zero-padding). if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { return static_cast(0); } + // Integer floor and ceiling indices for the 2x2 neighborhood. const int64_t h_low = static_cast(std::floor(h)); const int64_t w_low = static_cast(std::floor(w)); const int64_t h_high = h_low + 1; const int64_t w_high = w_low + 1; + // Fractional parts for interpolation weights. const T lh = h - static_cast(h_low); const T lw = w - static_cast(w_low); const T hh = static_cast(1) - lh; const T hw = static_cast(1) - lw; + // Load four corner values; use 0 when index is out of bounds. const T v1 = (h_low >= 0 && w_low >= 0) ? in[h_low * width + w_low] : static_cast(0); const T v2 = (h_low >= 0 && w_high < width) ? in[h_low * width + w_high] : static_cast(0); const T v3 = (h_high < height && w_low >= 0) ? in[h_high * width + w_low] : static_cast(0); const T v4 = (h_high < height && w_high < width) ? in[h_high * width + w_high] : static_cast(0); + // Bilinear weights: (1-lh)*(1-lw), (1-lh)*lw, lh*(1-lw), lh*lw. const T w1 = hh * hw; const T w2 = hh * lw; const T w3 = lh * hw; @@ -48,86 +54,86 @@ T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { } // Deformable Im2Col for a SINGLE image. -// Converts the input image into a matrix suitable for GEMM. +// Converts the input image into a matrix suitable for GEMM by sampling with learned offsets. // Output 'data_col' shape: [C_in * kH * kW, H_out * W_out] -template +// When UseMask=false, pass nullptr for data_mask; compiler eliminates dead code for mask. +template void DeformableIm2col( const T* data_im, // Input image [C, H, W] const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] - const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (optional) - int64_t height, int64_t width, // Input dimensions - int64_t kernel_h, int64_t kernel_w, // Kernel dimensions - int64_t pad_h, int64_t pad_w, // Padding - int64_t stride_h, int64_t stride_w, // Stride - int64_t dilation_h, int64_t dilation_w, // Dilation + const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (nullptr when UseMask=false) + int64_t height, int64_t width, // Input spatial dimensions + int64_t kernel_h, int64_t kernel_w, // Kernel dimensions + int64_t pad_h, int64_t pad_w, // Padding (begin) for H and W + int64_t stride_h, int64_t stride_w, // Stride for H and W + int64_t dilation_h, int64_t dilation_w, // Dilation for H and W int64_t channels, // Input channels - int64_t offset_groups, // Number of offset groups - int64_t height_col, int64_t width_col, // Output dimensions - bool use_mask, // Use mask - T* data_col, // Output buffer + int64_t offset_groups, // Number of offset groups (channels shared per group) + int64_t height_col, int64_t width_col, // Output spatial dimensions (H_out, W_out) + T* data_col, // Output buffer for im2col result concurrency::ThreadPool* thread_pool) { const int64_t channel_per_offset_group = channels / offset_groups; + const int64_t kernel_size = kernel_h * kernel_w; + const int64_t output_size = height_col * width_col; - // Loop order optimized for cache locality: - // Outer loop: Channels - // Inner loop: Spatial locations (c_col) - // This ensures sequential access to data_col and better locality for data_im. - + // Parallelize over (channel, kernel_position) so each task processes one full row of data_col. + // This yields channels*kernel_size tasks, better CPU utilization and cache-friendly sequential writes. concurrency::ThreadPool::TryParallelFor( - thread_pool, static_cast(channels), 1.0, - [&](ptrdiff_t c_im_start, ptrdiff_t c_im_end) { - for (int64_t c_im = c_im_start; c_im < c_im_end; ++c_im) { + thread_pool, + static_cast(channels * kernel_size), + static_cast(output_size) * 10.0, + [&](ptrdiff_t begin, ptrdiff_t end) { + for (ptrdiff_t idx = begin; idx < end; ++idx) { + // Decompose idx into (c_im, i, j): which channel and kernel position. + const int64_t j = static_cast(idx) % kernel_w; + const int64_t i = (static_cast(idx) / kernel_w) % kernel_h; + const int64_t c_im = static_cast(idx) / kernel_size; const int64_t offset_grp = c_im / channel_per_offset_group; - for (int64_t c_col = 0; c_col < height_col * width_col; ++c_col) { - const int64_t w_col = c_col % width_col; - const int64_t h_col = c_col / width_col; - - // Iterate over kernel window - for (int64_t i = 0; i < kernel_h; ++i) { - for (int64_t j = 0; j < kernel_w; ++j) { - // Calculate the index in the offset/mask tensors. - // The offset tensor is organized as: (offset_groups, 2 * kH * kW, H_out, W_out). - // Flattened offset channel index relative to the start of the tensor: - // base = offset_grp * (2 * kH * kW). - // specific = 2 * (i * kW + j). - - const int64_t data_offset_h_ptr = - ((offset_grp * (2 * kernel_h * kernel_w) + 2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - - const int64_t data_offset_w_ptr = - ((offset_grp * (2 * kernel_h * kernel_w) + 2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - - const int64_t data_mask_ptr = - ((offset_grp * (kernel_h * kernel_w) + (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - - const T offset_h = data_offset[data_offset_h_ptr]; - const T offset_w = data_offset[data_offset_w_ptr]; - - T val = static_cast(0); - T mask_val = static_cast(1); - if (use_mask) { - mask_val = data_mask[data_mask_ptr]; - } - - // Only compute interpolation if mask is not zero (optimization) - if (mask_val != 0) { - const T h_im = h_col * stride_h - pad_h + i * dilation_h + offset_h; - const T w_im = w_col * stride_w - pad_w + j * dilation_w + offset_w; - - // Map (c_im, h_im, w_im) back to input - // data_im is [C, H, W] - const T* data_im_ptr = data_im + c_im * (height * width); - val = BilinearInterpolate(data_im_ptr, height, width, h_im, w_im); - } - - // Assign to data_col - // The layout of data_col row is: [Channel, KernelH, KernelW] flattened. - // Row index: c_im * (kH * kW) + i * kW + j - const int64_t col_row_idx = (c_im * kernel_h * kernel_w) + (i * kernel_w + j); - - data_col[col_row_idx * (height_col * width_col) + c_col] = val * mask_val; + // Output row: one (channel, kernel_pos) across all spatial locations. + T* col_ptr = data_col + static_cast(idx) * output_size; + const T* im_ptr = data_im + c_im * height * width; + + // Offset tensor layout: [offset_grp, 2*kH*kW, H_out, W_out] flattened. + // For (i,j) we use channel indices 2*(i*kW+j) and 2*(i*kW+j)+1 for offset_h, offset_w. + const int64_t offset_base = + offset_grp * 2 * kernel_size + 2 * (i * kernel_w + j); + + // Mask base index; only used when UseMask=true (compiler removes when false). + [[maybe_unused]] int64_t mask_base = 0; + if constexpr (UseMask) { + mask_base = offset_grp * kernel_size + i * kernel_w + j; + } + + // Loop over output spatial positions. + for (int64_t h_col = 0; h_col < height_col; ++h_col) { + for (int64_t w_col = 0; w_col < width_col; ++w_col) { + const int64_t spatial_idx = h_col * width_col + w_col; + + // Fetch learned offsets for this (output_pos, kernel_pos). + const T offset_h = + data_offset[offset_base * output_size + spatial_idx]; + const T offset_w = + data_offset[(offset_base + 1) * output_size + spatial_idx]; + + // Deformed sampling coordinates (fractional, for bilinear interpolation). + const T h_im = h_col * stride_h - pad_h + i * dilation_h + offset_h; + const T w_im = w_col * stride_w - pad_w + j * dilation_w + offset_w; + + // Sample input at deformed location; returns 0 if out of bounds. + T val = BilinearInterpolate(im_ptr, height, width, h_im, w_im); + + // Modulate by mask when UseMask=true; compiled away when false. + // Design choice: we always interpolate then multiply, rather than skip when mask==0. + // Rationale: (1) Skipping adds a branch; unpredictable mask values cause misprediction + // penalties (~15-20 cycles). (2) Straight-line code vectorizes better; conditional + // skip blocks SIMD. (3) Multiplying by 0 is cheap when vectorized. In typical DCN + // usage (moderate mask density), the unconditional path usually wins. + if constexpr (UseMask) { + val *= data_mask[mask_base * output_size + spatial_idx]; } + + col_ptr[spatial_idx] = val; } } } @@ -166,22 +172,21 @@ Status DeformConv::Compute(OpKernelContext* context) const { const int64_t out_w = params.out_w; const bool use_mask = params.use_mask; - // Allocate Output + // Allocate output tensor [N, M, out_h, out_w]. const TensorShape Y_shape({N, M, out_h, out_w}); Tensor* Y = context->Output(0, Y_shape); if (Y->Shape().Size() == 0) { return Status::OK(); } - // Common sizes + // Precompute common sizes for the im2col + GEMM pipeline. const int64_t kernel_size = kH * kW; const int64_t output_image_size = out_h * out_w; const int64_t input_image_size = H * W_in; - const int64_t kernel_dim = C / group * kernel_size; // The "K" dimension for GEMM (per group) + const int64_t kernel_dim = C / group * kernel_size; // K dimension for GEMM: C/group * kH * kW - // Total col buffer size: (C * kH * kW) * (out_h * out_w) - // We allocate this per image to save memory compared to batch allocation if N is large, - // or simply because Im2Col is easier to implement per-image. + // Col buffer: shape [C*kH*kW, out_h*out_w]. Allocate per-image (process one image at a time) + // to reduce peak memory when N is large; im2col is implemented per-image anyway. const int64_t col_buffer_size = (C * kernel_size) * output_image_size; AllocatorPtr alloc; @@ -197,58 +202,46 @@ Status DeformConv::Compute(OpKernelContext* context) const { concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); - // Main Loop: Iterate over Batch + // Process each image in the batch. for (int64_t n = 0; n < N; ++n) { - // 1. Perform Im2Col for the current image n - // Pointers for current image + // Step 1: Deformable Im2Col for image n. + // Gather deformed samples into col buffer for GEMM. const T* X_curr = Xdata + n * (C * input_image_size); const T* offset_curr = offset_data + n * (offset_group * 2 * kernel_size * output_image_size); const T* mask_curr = use_mask ? (mask_data + n * (offset_group * kernel_size * output_image_size)) : nullptr; T* col_buffer_ptr = col_buffer.get(); - // DeformableIm2col only needs pad_h, pad_w (begin-side pads) for coordinate mapping. - // pad_h_end and pad_w_end are used in out_h/out_w computation (params) but do not affect - // the im2col sampling logic; they only influence output dimensions. - DeformableIm2col( - X_curr, - offset_curr, - mask_curr, - H, W_in, - kH, kW, - pad_h, pad_w, - stride_h, stride_w, - dilation_h, dilation_w, - C, - offset_group, - out_h, out_w, - use_mask, - col_buffer_ptr, - thread_pool); - - // 2. Perform GEMM for each group + // Dispatch to template instantiation: UseMask=true or false eliminates branch in hot loop. + // Note: pad_h, pad_w are begin-side paddings for coordinate mapping; pad_h_end/pad_w_end + // affect only output size (already baked into out_h, out_w), not im2col sampling. + if (use_mask) { + DeformableIm2col( + X_curr, offset_curr, mask_curr, + H, W_in, kH, kW, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + C, offset_group, out_h, out_w, + col_buffer_ptr, thread_pool); + } else { + DeformableIm2col( + X_curr, offset_curr, nullptr, + H, W_in, kH, kW, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + C, offset_group, out_h, out_w, + col_buffer_ptr, thread_pool); + } + + // Step 2: GEMM for each group. Y = W * Col (per group). for (int64_t g = 0; g < group; ++g) { - // Weight pointer for group g - // Weight shape: [M, C/group, kH, kW]. - // Stride for group g is (M/group) * (C/group * kH * kW). + // Weight for group g: shape [M/group, C/group, kH, kW], row-major. const T* weight_g = Wdata + g * (M / group) * kernel_dim; - // Col buffer pointer for group g - // Col buffer shape: [C * kH * kW, output_image_size] - // We need the rows corresponding to group g. - // Row stride: output_image_size - // Group stride: (C/group * kH * kW) * output_image_size + // Col rows for group g: layout [C*kH*kW, out_h*out_w], group g spans rows [g*kernel_dim, (g+1)*kernel_dim). const T* col_g = col_buffer_ptr + g * kernel_dim * output_image_size; - // Output pointer for group g - // Output shape: [N, M, out_h, out_w] - // Current image offset: n * M * output_image_size - // Group offset: g * (M/group) * output_image_size + // Output slice for group g: [n, g*M/group:(g+1)*M/group, out_h, out_w]. T* Y_g = Ydata + n * M * output_image_size + g * (M / group) * output_image_size; - // Y = W * Col - // W matrix: [M/group, kernel_dim] - // Col matrix: [kernel_dim, output_image_size] - // Y matrix: [M/group, output_image_size] + // GEMM: Y = W * Col. W [M/group, kernel_dim], Col [kernel_dim, output_image_size]. math::Gemm( CblasNoTrans, CblasNoTrans, @@ -265,7 +258,7 @@ Status DeformConv::Compute(OpKernelContext* context) const { } } - // 3. Add Bias if present + // Step 3: Add bias if provided (broadcast over spatial dimensions). if (Bdata != nullptr) { int64_t total_work = N * M; concurrency::ThreadPool::TryParallelFor( @@ -275,7 +268,7 @@ Status DeformConv::Compute(OpKernelContext* context) const { int64_t n = idx / M; int64_t m = idx % M; T* Y_ptr = Ydata + n * M * output_image_size + m * output_image_size; - // Vectorized: Y_ptr[i] += Bdata[m] for i in [0, output_image_size); uses Eigen SIMD. + // Eigen vectorized add: Y_ptr += Bdata[m] over all spatial positions. EigenVectorArrayMap(Y_ptr, narrow(output_image_size)) += Bdata[m]; } }); From 33e4866b2c5d82f61170f917fc40543d05307698 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 06:28:57 +0800 Subject: [PATCH 33/58] perf(DeformConv CPU): optimize im2col and BilinearInterpolate - Precompute ptr_offset_h/ptr_offset_w and ptr_mask to avoid redundant multiplies in the inner loop - Extract base_h, base_w (-pad + i*dilation, etc.) as invariants outside the spatial loop - Add BilinearInterpolate fast path when all 4 neighbors are in-bounds --- .../core/providers/cpu/nn/deform_conv.cc | 59 ++++++++++--------- .../providers/cpu/nn/deform_conv_op_test.cc | 10 ++-- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index d63d619df924a..8cb8830c5e8e5 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -21,36 +21,35 @@ namespace { // nearest integer grid points. Standard implementation for deformable convolution sampling. template T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { - // Out-of-bounds: return zero (same as zero-padding). - if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { - return static_cast(0); - } - - // Integer floor and ceiling indices for the 2x2 neighborhood. const int64_t h_low = static_cast(std::floor(h)); const int64_t w_low = static_cast(std::floor(w)); const int64_t h_high = h_low + 1; const int64_t w_high = w_low + 1; - // Fractional parts for interpolation weights. + // Fast path: all 4 corners in bounds (h in [0, height-1), w in [0, width-1)). + // Most sampling points in deformable conv fall here; avoids 4 per-corner branches. + if (h_low >= 0 && h_high < height && w_low >= 0 && w_high < width) { + const T lh = h - static_cast(h_low); + const T lw = w - static_cast(w_low); + const T hh = static_cast(1) - lh; + const T hw = static_cast(1) - lw; + return hh * hw * in[h_low * width + w_low] + hh * lw * in[h_low * width + w_high] + + lh * hw * in[h_high * width + w_low] + lh * lw * in[h_high * width + w_high]; + } + + // Slow path: near boundary or out of bounds. + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return static_cast(0); + } const T lh = h - static_cast(h_low); const T lw = w - static_cast(w_low); const T hh = static_cast(1) - lh; const T hw = static_cast(1) - lw; - - // Load four corner values; use 0 when index is out of bounds. const T v1 = (h_low >= 0 && w_low >= 0) ? in[h_low * width + w_low] : static_cast(0); const T v2 = (h_low >= 0 && w_high < width) ? in[h_low * width + w_high] : static_cast(0); const T v3 = (h_high < height && w_low >= 0) ? in[h_high * width + w_low] : static_cast(0); const T v4 = (h_high < height && w_high < width) ? in[h_high * width + w_high] : static_cast(0); - - // Bilinear weights: (1-lh)*(1-lw), (1-lh)*lw, lh*(1-lw), lh*lw. - const T w1 = hh * hw; - const T w2 = hh * lw; - const T w3 = lh * hw; - const T w4 = lh * lw; - - return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + return hh * hw * v1 + hh * lw * v2 + lh * hw * v3 + lh * lw * v4; } // Deformable Im2Col for a SINGLE image. @@ -96,13 +95,20 @@ void DeformableIm2col( // Offset tensor layout: [offset_grp, 2*kH*kW, H_out, W_out] flattened. // For (i,j) we use channel indices 2*(i*kW+j) and 2*(i*kW+j)+1 for offset_h, offset_w. + // Precompute pointers to avoid offset_base * output_size multiplication in inner loop. const int64_t offset_base = offset_grp * 2 * kernel_size + 2 * (i * kernel_w + j); + const T* ptr_offset_h = data_offset + offset_base * output_size; + const T* ptr_offset_w = data_offset + (offset_base + 1) * output_size; + + // Base terms for h_im, w_im: invariant in inner loop (i, j fixed). + const T base_h = -pad_h + static_cast(i) * dilation_h; + const T base_w = -pad_w + static_cast(j) * dilation_w; - // Mask base index; only used when UseMask=true (compiler removes when false). - [[maybe_unused]] int64_t mask_base = 0; + // Mask pointer; only used when UseMask=true (compiler removes when false). + [[maybe_unused]] const T* ptr_mask = nullptr; if constexpr (UseMask) { - mask_base = offset_grp * kernel_size + i * kernel_w + j; + ptr_mask = data_mask + (offset_grp * kernel_size + i * kernel_w + j) * output_size; } // Loop over output spatial positions. @@ -110,15 +116,12 @@ void DeformableIm2col( for (int64_t w_col = 0; w_col < width_col; ++w_col) { const int64_t spatial_idx = h_col * width_col + w_col; - // Fetch learned offsets for this (output_pos, kernel_pos). - const T offset_h = - data_offset[offset_base * output_size + spatial_idx]; - const T offset_w = - data_offset[(offset_base + 1) * output_size + spatial_idx]; + const T offset_h = ptr_offset_h[spatial_idx]; + const T offset_w = ptr_offset_w[spatial_idx]; // Deformed sampling coordinates (fractional, for bilinear interpolation). - const T h_im = h_col * stride_h - pad_h + i * dilation_h + offset_h; - const T w_im = w_col * stride_w - pad_w + j * dilation_w + offset_w; + const T h_im = h_col * stride_h + base_h + offset_h; + const T w_im = w_col * stride_w + base_w + offset_w; // Sample input at deformed location; returns 0 if out of bounds. T val = BilinearInterpolate(im_ptr, height, width, h_im, w_im); @@ -130,7 +133,7 @@ void DeformableIm2col( // skip blocks SIMD. (3) Multiplying by 0 is cheap when vectorized. In typical DCN // usage (moderate mask density), the unconditional path usually wins. if constexpr (UseMask) { - val *= data_mask[mask_base * output_size + spatial_idx]; + val *= ptr_mask[spatial_idx]; } col_ptr[spatial_idx] = val; diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index a00b736ffa9db..0122fff265568 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -41,7 +41,8 @@ template <> struct DeformConvTestTraits { static std::vector Convert(const std::vector& v) { return v; } static std::unordered_set ExcludedProviders() { - return {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; + return {kTensorrtExecutionProvider, kNvTensorRTRTXExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; } static constexpr float DefaultRtol() { return 1e-5f; } static constexpr float DefaultAtol() { return 1e-5f; } @@ -51,7 +52,7 @@ template <> struct DeformConvTestTraits { static std::vector Convert(const std::vector& v) { return FloatsToMLFloat16s(v); } static std::unordered_set ExcludedProviders() { - return {kCpuExecutionProvider, kTensorrtExecutionProvider, + return {kCpuExecutionProvider, kNvTensorRTRTXExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; } static constexpr float DefaultRtol() { return 1e-2f; } @@ -64,7 +65,8 @@ struct DeformConvTestTraits { return std::vector(v.begin(), v.end()); } static std::unordered_set ExcludedProviders() { - return {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; + return {kTensorrtExecutionProvider, kNvTensorRTRTXExecutionProvider, kOpenVINOExecutionProvider, + kQnnExecutionProvider}; } static constexpr double DefaultRtol() { return 1e-8; } static constexpr double DefaultAtol() { return 1e-8; } @@ -75,7 +77,7 @@ template <> struct DeformConvTestTraits { static std::vector Convert(const std::vector& v) { return FloatsToBFloat16s(v); } static std::unordered_set ExcludedProviders() { - return {kCpuExecutionProvider, kTensorrtExecutionProvider, + return {kCpuExecutionProvider, kNvTensorRTRTXExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; } static constexpr float DefaultRtol() { return 1e-2f; } From a482eb5ec5f24a37b9f42485788e5d0f30639646 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 06:35:06 +0800 Subject: [PATCH 34/58] Early OOB check for BilinearInterpolate --- onnxruntime/core/providers/cpu/nn/deform_conv.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 8cb8830c5e8e5..a86b7a5041451 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -21,6 +21,11 @@ namespace { // nearest integer grid points. Standard implementation for deformable convolution sampling. template T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { + // Early exit for clearly out-of-bounds (skip floor() for OOB case). + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return static_cast(0); + } + const int64_t h_low = static_cast(std::floor(h)); const int64_t w_low = static_cast(std::floor(w)); const int64_t h_high = h_low + 1; @@ -37,10 +42,7 @@ T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { lh * hw * in[h_high * width + w_low] + lh * lw * in[h_high * width + w_high]; } - // Slow path: near boundary or out of bounds. - if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { - return static_cast(0); - } + // Slow path: near boundary (one or more of the 4 corners may be out of bounds). const T lh = h - static_cast(h_low); const T lw = w - static_cast(w_low); const T hh = static_cast(1) - lh; From b46f922c0e2e48cf6b436a06ae0916d644066d53 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 06:41:43 +0800 Subject: [PATCH 35/58] Shrink DeformConv CUDA mutex to UpdateState only --- onnxruntime/core/providers/cuda/nn/deform_conv.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 3b1a30d9f1112..c7241642eebe5 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -115,8 +115,6 @@ template Status DeformConv::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; - std::lock_guard lock(state_.mutex); - const auto* X = context->Input(0); const auto* W = context->Input(1); const auto* offset = context->Input(2); @@ -151,7 +149,10 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { } int n_parallel_imgs; - ORT_RETURN_IF_ERROR(UpdateState(context, params, n_parallel_imgs)); + { + std::lock_guard lock(state_.mutex); + ORT_RETURN_IF_ERROR(UpdateState(context, params, n_parallel_imgs)); + } const int64_t kernel_size = kH * kW; const int64_t output_image_size = out_h * out_w; From 82d122838a90fb129fa19f60e4ee433233ec6d0a Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 06:48:14 +0800 Subject: [PATCH 36/58] Use cublasGemmStridedBatched for gemm_writes_directly path in DeformConv CUDA --- .../core/providers/cuda/nn/deform_conv.cc | 86 ++++++++++++------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index c7241642eebe5..57933549cecc9 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -212,45 +212,73 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { offset_group, use_mask)); - for (int64_t g = 0; g < group; ++g) { - const T* W_g = Wdata + g * (M / group) * kernel_dim; - const T* col_g = col_buffer.get() + g * kernel_dim * col_stride; - T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size; - - // GEMM layout trick: compute Y = W * Col without physical transpose. - // - // Our data is row-major: W [M/group, kernel_dim], Col [kernel_dim, cur_out_size], Y [M/group, cur_out_size]. - // cuBLAS is column-major. Key insight: row-major A[M,K] in memory equals column-major A^T[K,M]. - // We compute Y^T = Col^T * W^T by passing Col as A and W as B, both OP_N (no transpose): - // - Col (row [kernel_dim, cur_out_size]) -> cuBLAS interprets as col-major [cur_out_size, kernel_dim] = Col^T - // - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T - // - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major - // - // cublasGemmHelper: m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size - // - // cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write - // into Y_g. cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW. - const bool gemm_writes_directly = (cur_parallel == 1); - - CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( + // GEMM layout trick: compute Y = W * Col without physical transpose. + // + // Our data is row-major: W [M/group, kernel_dim], Col [kernel_dim, cur_out_size], Y [M/group, cur_out_size]. + // cuBLAS is column-major. Key insight: row-major A[M,K] in memory equals column-major A^T[K,M]. + // We compute Y^T = Col^T * W^T by passing Col as A and W as B, both OP_N (no transpose): + // - Col (row [kernel_dim, cur_out_size]) -> cuBLAS interprets as col-major [cur_out_size, kernel_dim] = Col^T + // - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T + // - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major + // + // m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size. + // + // cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write + // directly into Y_g. Use strided batched for all groups in one call. + // cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW. + + const bool gemm_writes_directly = (cur_parallel == 1); + if (gemm_writes_directly) { + // Strided batched: one call for all groups. Strides between batches: + const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1 + const int64_t stride_w = (M / group) * kernel_dim; + const int64_t stride_y = (M / group) * output_image_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, - narrow(cur_out_size), + narrow(output_image_size), narrow(M / group), narrow(kernel_dim), &alpha, - reinterpret_cast(col_g), - narrow(cur_out_size), - reinterpret_cast(W_g), + reinterpret_cast(col_buffer.get()), + narrow(output_image_size), + stride_col, + reinterpret_cast(Wdata), narrow(kernel_dim), + stride_w, &beta, - (gemm_writes_directly ? reinterpret_cast(Y_g) : reinterpret_cast(gemm_output_buffer.get())), - narrow(gemm_writes_directly ? output_image_size : cur_out_size), + reinterpret_cast(Ydata + b * M * output_image_size), + narrow(output_image_size), + stride_y, + narrow(group), device_prop, - UseTF32()))); + UseTF32())); + } else { + // cur_parallel>1: GEMM output layout differs from NCHW; write to buffer then copy per group. + for (int64_t g = 0; g < group; ++g) { + const T* W_g = Wdata + g * (M / group) * kernel_dim; + const T* col_g = col_buffer.get() + g * kernel_dim * col_stride; + T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size; + + CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( + cublas, + CUBLAS_OP_N, + CUBLAS_OP_N, + narrow(cur_out_size), + narrow(M / group), + narrow(kernel_dim), + &alpha, + reinterpret_cast(col_g), + narrow(cur_out_size), + reinterpret_cast(W_g), + narrow(kernel_dim), + &beta, + reinterpret_cast(gemm_output_buffer.get()), + narrow(cur_out_size), + device_prop, + UseTF32()))); - if (!gemm_writes_directly) { ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW( stream, gemm_output_buffer.get(), From da18ee30d9f3c2120046630a7eaa36cff3eb2eba Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 06:50:49 +0800 Subject: [PATCH 37/58] Fix var name --- onnxruntime/core/providers/cuda/nn/deform_conv.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 57933549cecc9..8359e54e1e346 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -231,7 +231,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { if (gemm_writes_directly) { // Strided batched: one call for all groups. Strides between batches: const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1 - const int64_t stride_w = (M / group) * kernel_dim; + const int64_t stride_weight = (M / group) * kernel_dim; const int64_t stride_y = (M / group) * output_image_size; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, @@ -246,7 +246,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { stride_col, reinterpret_cast(Wdata), narrow(kernel_dim), - stride_w, + stride_weight, &beta, reinterpret_cast(Ydata + b * M * output_image_size), narrow(output_image_size), From dcd00c307611e4dd2325b3bd22a8e8271ef32839 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 07:04:01 +0800 Subject: [PATCH 38/58] Drop mask==0 branch in im2col to match CPU behavior --- onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index a6e1286f6f6d8..4edd48528eff3 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -232,14 +232,6 @@ __global__ void DeformableIm2ColKernel( if (use_mask) { // Access mask using pre-calculated base and stride. mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size); - - // [Optimization 1] Early Exit / Pruning - // If mask is 0, the contribution is 0. Skip expensive offset load and interpolation. - // Note: casting to float for comparison is safe for standard floating point types. - if (static_cast(mask_val) == 0.0f) { - data_col_ptr_base[kernel_idx * col_stride] = static_cast(0); - return; - } } // Calculate offset pointers relative to the base. @@ -255,7 +247,7 @@ __global__ void DeformableIm2ColKernel( T val = BilinearInterpolate(input_ptr, height, width, h_im, w_im); - // Write result to data_col using pre-calculated base. + // Match CPU path: always interpolate then apply mask to keep branch-free hot loop. data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; }; From 6e727c201138194e3f162ebcc6cc08c837a24a67 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 07:13:36 +0800 Subject: [PATCH 39/58] Add 1x1 im2col kernel specialization dispatch --- onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index 4edd48528eff3..25feb9484198e 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -417,7 +417,9 @@ Status DeformConvIm2ColImpl( } }; - if (kH == 3 && kW == 3) { + if (kH == 1 && kW == 1) { + launch(DeformConvKSize<1>{}, DeformConvKSize<1>{}); + } else if (kH == 3 && kW == 3) { launch(DeformConvKSize<3>{}, DeformConvKSize<3>{}); } else if (kH == 5 && kW == 5) { launch(DeformConvKSize<5>{}, DeformConvKSize<5>{}); From 0166fa1d1e3effedb8a840f93398b90f01ef9ca1 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 07:13:50 +0800 Subject: [PATCH 40/58] Reformat codes --- onnxruntime/core/providers/cpu/nn/deform_conv.cc | 4 ++-- onnxruntime/core/providers/cuda/nn/deform_conv.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index a86b7a5041451..9e3842f5e2e41 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -64,9 +64,9 @@ void DeformableIm2col( const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (nullptr when UseMask=false) int64_t height, int64_t width, // Input spatial dimensions - int64_t kernel_h, int64_t kernel_w, // Kernel dimensions + int64_t kernel_h, int64_t kernel_w, // Kernel dimensions int64_t pad_h, int64_t pad_w, // Padding (begin) for H and W - int64_t stride_h, int64_t stride_w, // Stride for H and W + int64_t stride_h, int64_t stride_w, // Stride for H and W int64_t dilation_h, int64_t dilation_w, // Dilation for H and W int64_t channels, // Input channels int64_t offset_groups, // Number of offset groups (channels shared per group) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 8359e54e1e346..edb760c12763e 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -230,7 +230,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { const bool gemm_writes_directly = (cur_parallel == 1); if (gemm_writes_directly) { // Strided batched: one call for all groups. Strides between batches: - const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1 + const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1 const int64_t stride_weight = (M / group) * kernel_dim; const int64_t stride_y = (M / group) * output_image_size; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( From f2d8f5dfa37755357fbfd1cdb151a66c06b4059e Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 7 Mar 2026 15:39:59 +0800 Subject: [PATCH 41/58] Fix C4244 in deform_conv_op_test by casting rtol/atol to float --- onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 0122fff265568..d09a656afc646 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -138,7 +138,9 @@ void RunDeformConvTest(const DeformConvTestParams& params, test.AddOptionalInputEdge(); } - test.AddOutput("Y", Y_shape, expected_Y_t, false, rtol, atol); + const float rtol_f = static_cast(rtol); + const float atol_f = static_cast(atol); + test.AddOutput("Y", Y_shape, expected_Y_t, false, rtol_f, atol_f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", DeformConvTestTraits::ExcludedProviders()); } From 6ead850129bd3bf31a57f97733f6aee3c34d1898 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Thu, 12 Mar 2026 19:05:00 +0800 Subject: [PATCH 42/58] Add standard MIT license header --- .../providers/cpu/nn/deform_conv_expected_gen.py | 12 +++++++----- onnxruntime/test/testdata/nn/deform_conv_test_gen.py | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py index cb3cd38d120a6..a6c4923cd961e 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -1,4 +1,6 @@ -#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + """ Generate expected outputs for DeformConv tests using torchvision.ops.deform_conv2d. Run with: .venv/bin/python onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py @@ -13,7 +15,7 @@ import torchvision.ops -def _pair(x): +def _pair(x: int | tuple[int, int]) -> tuple[int, int]: if isinstance(x, int): return (x, x) return x @@ -34,9 +36,9 @@ def run_case( n_offset_grps: int, kernel_h: int, kernel_w: int, - stride: tuple, - pad: tuple, - dilation: tuple, + stride: tuple[int, int] | int, + pad: tuple[int, int] | int, + dilation: tuple[int, int] | int, in_h: int, in_w: int, seed: int = 42, diff --git a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py index 8fcd9e34a26a6..3d1f276846181 100644 --- a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py +++ b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py @@ -1,4 +1,6 @@ -#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + """ Generate DeformConv ONNX model and test data for cross-platform validation. From 7b11badcce1efbc799ebb8772e753fab0f41fba3 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Thu, 12 Mar 2026 19:07:47 +0800 Subject: [PATCH 43/58] Register DeformConv BFloat16 only for opset 22 --- .../core/providers/cuda/nn/deform_conv.cc | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index edb760c12763e..96fd84fb6bb25 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -315,12 +315,21 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - DeformConv); - -REGISTER_DEFORMCONV_KERNEL_TYPED(float) -REGISTER_DEFORMCONV_KERNEL_TYPED(double) -REGISTER_DEFORMCONV_KERNEL_TYPED(MLFloat16) -REGISTER_DEFORMCONV_KERNEL_TYPED(BFloat16) + DeformConv) + +REGISTER_DEFORMCONV_KERNEL_TYPED(float); +REGISTER_DEFORMCONV_KERNEL_TYPED(double); +REGISTER_DEFORMCONV_KERNEL_TYPED(MLFloat16); + +// BFloat16 only for opset 22; opset 19-21 do not support BFloat16. +ONNX_OPERATOR_TYPED_KERNEL_EX( + DeformConv, + kOnnxDomain, + 22, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + DeformConv); #undef REGISTER_DEFORMCONV_KERNEL_TYPED From 4a7227635062c037295fd0a3b925c2c6b85d1eaf Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Thu, 12 Mar 2026 19:14:22 +0800 Subject: [PATCH 44/58] Validate DeformConv input ranks and output size before indexing shapes --- onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index ea1b5761f0913..c7ad2011f3d68 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -79,6 +79,10 @@ inline Status DeformConvValidateAndParse( const TensorShape& offset_shape, const TensorShape* mask_shape, DeformConvParams& params) { + ORT_RETURN_IF_NOT(X_shape.NumDimensions() == 4, "Input X must be 4D (N, C, H, W)."); + ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); + ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); + // Parse input shapes params.N = X_shape[0]; params.C = X_shape[1]; @@ -115,10 +119,9 @@ inline Status DeformConvValidateAndParse( params.out_h = (params.H + params.pad_h + params.pad_h_end - params.dilation_h * (params.kH - 1) - 1) / params.stride_h + 1; params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1; + ORT_RETURN_IF_NOT(params.out_h >= 0 && params.out_w >= 0, "Computed output spatial size must be non-negative."); // Validate tensor shapes - ORT_RETURN_IF_NOT(W_shape.NumDimensions() == 4, "Weight must be 4D."); - ORT_RETURN_IF_NOT(offset_shape.NumDimensions() == 4, "Offset must be 4D."); ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size."); ORT_RETURN_IF_NOT( offset_shape[1] == params.offset_group * 2 * params.kH * params.kW, From 7ba7d1b57c12677a366dcd854b314c2d6e346ff5 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Thu, 12 Mar 2026 21:04:04 +0800 Subject: [PATCH 45/58] Refactor DeformConv MinimalBilinear tests with shared data/template and add optional bias omitted test --- .../providers/cpu/nn/deform_conv_op_test.cc | 140 +++++------------- 1 file changed, 41 insertions(+), 99 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index d09a656afc646..926f67f77117f 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -95,7 +95,8 @@ void RunDeformConvTest(const DeformConvTestParams& params, const std::vector& expected_Y, int opset = 19, decltype(DeformConvTestTraits::DefaultRtol()) rtol = DeformConvTestTraits::DefaultRtol(), - decltype(DeformConvTestTraits::DefaultAtol()) atol = DeformConvTestTraits::DefaultAtol()) { + decltype(DeformConvTestTraits::DefaultAtol()) atol = DeformConvTestTraits::DefaultAtol(), + bool omit_bias = false) { const int64_t kH = params.kernel_shape[0]; const int64_t kW = params.kernel_shape[1]; // ONNX pads format: [pad_top, pad_left, pad_bottom, pad_right] = [pad[0], pad[1], pad[2], pad[3]] @@ -124,13 +125,17 @@ void RunDeformConvTest(const DeformConvTestParams& params, auto X_t = DeformConvTestTraits::Convert(X); auto W_t = DeformConvTestTraits::Convert(W); auto offset_t = DeformConvTestTraits::Convert(offset); - auto B_t = DeformConvTestTraits::Convert(B); auto expected_Y_t = DeformConvTestTraits::Convert(expected_Y); test.AddInput("X", X_shape, X_t); test.AddInput("W", W_shape, W_t); test.AddInput("offset", offset_shape, offset_t); - test.AddInput("B", {params.n_out_channels}, B_t); + if (omit_bias) { + test.AddOptionalInputEdge(); + } else { + auto B_t = DeformConvTestTraits::Convert(B); + test.AddInput("B", {params.n_out_channels}, B_t); + } if (mask != nullptr) { const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; test.AddInput("mask", mask_shape, DeformConvTestTraits::Convert(*mask)); @@ -145,11 +150,15 @@ void RunDeformConvTest(const DeformConvTestParams& params, test.Run(OpTester::ExpectResult::kExpectSuccess, "", DeformConvTestTraits::ExcludedProviders()); } -} // namespace - -// Minimal case: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +// MinimalBilinear test: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). // At (0,0) offset (0.5, 0.5) samples center of [1,2;3,4] -> 2.5. -TEST(DeformConvTest, MinimalBilinear) { +template +void RunMinimalBilinearTest(int opset = 19, int min_cuda_arch = 0, bool omit_bias = false) { +#if defined(USE_CUDA) + if (min_cuda_arch > 0 && !HasCudaEnvironment(min_cuda_arch)) { + return; + } +#endif DeformConvTestParams p = {}; p.batch_sz = 1; p.n_in_channels = 1; @@ -162,117 +171,50 @@ TEST(DeformConvTest, MinimalBilinear) { p.dilation = {1, 1}; p.in_h = 2; p.in_w = 2; - - std::vector X = {1.f, 2.f, 3.f, 4.f}; // NCHW + std::vector X = {1.f, 2.f, 3.f, 4.f}; std::vector W = {1.f}; // offset shape [N, 2*kH*kW, out_h, out_w] = [1, 2, 2, 2]: ch0=offset_h, ch1=offset_w (for kernel pt 0) // Layout: offset[n,c,oh,ow]. Flattened (NCHW): [ch0@00, ch0@01, ch0@10, ch0@11, ch1@00, ch1@01, ch1@10, ch1@11] // (0,0): (0.5, 0.5)->center of [1,2;3,4]->2.5; (0,1): (0,-1)->(0,0)->1; (1,0): (0,0)->3; (1,1): (0,0)->4 - std::vector offset = { - 0.5f, 0.f, 0.f, 0.f, - 0.5f, -1.0f, 0.f, 0.f}; + std::vector offset = {0.5f, 0.f, 0.f, 0.f, 0.5f, -1.0f, 0.f, 0.f}; std::vector B = {0.f}; - std::vector mask = {1.f, 1.f, 1.f, 1.f}; // (1,1,2,2) + std::vector mask = {1.f, 1.f, 1.f, 1.f}; std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; + if (omit_bias) { + RunDeformConvTest(p, X, W, offset, {} /* B unused */, &mask, expected_Y, opset, + DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), true); + } else { + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, opset, + DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), false); + } +} +} // namespace - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +// Minimal case: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). +TEST(DeformConvTest, MinimalBilinear) { + RunMinimalBilinearTest(); +} + +// Optional bias omitted: same as MinimalBilinear but B is not provided; output must match B=0. +TEST(DeformConvTest, OptionalBiasOmitted) { + RunMinimalBilinearTest(19, 0, true); } -// Minimal case FP16: Same as MinimalBilinear but in FP16. -// Validates CUDA FP16 implementation (specifically coordinate precision logic). -// DeformConv FP16 is CUDA-only; skip when CUDA is not available (e.g., Linux x64/arm64 CPU-only builds). +// Minimal case FP16: Same as MinimalBilinear but in FP16 (CUDA-only). #if defined(USE_CUDA) TEST(DeformConvTest, MinimalBilinearFP16) { - int min_cuda_architecture = 530; // FP16 requires SM 5.3+ - if (!HasCudaEnvironment(min_cuda_architecture)) { - LOGS_DEFAULT(WARNING) << "DeformConv FP16: CUDA not available, skipping."; - return; - } - - DeformConvTestParams p = {}; - p.batch_sz = 1; - p.n_in_channels = 1; - p.n_out_channels = 1; - p.n_weight_grps = 1; - p.n_offset_grps = 1; - p.kernel_shape = {1, 1}; - p.stride = {1, 1}; - p.pad = {0, 0, 0, 0}; - p.dilation = {1, 1}; - p.in_h = 2; - p.in_w = 2; - - std::vector X = {1.f, 2.f, 3.f, 4.f}; - std::vector W = {1.f}; - std::vector offset = { - 0.5f, 0.f, 0.f, 0.f, - 0.5f, -1.0f, 0.f, 0.f}; - std::vector B = {0.f}; - std::vector mask = {1.f, 1.f, 1.f, 1.f}; - std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; - - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunMinimalBilinearTest(19, 530); } -// Minimal case BFloat16: Same as MinimalBilinear but in BFloat16 (CUDA BFloat16 coordinate precision). +// Minimal case BFloat16: Same as MinimalBilinear but in BFloat16 (CUDA-only, opset 22). TEST(DeformConvTest, MinimalBilinearBFloat16) { - int min_cuda_architecture = 800; - if (!HasCudaEnvironment(min_cuda_architecture)) { - LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; - return; - } - - DeformConvTestParams p = {}; - p.batch_sz = 1; - p.n_in_channels = 1; - p.n_out_channels = 1; - p.n_weight_grps = 1; - p.n_offset_grps = 1; - p.kernel_shape = {1, 1}; - p.stride = {1, 1}; - p.pad = {0, 0, 0, 0}; - p.dilation = {1, 1}; - p.in_h = 2; - p.in_w = 2; - - std::vector X = {1.f, 2.f, 3.f, 4.f}; - std::vector W = {1.f}; - std::vector offset = { - 0.5f, 0.f, 0.f, 0.f, - 0.5f, -1.0f, 0.f, 0.f}; - std::vector B = {0.f}; - std::vector mask = {1.f, 1.f, 1.f, 1.f}; - std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; - - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y, 22); + RunMinimalBilinearTest(22, 800); } #endif // defined(USE_CUDA) // Minimal case Double (FP64): Same as MinimalBilinear in double precision. TEST(DeformConvTest, MinimalBilinearDouble) { - DeformConvTestParams p = {}; - p.batch_sz = 1; - p.n_in_channels = 1; - p.n_out_channels = 1; - p.n_weight_grps = 1; - p.n_offset_grps = 1; - p.kernel_shape = {1, 1}; - p.stride = {1, 1}; - p.pad = {0, 0, 0, 0}; - p.dilation = {1, 1}; - p.in_h = 2; - p.in_w = 2; - - std::vector X = {1.f, 2.f, 3.f, 4.f}; - std::vector W = {1.f}; - std::vector offset = { - 0.5f, 0.f, 0.f, 0.f, - 0.5f, -1.0f, 0.f, 0.f}; - std::vector B = {0.f}; - std::vector mask = {1.f, 1.f, 1.f, 1.f}; - std::vector expected_Y = {2.5f, 1.f, 3.f, 4.f}; - - RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); + RunMinimalBilinearTest(); } // Forward with mask and bias FP16 (CUDA-only; skip when CUDA not available). From 432c1c681819276e98704c543537ef8b842bae2e Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Thu, 12 Mar 2026 21:09:01 +0800 Subject: [PATCH 46/58] Validate optional bias B shape (1D [M]) in DeformConv shared helper for CPU/CUDA --- onnxruntime/core/providers/cpu/nn/deform_conv.cc | 9 ++++++++- .../core/providers/cpu/nn/deform_conv_attributes.h | 6 ++++++ onnxruntime/core/providers/cuda/nn/deform_conv.cc | 9 ++++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index 9e3842f5e2e41..ae7ad8f0c63e5 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -156,7 +156,14 @@ Status DeformConv::Compute(OpKernelContext* context) const { const auto* mask = context->Input(4); // optional DeformConvParams params; - ORT_RETURN_IF_ERROR(DeformConvValidateAndParse(attrs_, X->Shape(), W->Shape(), offset->Shape(), mask ? &mask->Shape() : nullptr, params)); + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse( + attrs_, + X->Shape(), + W->Shape(), + offset->Shape(), + B ? &B->Shape() : nullptr, + mask ? &mask->Shape() : nullptr, + params)); const int64_t N = params.N; const int64_t C = params.C; diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index c7ad2011f3d68..01a487a1e1664 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -77,6 +77,7 @@ inline Status DeformConvValidateAndParse( const TensorShape& X_shape, const TensorShape& W_shape, const TensorShape& offset_shape, + const TensorShape* B_shape, const TensorShape* mask_shape, DeformConvParams& params) { ORT_RETURN_IF_NOT(X_shape.NumDimensions() == 4, "Input X must be 4D (N, C, H, W)."); @@ -132,6 +133,11 @@ inline Status DeformConvValidateAndParse( ORT_RETURN_IF_NOT(params.C == W_shape[1] * params.group, "Input channels must match weight in channels * group."); ORT_RETURN_IF_NOT(params.M % params.group == 0, "Output channels must be divisible by group."); + if (B_shape != nullptr) { + ORT_RETURN_IF_NOT(B_shape->NumDimensions() == 1, "Bias B must be 1D."); + ORT_RETURN_IF_NOT((*B_shape)[0] == params.M, "Bias B must have shape [M] (M = number of output channels)."); + } + // Validate mask if present if (params.use_mask) { ORT_RETURN_IF_NOT(mask_shape->NumDimensions() == 4, "Mask must be 4D."); diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 96fd84fb6bb25..17c43cbdb5e4d 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -122,7 +122,14 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { const auto* mask = context->Input(4); DeformConvParams params; - ORT_RETURN_IF_ERROR(DeformConvValidateAndParse(attrs_, X->Shape(), W->Shape(), offset->Shape(), mask ? &mask->Shape() : nullptr, params)); + ORT_RETURN_IF_ERROR(DeformConvValidateAndParse( + attrs_, + X->Shape(), + W->Shape(), + offset->Shape(), + B ? &B->Shape() : nullptr, + mask ? &mask->Shape() : nullptr, + params)); const int64_t N = params.N; const int64_t C = params.C; From a25c1f4a9f9ff58b75758a3f532349d38d719cc4 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Fri, 13 Mar 2026 00:37:26 +0800 Subject: [PATCH 47/58] Use cached totalGlobalMem for temp budget, remove cudaMemGetInfo and per-kernel state --- .../core/providers/cuda/nn/deform_conv.cc | 115 ++++++++---------- .../core/providers/cuda/nn/deform_conv.h | 15 --- 2 files changed, 50 insertions(+), 80 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 17c43cbdb5e4d..18aff162a6a63 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -7,6 +7,8 @@ #include "deform_conv.h" #include "deform_conv_impl.h" +#include + #include "core/common/narrow.h" #include "core/common/span_utils.h" #include "core/providers/cuda/cuda_common.h" @@ -46,71 +48,58 @@ int GetGreatestDivisorBelowBound(int n, int bound) { return 1; } -// Returns effective max temp memory (bytes) for DeformConv batching. -// Uses 90% of free GPU memory with tiered cap; fallback 256MB if cudaMemGetInfo fails. -// Mirrors Conv's approach (conv_8.h); tiered limits avoid OOM on smaller GPUs. -// Called only when input/weight shapes change (see UpdateState). -size_t GetDeformConvEffectiveMaxTempBytes() { - constexpr size_t kDefaultFallback = 256ULL * 1024 * 1024; - constexpr size_t kMinTempMemSize = 32ULL * 1024 * 1024; - - size_t free_mem = 0, total_mem = 0; - if (cudaMemGetInfo(&free_mem, &total_mem) != cudaSuccess || free_mem == 0) { - return kDefaultFallback; - } - free_mem = static_cast(static_cast(free_mem) * 0.9); // 10% fragmentation buffer - - size_t tier_cap; - if (free_mem > 16ULL * 1024 * 1024 * 1024) { - tier_cap = 2ULL * 1024 * 1024 * 1024; // 16GB+ free → 2GB - } else if (free_mem > 8ULL * 1024 * 1024 * 1024) { - tier_cap = 1ULL * 1024 * 1024 * 1024; // 8-16GB → 1GB - } else if (free_mem > 4ULL * 1024 * 1024 * 1024) { - tier_cap = 512ULL * 1024 * 1024; // 4-8GB → 512MB - } else if (free_mem > 2ULL * 1024 * 1024 * 1024) { - tier_cap = 256ULL * 1024 * 1024; // 2-4GB → 256MB - } else { - tier_cap = 128ULL * 1024 * 1024; // <2GB → 128MB - } - - return std::max(kMinTempMemSize, std::min(tier_cap, free_mem)); +// Returns the maximum temp memory (bytes) allowed for DeformConv's im2col + GEMM buffers. +// Uses a fraction of total GPU memory to avoid OOM while leaving room for weights, activations, +// and other ops. No CUDA API is called; total_global_mem is expected from cached device props. +// +// Formula: +// budget = total_global_mem * kFraction +// return clamp(budget, kMin, kMax) +// with kFraction = 0.1 (10%), kMin = 32 MiB, kMax = 2 GiB. +// +// Example results (effective_max_temp after clamp): +// GPU | totalGlobalMem | effective_max_temp +// -----------------|----------------|-------------------- +// A100 80GB | 80 GiB | 2 GiB (capped) +// RTX 5080 16GB | 16 GiB | 1.6 GiB +// RTX 4090 24GB | 24 GiB | 2 GiB (capped) +// RTX 3080 10GB | 10 GiB | 1 GiB +// GTX 1060 6GB | 6 GiB | 614.4 MiB +// GTX 1050 4GB | 4 GiB | 409.6 MiB +// Jetson 2GB | 2 GiB | 204.8 MiB +size_t GetDeformConvEffectiveMaxTempBytes(size_t total_global_mem) { + constexpr double kFraction = 0.1; + constexpr size_t kMin = 32ULL * 1024 * 1024; + constexpr size_t kMax = 2ULL * 1024 * 1024 * 1024; + size_t budget = static_cast(static_cast(total_global_mem) * kFraction); + return std::clamp(budget, kMin, kMax); } -} // namespace - +// Returns how many images to process in parallel per batch chunk for DeformConv. +// Chooses the largest divisor of batch size N that fits in the temp budget and does not +// exceed kMaxParallelImgs, so that batch dimension is split evenly (no remainder). +// +// Formulas: +// kernel_size = kH * kW +// output_image_size = out_h * out_w +// bytes_per_image = output_image_size * (C * kernel_size + M / group) * sizeof(T) +// (temp bytes per image: im2col col buffer + GEMM output buffer per output position) +// max_parallel_imgs_mem = max(1, floor(effective_max_temp / bytes_per_image)) +// target_parallel_imgs = min(kMaxParallelImgs, max_parallel_imgs_mem) +// return GetGreatestDivisorBelowBound(N, target_parallel_imgs) template -Status DeformConv::UpdateState(OpKernelContext* context, - const DeformConvParams& params, - int& n_parallel_imgs) const { - const auto* X = context->Input(0); - const auto* W = context->Input(1); - const auto x_dims = X->Shape().AsShapeVector(); - const auto w_dims = W->Shape().AsShapeVector(); - - bool input_dims_changed = (state_.last_x_dims != x_dims); - bool w_dims_changed = (state_.last_w_dims != w_dims); - - if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) { - state_.last_x_dims = gsl::make_span(x_dims); - } - if (w_dims_changed) { - state_.last_w_dims = gsl::make_span(w_dims); - } - - const int64_t kernel_size = params.kH * params.kW; - const int64_t output_image_size = params.out_h * params.out_w; - const size_t bytes_per_image = SafeInt(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T); - const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(); - const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); - const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); - state_.cached_n_parallel_imgs = GetGreatestDivisorBelowBound(static_cast(params.N), target_parallel_imgs); - } - - n_parallel_imgs = state_.cached_n_parallel_imgs; - return Status::OK(); +int GetNParallelImgs(const DeformConvParams& params, size_t total_global_mem) { + const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(total_global_mem); + const int64_t kernel_size = params.kH * params.kW; + const int64_t output_image_size = params.out_h * params.out_w; + const size_t bytes_per_image = SafeInt(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T); + const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); + const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); + return GetGreatestDivisorBelowBound(static_cast(params.N), target_parallel_imgs); } +} // namespace + template Status DeformConv::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; @@ -155,11 +144,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } - int n_parallel_imgs; - { - std::lock_guard lock(state_.mutex); - ORT_RETURN_IF_ERROR(UpdateState(context, params, n_parallel_imgs)); - } + const int n_parallel_imgs = GetNParallelImgs(params, GetDeviceProp().totalGlobalMem); const int64_t kernel_size = kH * kW; const int64_t output_image_size = out_h * out_w; diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.h b/onnxruntime/core/providers/cuda/nn/deform_conv.h index f89c105327b38..fa564641d4b98 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.h +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.h @@ -3,24 +3,14 @@ #pragma once -#include - #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "core/framework/tensor_shape.h" #include "core/providers/cpu/nn/deform_conv_attributes.h" #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { namespace cuda { -struct DeformConvState { - TensorShape last_x_dims; - TensorShape last_w_dims; - int cached_n_parallel_imgs{0}; - std::mutex mutex; -}; - template class DeformConv final : public CudaKernel { public: @@ -29,12 +19,7 @@ class DeformConv final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; private: - Status UpdateState(OpKernelContext* context, - const DeformConvParams& params, - int& n_parallel_imgs) const; - DeformConvAttributes attrs_; - mutable DeformConvState state_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeformConv); }; From d98f3399abc3a22224aa1c6500310239c112e640 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Fri, 13 Mar 2026 00:48:51 +0800 Subject: [PATCH 48/58] Document int indices in CUDA BilinearInterpolate --- onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu | 6 ++++++ onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc | 2 ++ 2 files changed, 8 insertions(+) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index 25feb9484198e..099358b9914ad 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -32,6 +32,10 @@ inline int GetGridSize(size_t n, size_t threads_per_block) { } // Bilinear interpolation at (h, w). Returns 0 if out of bounds (ONNX spec). +// Indices h_low, w_low, h_high, w_high use int (not int64_t) to reduce register pressure in the +// hot path and improve occupancy. Limitation: height and width must not exceed INT_MAX; otherwise +// the cast from floor(h)/floor(w) to int may overflow and produce incorrect sampling. This is +// acceptable for typical deformable convolution use (spatial dimensions are far below INT_MAX). template __device__ __inline__ T BilinearInterpolate( const T* in, @@ -88,6 +92,7 @@ __device__ __inline__ half BilinearInterpolate( if (h <= -1.0f || h >= height || w <= -1.0f || w >= width) { return __float2half(0.0f); } + // int for indices to save registers; see limitation in BilinearInterpolate above. int h_low = static_cast(floorf(h)); int w_low = static_cast(floorf(w)); int h_high = h_low + 1; @@ -116,6 +121,7 @@ __device__ __inline__ BFloat16 BilinearInterpolate( if (h <= -1.0f || h >= height || w <= -1.0f || w >= width) { return BFloat16(0.0f); } + // int for indices to save registers; see limitation in BilinearInterpolate above. int h_low = static_cast(floorf(h)); int w_low = static_cast(floorf(w)); int h_high = h_low + 1; diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 926f67f77117f..122209f022531 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -158,6 +158,8 @@ void RunMinimalBilinearTest(int opset = 19, int min_cuda_arch = 0, bool omit_bia if (min_cuda_arch > 0 && !HasCudaEnvironment(min_cuda_arch)) { return; } +#else + (void)min_cuda_arch; #endif DeformConvTestParams p = {}; p.batch_sz = 1; From 74a7760af5878f3870200342286a276ac59f39f3 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Fri, 13 Mar 2026 00:55:44 +0800 Subject: [PATCH 49/58] Document why BFloat16 is not delegated in DeformConv CUDA impl --- onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index 099358b9914ad..a90331a9c7bba 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -487,6 +487,8 @@ template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const B reinterpret_cast(B), N, M, out_h, out_w); \ } +// BFloat16 is not delegated: ORT's BFloat16 is the same type used in device code (ToCudaType in +// cuda_common.h), so the explicit instantiations above (INST_DeformConvIm2ColImpl(BFloat16), etc.) suffice. DELEGATE_DEFORM_CONV_IMPL(MLFloat16, half) } // namespace cuda From c2bf6f40ef8981e84e9646c51b999cb91102959f Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Fri, 13 Mar 2026 01:11:13 +0800 Subject: [PATCH 50/58] Document prime-batch fallback to single-image chunks in DeformConv GetNParallelImgs --- onnxruntime/core/providers/cuda/nn/deform_conv.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 18aff162a6a63..7a0b896acfe01 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -78,6 +78,8 @@ size_t GetDeformConvEffectiveMaxTempBytes(size_t total_global_mem) { // Returns how many images to process in parallel per batch chunk for DeformConv. // Chooses the largest divisor of batch size N that fits in the temp budget and does not // exceed kMaxParallelImgs, so that batch dimension is split evenly (no remainder). +// Note: if N is prime and N > target_parallel_imgs, the greatest divisor <= target_parallel_imgs is 1, +// so batching is effectively disabled (single-image chunks). // // Formulas: // kernel_size = kH * kW From a8920b42bca929dc1b85f37190d78aac5e04ebf7 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 14 Mar 2026 00:07:53 +0800 Subject: [PATCH 51/58] Refine deform conv test generator imports and ONNX model save usage --- onnxruntime/test/testdata/nn/deform_conv_test_gen.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py index 3d1f276846181..120fb1ed4c211 100644 --- a/onnxruntime/test/testdata/nn/deform_conv_test_gen.py +++ b/onnxruntime/test/testdata/nn/deform_conv_test_gen.py @@ -24,7 +24,7 @@ from pathlib import Path import numpy as np -import onnx +from onnx import TensorProto, checker, helper, save try: import onnxruntime as ort @@ -32,7 +32,6 @@ ort = None import torch import torchvision.ops -from onnx import TensorProto, helper # Config: groups=2, offset_group=2, 2x2 kernel (from deform_conv_expected_gen Case 3) BATCH = 1 @@ -114,7 +113,7 @@ def _build_onnx_model(): ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 19)]) - onnx.checker.check_model(model) + checker.check_model(model) return model @@ -157,7 +156,7 @@ def main(): print("Building ONNX model...") model = _build_onnx_model() - onnx.save(model, str(model_path)) + save(model, str(model_path)) print(f" Saved {model_path}") np.savez(str(data_path), **data) From b7b468138a27a4bb05c9675bc3b64474fc5c24bd Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 14 Mar 2026 01:07:12 +0800 Subject: [PATCH 52/58] Optimize DeformConv CPU bilinear interpolation --- .../core/providers/cpu/nn/deform_conv.cc | 66 +++++++++++-------- .../providers/cpu/nn/deform_conv_attributes.h | 7 ++ 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index ae7ad8f0c63e5..f128b0e0182ad 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -15,42 +15,50 @@ namespace onnxruntime { namespace { - -// Bilinear interpolation at fractional coordinates (h, w). -// Returns 0 if (h,w) is out of bounds; otherwise computes weighted average of the four -// nearest integer grid points. Standard implementation for deformable convolution sampling. +// Bilinear interpolation at (h, w). Out-of-bounds samples return 0 (ONNX spec). +// Indices use int (not int64_t) to reduce register pressure and improve occupancy in the hot path. +// Limitation: height and width must not exceed INT_MAX, or casting floor(h)/floor(w) to int may overflow. +// Acceptable in practice: deformable convolution spatial dimensions are typically well below INT_MAX. template -T BilinearInterpolate(const T* in, int64_t height, int64_t width, T h, T w) { - // Early exit for clearly out-of-bounds (skip floor() for OOB case). +T BilinearInterpolate(const T* in, int height, int width, T h, T w) { + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { return static_cast(0); } - const int64_t h_low = static_cast(std::floor(h)); - const int64_t w_low = static_cast(std::floor(w)); - const int64_t h_high = h_low + 1; - const int64_t w_high = w_low + 1; + // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + const T h_floor = std::floor(h); + const T w_floor = std::floor(w); + const int h_low = static_cast(h_floor); + const int w_low = static_cast(w_floor); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const T lh = h - h_floor; + const T lw = w - w_floor; + const T hh = static_cast(1) - lh; + const T hw = static_cast(1) - lw; // Fast path: all 4 corners in bounds (h in [0, height-1), w in [0, width-1)). // Most sampling points in deformable conv fall here; avoids 4 per-corner branches. - if (h_low >= 0 && h_high < height && w_low >= 0 && w_high < width) { - const T lh = h - static_cast(h_low); - const T lw = w - static_cast(w_low); - const T hh = static_cast(1) - lh; - const T hw = static_cast(1) - lw; - return hh * hw * in[h_low * width + w_low] + hh * lw * in[h_low * width + w_high] + - lh * hw * in[h_high * width + w_low] + lh * lw * in[h_high * width + w_high]; + // [Optimization 3]: Use unsigned comparison to avoid branch on negative height/width. + if (static_cast(h_low) < static_cast(height - 1) && + static_cast(w_low) < static_cast(width - 1)) { + const int base_low = h_low * width; + const int base_high = h_high * width; + return hh * hw * in[base_low + w_low] + + hh * lw * in[base_low + w_high] + + lh * hw * in[base_high + w_low] + + lh * lw * in[base_high + w_high]; } // Slow path: near boundary (one or more of the 4 corners may be out of bounds). - const T lh = h - static_cast(h_low); - const T lw = w - static_cast(w_low); - const T hh = static_cast(1) - lh; - const T hw = static_cast(1) - lw; - const T v1 = (h_low >= 0 && w_low >= 0) ? in[h_low * width + w_low] : static_cast(0); - const T v2 = (h_low >= 0 && w_high < width) ? in[h_low * width + w_high] : static_cast(0); - const T v3 = (h_high < height && w_low >= 0) ? in[h_high * width + w_low] : static_cast(0); - const T v4 = (h_high < height && w_high < width) ? in[h_high * width + w_high] : static_cast(0); + const int base_low = h_low * width; + const int base_high = h_high * width; + const T v1 = (h_low >= 0 && w_low >= 0) ? in[base_low + w_low] : static_cast(0); + const T v2 = (h_low >= 0 && w_high < width) ? in[base_low + w_high] : static_cast(0); + const T v3 = (h_high < height && w_low >= 0) ? in[base_high + w_low] : static_cast(0); + const T v4 = (h_high < height && w_high < width) ? in[base_high + w_high] : static_cast(0); return hh * hw * v1 + hh * lw * v2 + lh * hw * v3 + lh * lw * v4; } @@ -63,7 +71,7 @@ void DeformableIm2col( const T* data_im, // Input image [C, H, W] const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (nullptr when UseMask=false) - int64_t height, int64_t width, // Input spatial dimensions + int height, int width, // Input spatial dimensions (validated H*W <= INT_MAX) int64_t kernel_h, int64_t kernel_w, // Kernel dimensions int64_t pad_h, int64_t pad_w, // Padding (begin) for H and W int64_t stride_h, int64_t stride_w, // Stride for H and W @@ -93,7 +101,7 @@ void DeformableIm2col( // Output row: one (channel, kernel_pos) across all spatial locations. T* col_ptr = data_col + static_cast(idx) * output_size; - const T* im_ptr = data_im + c_im * height * width; + const T* im_ptr = data_im + c_im * static_cast(height) * width; // Offset tensor layout: [offset_grp, 2*kH*kW, H_out, W_out] flattened. // For (i,j) we use channel indices 2*(i*kW+j) and 2*(i*kW+j)+1 for offset_h, offset_w. @@ -229,14 +237,14 @@ Status DeformConv::Compute(OpKernelContext* context) const { if (use_mask) { DeformableIm2col( X_curr, offset_curr, mask_curr, - H, W_in, kH, kW, + static_cast(H), static_cast(W_in), kH, kW, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, C, offset_group, out_h, out_w, col_buffer_ptr, thread_pool); } else { DeformableIm2col( X_curr, offset_curr, nullptr, - H, W_in, kH, kW, + static_cast(H), static_cast(W_in), kH, kW, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, C, offset_group, out_h, out_w, col_buffer_ptr, thread_pool); diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index 01a487a1e1664..9517c6da70fa1 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/framework/tensor_shape.h" @@ -122,6 +124,11 @@ inline Status DeformConvValidateAndParse( params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1; ORT_RETURN_IF_NOT(params.out_h >= 0 && params.out_w >= 0, "Computed output spatial size must be non-negative."); + // CPU BilinearInterpolate uses int for indices (for performance optimization); W <= INT_MAX / (H+1) covers all index math. + ORT_RETURN_IF_NOT(params.H >= 0 && params.W_in >= 0, "Input spatial dimensions H and W must be non-negative."); + ORT_RETURN_IF_NOT(params.W_in <= static_cast(INT_MAX) / (params.H + 1), + "Input (H+1)*W must not exceed INT_MAX (for performance optimization)."); + // Validate tensor shapes ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size."); ORT_RETURN_IF_NOT( From 6aeef468c94bfb53c2fe8af10d0c9a0dba119f88 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Sat, 14 Mar 2026 01:37:11 +0800 Subject: [PATCH 53/58] Optimize DeformConv BilinearInterpolation for performance on CUDA --- .../providers/cuda/nn/deform_conv_impl.cu | 195 ++++++++++-------- 1 file changed, 106 insertions(+), 89 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index a90331a9c7bba..7b3666fca810b 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -31,48 +31,6 @@ inline int GetGridSize(size_t n, size_t threads_per_block) { return static_cast(std::min(blocks_needed, static_cast(std::numeric_limits::max()))); } -// Bilinear interpolation at (h, w). Returns 0 if out of bounds (ONNX spec). -// Indices h_low, w_low, h_high, w_high use int (not int64_t) to reduce register pressure in the -// hot path and improve occupancy. Limitation: height and width must not exceed INT_MAX; otherwise -// the cast from floor(h)/floor(w) to int may overflow and produce incorrect sampling. This is -// acceptable for typical deformable convolution use (spatial dimensions are far below INT_MAX). -template -__device__ __inline__ T BilinearInterpolate( - const T* in, - int64_t height, - int64_t width, - T h, - T w) { - if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { - return static_cast(0); - } - int h_low = static_cast(_Floor(h)); - int w_low = static_cast(_Floor(w)); - int h_high = h_low + 1; - int w_high = w_low + 1; - - T lh = h - static_cast(h_low); - T lw = w - static_cast(w_low); - T hh = static_cast(1) - lh; - T hw = static_cast(1) - lw; - - T v1 = (h_low >= 0 && w_low >= 0) ? __ldg(in + h_low * width + w_low) : static_cast(0); - T v2 = (h_low >= 0 && w_high < width) ? __ldg(in + h_low * width + w_high) : static_cast(0); - T v3 = (h_high < height && w_low >= 0) ? __ldg(in + h_high * width + w_low) : static_cast(0); - T v4 = (h_high < height && w_high < width) ? __ldg(in + h_high * width + w_high) : static_cast(0); - - T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; -} - -// FP16/BF16: coordinate and weight math in float to avoid precision loss. -template -struct DeformConvUseFloatCoords : std::false_type {}; -template <> -struct DeformConvUseFloatCoords : std::true_type {}; -template <> -struct DeformConvUseFloatCoords : std::true_type {}; - // __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly. template __device__ __inline__ T DeformConvLdg(const T* p) { @@ -83,62 +41,116 @@ __device__ __inline__ BFloat16 DeformConvLdg(const BFloat16* p) { return BFloat16::FromBits(__ldg(reinterpret_cast(p))); } -__device__ __inline__ half BilinearInterpolate( - const half* in, - int64_t height, - int64_t width, - float h, - float w) { - if (h <= -1.0f || h >= height || w <= -1.0f || w >= width) { +// Traits for bilinear interpolation math: +// - ComputeT: type used for coordinate/weight math (float for half/BFloat16, T otherwise) +// - Load: load one element and convert to ComputeT +// - ToResult: convert ComputeT result back to T +// - Zero: zero value of T +template +struct DeformConvBilinearTraits { + using ComputeT = T; + + __device__ static __inline__ ComputeT Load(const T* p) { + return __ldg(p); + } + + __device__ static __inline__ T ToResult(ComputeT v) { + return v; + } + + __device__ static __inline__ T Zero() { + return static_cast(0); + } +}; + +template <> +struct DeformConvBilinearTraits { + using ComputeT = float; + + __device__ static __inline__ ComputeT Load(const half* p) { + return __half2float(__ldg(p)); + } + + __device__ static __inline__ half ToResult(ComputeT v) { + return __float2half(v); + } + + __device__ static __inline__ half Zero() { return __float2half(0.0f); } - // int for indices to save registers; see limitation in BilinearInterpolate above. - int h_low = static_cast(floorf(h)); - int w_low = static_cast(floorf(w)); - int h_high = h_low + 1; - int w_high = w_low + 1; +}; - float lh = h - static_cast(h_low); - float lw = w - static_cast(w_low); - float hh = 1.0f - lh; - float hw = 1.0f - lw; +template <> +struct DeformConvBilinearTraits { + using ComputeT = float; - float v1 = (h_low >= 0 && w_low >= 0) ? __half2float(__ldg(in + h_low * width + w_low)) : 0.0f; - float v2 = (h_low >= 0 && w_high < width) ? __half2float(__ldg(in + h_low * width + w_high)) : 0.0f; - float v3 = (h_high < height && w_low >= 0) ? __half2float(__ldg(in + h_high * width + w_low)) : 0.0f; - float v4 = (h_high < height && w_high < width) ? __half2float(__ldg(in + h_high * width + w_high)) : 0.0f; + __device__ static __inline__ ComputeT Load(const BFloat16* p) { + return static_cast(DeformConvLdg(p)); + } - float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - return __float2half(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); -} + __device__ static __inline__ BFloat16 ToResult(ComputeT v) { + return BFloat16(v); + } -__device__ __inline__ BFloat16 BilinearInterpolate( - const BFloat16* in, - int64_t height, - int64_t width, - float h, - float w) { - if (h <= -1.0f || h >= height || w <= -1.0f || w >= width) { + __device__ static __inline__ BFloat16 Zero() { return BFloat16(0.0f); } - // int for indices to save registers; see limitation in BilinearInterpolate above. - int h_low = static_cast(floorf(h)); - int w_low = static_cast(floorf(w)); - int h_high = h_low + 1; - int w_high = w_low + 1; +}; - float lh = h - static_cast(h_low); - float lw = w - static_cast(w_low); - float hh = 1.0f - lh; - float hw = 1.0f - lw; +// Bilinear interpolation at (h, w). Returns 0 if out of bounds (ONNX spec). +// Indices h_low, w_low, h_high, w_high use int (not int64_t) to reduce register pressure and +// improve occupancy in the hot path. Limitation: (H+1)*W must not exceed INT_MAX; this is +// validated on the host side in DeformConvValidateAndParse to guarantee index math in int +// does not overflow. For half/BFloat16, coordinate and weight math use float via +// DeformConvBilinearTraits to avoid precision loss. We keep floor() results in CoordT and +// cast to int only for indices (h_low/w_low), which avoids unnecessary CoordT->int->CoordT +// round trips when computing lh/lw/hh/hw. +template +__device__ __inline__ T BilinearInterpolate( + const T* in, + int height, + int width, + typename DeformConvBilinearTraits::ComputeT h, + typename DeformConvBilinearTraits::ComputeT w) { + using Traits = DeformConvBilinearTraits; + using CoordT = typename Traits::ComputeT; + + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). + if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { + return Traits::Zero(); + } - float v1 = (h_low >= 0 && w_low >= 0) ? static_cast(DeformConvLdg(in + h_low * width + w_low)) : 0.0f; - float v2 = (h_low >= 0 && w_high < width) ? static_cast(DeformConvLdg(in + h_low * width + w_high)) : 0.0f; - float v3 = (h_high < height && w_low >= 0) ? static_cast(DeformConvLdg(in + h_high * width + w_low)) : 0.0f; - float v4 = (h_high < height && w_high < width) ? static_cast(DeformConvLdg(in + h_high * width + w_high)) : 0.0f; + // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + CoordT h_floor = _Floor(h); + CoordT w_floor = _Floor(w); + int h_low = static_cast(h_floor); + int w_low = static_cast(w_floor); + int h_high = h_low + 1; + int w_high = w_low + 1; - float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - return BFloat16(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + CoordT lh = h - h_floor; + CoordT lw = w - w_floor; + CoordT hh = static_cast(1) - lh; + CoordT hw = static_cast(1) - lw; + + // [Optimization 3]: Avoid a second multiply for base_high. + // Original code computed both bases as: + // base_low = h_low * width; + // base_high = h_high * width; + // Since h_high = h_low + 1, we can rewrite base_high as base_low + width and + // save one integer multiply in the hot path: + // base_low = h_low * width; + // base_high = base_low + width; + int base_low = h_low * width; + int base_high = base_low + width; + + CoordT v1 = (h_low >= 0 && w_low >= 0) ? Traits::Load(in + base_low + w_low) : static_cast(0); + CoordT v2 = (h_low >= 0 && w_high < width) ? Traits::Load(in + base_low + w_high) : static_cast(0); + CoordT v3 = (h_high < height && w_low >= 0) ? Traits::Load(in + base_high + w_low) : static_cast(0); + CoordT v4 = (h_high < height && w_high < width) ? Traits::Load(in + base_high + w_high) : static_cast(0); + + CoordT w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + return Traits::ToResult(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); } // kH/kW = -1 means dynamic (runtime); >= 0 means compile-time constant for loop unrolling. @@ -179,7 +191,7 @@ __global__ void DeformableIm2ColKernel( // The stride for data_col is (parallel_imgs * out_h * out_w) const int64_t col_stride = parallel_imgs * out_size; - using CoordT = typename std::conditional::value, float, T>::type; + using CoordT = typename DeformConvBilinearTraits::ComputeT; for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { IndexT val = index; @@ -251,7 +263,12 @@ __global__ void DeformableIm2ColKernel( const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h; const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w; - T val = BilinearInterpolate(input_ptr, height, width, h_im, w_im); + // height/width are validated on host (DeformConvValidateAndParse) so int is safe here. + T val = BilinearInterpolate(input_ptr, + static_cast(height), + static_cast(width), + h_im, + w_im); // Match CPU path: always interpolate then apply mask to keep branch-free hot loop. data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; From 46f176c4bd8fbbe0606fce1c01f8d1ba5c89a3a3 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 17 Mar 2026 01:42:44 +0800 Subject: [PATCH 54/58] Enforce 2D attribute lengths and validate kernel_shape/pads/overflow-safe shapes --- .../providers/cpu/nn/deform_conv_attributes.h | 68 ++++++++++++++----- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index 9517c6da70fa1..521f459b9ff44 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -92,23 +92,57 @@ inline Status DeformConvValidateAndParse( params.H = X_shape[2]; params.W_in = X_shape[3]; params.M = W_shape[0]; + ORT_RETURN_IF_NOT(params.N > 0, "Batch size N must be positive."); + ORT_RETURN_IF_NOT(params.C > 0, "Input channels C must be positive."); + ORT_RETURN_IF_NOT(params.M > 0, "Output channels M (oC) must be positive."); + ORT_RETURN_IF_NOT(W_shape[1] > 0, "Weight W must have positive in-channels (W_shape[1] = C/group)."); + + // Handle kernel shape inference. If kernel_shape is provided, it must match weight spatial dims + // to avoid GEMM using wrong K and potential out-of-bounds reads from the weight buffer. + const int64_t W_kH = W_shape[2]; + const int64_t W_kW = W_shape[3]; + if (!attrs.kernel_shape.empty()) { + ORT_RETURN_IF_NOT(attrs.kernel_shape.size() == 2, + "kernel_shape must be absent or have exactly 2 values (kH, kW) for 2D DeformConv."); + ORT_RETURN_IF_NOT(attrs.kernel_shape[0] == W_kH && attrs.kernel_shape[1] == W_kW, + "kernel_shape must match weight spatial dimensions (W_shape[2], W_shape[3])."); + params.kH = attrs.kernel_shape[0]; + params.kW = attrs.kernel_shape[1]; + } else { + params.kH = W_kH; + params.kW = W_kW; + } - // Handle kernel shape inference - params.kH = attrs.kernel_shape.size() >= 1 ? attrs.kernel_shape[0] : W_shape[2]; - params.kW = attrs.kernel_shape.size() >= 2 ? attrs.kernel_shape[1] : W_shape[3]; - + // DeformConv is 2D-only: when an attribute is present, require exact length to avoid silently misinterpreting malformed models. params.pad_h = params.pad_w = params.pad_h_end = params.pad_w_end = 0; - if (attrs.pads.size() >= 4) { + if (!attrs.pads.empty()) { + ORT_RETURN_IF_NOT(attrs.pads.size() == 4, + "pads must be absent or have exactly 4 values [pad_h_begin, pad_w_begin, pad_h_end, pad_w_end] for 2D DeformConv."); params.pad_h = attrs.pads[0]; params.pad_w = attrs.pads[1]; params.pad_h_end = attrs.pads[2]; params.pad_w_end = attrs.pads[3]; + ORT_RETURN_IF_NOT(params.pad_h >= 0 && params.pad_w >= 0 && params.pad_h_end >= 0 && params.pad_w_end >= 0, + "Pads must be non-negative (ONNX spec)."); } - params.stride_h = attrs.strides.empty() ? 1 : attrs.strides[0]; - params.stride_w = attrs.strides.size() < 2 ? 1 : attrs.strides[1]; - params.dilation_h = attrs.dilations.empty() ? 1 : attrs.dilations[0]; - params.dilation_w = attrs.dilations.size() < 2 ? 1 : attrs.dilations[1]; + if (!attrs.strides.empty()) { + ORT_RETURN_IF_NOT(attrs.strides.size() == 2, + "strides must be absent or have exactly 2 values [stride_h, stride_w] for 2D DeformConv."); + params.stride_h = attrs.strides[0]; + params.stride_w = attrs.strides[1]; + } else { + params.stride_h = params.stride_w = 1; + } + + if (!attrs.dilations.empty()) { + ORT_RETURN_IF_NOT(attrs.dilations.size() == 2, + "dilations must be absent or have exactly 2 values [dilation_h, dilation_w] for 2D DeformConv."); + params.dilation_h = attrs.dilations[0]; + params.dilation_w = attrs.dilations[1]; + } else { + params.dilation_h = params.dilation_w = 1; + } params.group = attrs.group; params.offset_group = attrs.offset_group; params.use_mask = (mask_shape != nullptr); @@ -129,11 +163,12 @@ inline Status DeformConvValidateAndParse( ORT_RETURN_IF_NOT(params.W_in <= static_cast(INT_MAX) / (params.H + 1), "Input (H+1)*W must not exceed INT_MAX (for performance optimization)."); - // Validate tensor shapes + // Validate tensor shapes (use division to avoid int64 overflow in offset_group * 2 * kH * kW). ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size."); - ORT_RETURN_IF_NOT( - offset_shape[1] == params.offset_group * 2 * params.kH * params.kW, - "Offset channel count must be offset_group * 2 * kH * kW."); + const int64_t offset_block = 2 * params.kH * params.kW; + ORT_RETURN_IF_NOT(offset_block > 0 && offset_shape[1] % offset_block == 0 && + offset_shape[1] / offset_block == params.offset_group, + "Offset channel count must be offset_group * 2 * kH * kW."); ORT_RETURN_IF_NOT(offset_shape[2] == params.out_h, "Offset spatial height must match output oH."); ORT_RETURN_IF_NOT(offset_shape[3] == params.out_w, "Offset spatial width must match output oW."); ORT_RETURN_IF_NOT(params.C % params.offset_group == 0, "Input channels must be divisible by offset_group."); @@ -149,9 +184,10 @@ inline Status DeformConvValidateAndParse( if (params.use_mask) { ORT_RETURN_IF_NOT(mask_shape->NumDimensions() == 4, "Mask must be 4D."); ORT_RETURN_IF_NOT((*mask_shape)[0] == params.N, "Mask batch size must match input batch size."); - ORT_RETURN_IF_NOT( - (*mask_shape)[1] == params.offset_group * params.kH * params.kW, - "Mask channel count must be offset_group * kH * kW."); + const int64_t mask_block = params.kH * params.kW; + ORT_RETURN_IF_NOT(mask_block > 0 && (*mask_shape)[1] % mask_block == 0 && + (*mask_shape)[1] / mask_block == params.offset_group, + "Mask channel count must be offset_group * kH * kW."); ORT_RETURN_IF_NOT((*mask_shape)[2] == params.out_h, "Mask spatial height must match output oH."); ORT_RETURN_IF_NOT((*mask_shape)[3] == params.out_w, "Mask spatial width must match output oW."); } From 7cd167b2d1955d1feb482ec55b9d8a39dbaab4a5 Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 17 Mar 2026 01:53:18 +0800 Subject: [PATCH 55/58] Clarify DeformConv OnnxModelTest comment as ORT-reference smoke test --- onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 122209f022531..456ee5cae28c1 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -920,8 +920,9 @@ TEST(DeformConvTest, ExtremeAspectRatio) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } -// ONNX model data test: fixed inputs from deform_conv_test_gen.py (torchvision ref, seed=123). -// Validates output matches torch reference. The .onnx/.npz can be used for standalone model zoo validation. +// ONNX model data test: deform_conv_test_gen.py builds the ONNX model (via onnx.helper) +// and generates fixed inputs from torchvision (seed=123). This test is a model-loading/ +// integration smoke test that uses ORT-generated outputs from deform_conv_test.onnx as the reference. TEST(DeformConvTest, OnnxModelTest) { OpTester test("DeformConv", 19); test.AddAttribute("kernel_shape", std::vector{2, 2}); From ada5ca394598d6cb3bd82a1da15d7a02e9af793e Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 17 Mar 2026 03:10:32 +0800 Subject: [PATCH 56/58] DeformConv EmptyBatch test expects failure when batch size N is zero --- onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 456ee5cae28c1..4809d4b225fad 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -365,7 +365,7 @@ TEST(DeformConvTest, ForwardNoMask) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } -// Empty batch (like PyTorch batch_sz=0). +// Empty batch (N=0): ONNX DeformConv does not allow batch size 0; expect validation failure. TEST(DeformConvTest, EmptyBatch) { DeformConvTestParams p = {}; p.batch_sz = 0; @@ -408,7 +408,7 @@ TEST(DeformConvTest, EmptyBatch) { test.AddOutput("Y", Y_shape, expected_Y); std::unordered_set excluded = {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; - test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); + test.Run(OpTester::ExpectResult::kExpectFailure, "Batch size N must be positive.", excluded); } // Wrong offset channel count -> expect failure (like PyTorch test_wrong_sizes). From 17b155acd8e996e4fd8e86fd48020505807bf6ea Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Tue, 17 Mar 2026 03:31:19 +0800 Subject: [PATCH 57/58] Allow DeformConv empty batch --- onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h | 2 +- onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index 521f459b9ff44..8bc891bb4f377 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -92,7 +92,7 @@ inline Status DeformConvValidateAndParse( params.H = X_shape[2]; params.W_in = X_shape[3]; params.M = W_shape[0]; - ORT_RETURN_IF_NOT(params.N > 0, "Batch size N must be positive."); + ORT_RETURN_IF_NOT(params.N >= 0, "Batch size N must be non-negative."); ORT_RETURN_IF_NOT(params.C > 0, "Input channels C must be positive."); ORT_RETURN_IF_NOT(params.M > 0, "Output channels M (oC) must be positive."); ORT_RETURN_IF_NOT(W_shape[1] > 0, "Weight W must have positive in-channels (W_shape[1] = C/group)."); diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 4809d4b225fad..860c0d2f08b18 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -365,7 +365,7 @@ TEST(DeformConvTest, ForwardNoMask) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } -// Empty batch (N=0): ONNX DeformConv does not allow batch size 0; expect validation failure. +// Empty batch (N=0): allowed, same as Conv/ConvTranspose/Pool — output shape [0, oC, oH, oW]. TEST(DeformConvTest, EmptyBatch) { DeformConvTestParams p = {}; p.batch_sz = 0; @@ -408,7 +408,7 @@ TEST(DeformConvTest, EmptyBatch) { test.AddOutput("Y", Y_shape, expected_Y); std::unordered_set excluded = {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}; - test.Run(OpTester::ExpectResult::kExpectFailure, "Batch size N must be positive.", excluded); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded); } // Wrong offset channel count -> expect failure (like PyTorch test_wrong_sizes). From 288e4c020536ea818edd87657a58227f10e2fa4d Mon Sep 17 00:00:00 2001 From: Shirasawa <764798966@qq.com> Date: Wed, 18 Mar 2026 12:34:48 +0800 Subject: [PATCH 58/58] Update docs --- docs/OperatorKernels.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 39c9145a40912..25829d08206cf 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -103,6 +103,8 @@ Do not modify directly.* |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(double), tensor(float)| +|||[19, 21]|**T** = tensor(double), tensor(float)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(uint8)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -697,6 +699,8 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(int64)| +|DeformConv|*in* X:**T**
*in* W:**T**
*in* offset:**T**
*in* B:**T**
*in* mask:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|