diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv.cc b/onnxruntime/core/providers/cpu/nn/deform_conv.cc index f128b0e0182ad..dd0da83850a1f 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cpu/nn/deform_conv.cc @@ -2,159 +2,627 @@ // Licensed under the MIT License. // // CPU implementation of DeformConv (deformable convolution 2D). +// +// High-level pipeline (one batch item at a time for peak memory): +// (1) Build a bilinear sampling plan from offsets (parallel): for each (offset_group, kernel tap, output pixel), +// store 4 neighbor indices + 4 weights in AoSoA blocks (see kPlanAoSoALanes). +// (2) Fill im2col matrix (parallel over channels x kernel taps): each row is one (channel, i, j) slice, +// reusing the plan row shared by all channels in the same offset group. +// (3) Grouped GEMM: Y_g = W_g * Col_g per group (highly optimized in math::Gemm / BLAS). +// (4) Optional bias: add B[m] to each output channel map (vectorized per row). +// +// Biggest win vs a naive loop: reusing the sampling plan across C/offset_group channels (plan built once per +// offset row, not per channel) plus AoSoA layout so the gather/interpolate inner loop can SIMD-unroll 8-wide. #include "deform_conv.h" -#include +#include +#include +#include #include "core/common/common.h" +#include "core/common/inlined_containers.h" #include "core/util/math_cpuonly.h" -#include "core/common/narrow.h" +#include "core/util/force_inline.h" #include "core/util/math.h" +#if defined(__GNUC__) && !defined(__wasm__) +#define ORT_CPU_RESTRICT __restrict__ +#elif defined(_MSC_VER) +#define ORT_CPU_RESTRICT __restrict +#else +#define ORT_CPU_RESTRICT +#endif + +// Hint the inner lane loop for SIMD / vectorization (OpenMP simd, Clang loop, or GCC ivdep); empty otherwise. +#if defined(_OPENMP) +#define ORT_CPU_SIMD_INNER_LOOP _Pragma("omp simd") +#elif defined(__clang__) +#define ORT_CPU_SIMD_INNER_LOOP _Pragma("clang loop vectorize(enable)") +#elif defined(__GNUC__) +#define ORT_CPU_SIMD_INNER_LOOP _Pragma("GCC ivdep") +#else +#define ORT_CPU_SIMD_INNER_LOOP +#endif + namespace onnxruntime { namespace { -// Bilinear interpolation at (h, w). Out-of-bounds samples return 0 (ONNX spec). -// Indices use int (not int64_t) to reduce register pressure and improve occupancy in the hot path. -// Limitation: height and width must not exceed INT_MAX, or casting floor(h)/floor(w) to int may overflow. -// Acceptable in practice: deformable convolution spatial dimensions are typically well below INT_MAX. + +// AoSoA "lane" count: each BilinearSamplePlanBlock holds 8 output pixels' worth of idx/weights per corner. +// For T=float, 8 matches one 256-bit AVX2 vector of floats; auto-vectorizers often turn the lane loop into +// SIMD. For T=double, SIMD is typically 4-wide; the 8-lane layout still unrolls the scalar work and keeps +// the same indexing (pidx / 8, pidx % 8) in PlanStoreSample — changing it requires revisiting all offsets. +constexpr int64_t kPlanAoSoALanes = 8; + +// Overflow-safe size_t multiply: returns a * b only if the product fits in size_t. +// Guard: for a > 0, a * b <= max <=> b <= max / a (integer division; avoids computing a*b first). +// If a == 0 the product is 0 and is always representable, so the check is skipped. +// Used when deriving batch/group strides from tensor shapes so pointer arithmetic like +// base + n * stride cannot wrap size_t on valid ONNX shapes. +ORT_FORCEINLINE size_t CheckedMulSizeT(size_t a, size_t b, size_t max_size_t, const char* err) { + ORT_ENFORCE(a == 0 || b <= max_size_t / a, err); + return a * b; +} + +// Verifies that batch indexing over n items with byte stride `stride` stays within addressable size_t range. +// For n > 0, the largest offset used when stepping batch index 0..n-1 is (n-1)*stride plus element spans; +// requiring n * stride <= max_size_t is a conservative upper bound that n * stride itself does not overflow +// (same inequality as CheckedMulSizeT with arguments n and stride). +ORT_FORCEINLINE void CheckedBatchSpan(size_t n, size_t stride, size_t max_size_t, const char* err) { + ORT_ENFORCE(n == 0 || stride <= max_size_t / n, err); +} + +struct CpuDeformConvStrides { + size_t x_batch_stride = 0; + size_t y_batch_stride = 0; + size_t w_group_stride = 0; + size_t col_group_stride = 0; + size_t y_group_stride = 0; + size_t offset_batch_stride = 0; + size_t mask_batch_stride = 0; +}; + +struct CpuDeformConvExecutionDims { + int64_t plan_rows = 0; + int64_t padded_spatial_count = 0; + size_t block_count = 0; + int64_t im2col_rows = 0; + int64_t col_buffer_size = 0; +}; + +inline CpuDeformConvExecutionDims ComputeCpuDeformConvExecutionDims(const DeformConvParams& params, + const DeformConvCommonDims& common_dims, + int64_t ptrdiff_max, + size_t max_size_t) { + const int64_t int64_max = std::numeric_limits::max(); + + ORT_ENFORCE(params.offset_group <= int64_max / common_dims.kernel_size, "plan_rows overflows int64."); + const int64_t plan_rows = params.offset_group * common_dims.kernel_size; + + ORT_ENFORCE(plan_rows > 0 && common_dims.output_image_size > 0, "Invalid plan dimensions."); + ORT_ENFORCE(common_dims.output_image_size <= int64_max - (kPlanAoSoALanes - 1), + "output_image_size is too large and will overflow."); + // Round output_image_size up to a multiple of kPlanAoSoALanes so each plan "row" occupies an integer number + // of BilinearSamplePlanBlocks. Storage is row-major: row r starts at offset r * padded_spatial_count in + // "logical pixels"; block index = that offset / kPlanAoSoALanes. Only indices [0, output_image_size) are + // read when filling im2col; the padded tail slots in the last block are never read (FillColRow uses output_size). + // [IMPORTANT] Plan buffer is not zero-filled; tail lanes in the last block stay uninitialized (FillColRow uses tail_count only). + const int64_t padded_spatial_count = (common_dims.output_image_size + kPlanAoSoALanes - 1) / + kPlanAoSoALanes * kPlanAoSoALanes; + const size_t blocks_per_row = static_cast(padded_spatial_count) / kPlanAoSoALanes; + ORT_ENFORCE(blocks_per_row <= (max_size_t / static_cast(plan_rows)), + "Sampling plan size overflows size_t."); + const size_t block_count = static_cast(plan_rows) * blocks_per_row; + + ORT_ENFORCE(plan_rows <= ptrdiff_max, "plan_rows exceeds ptrdiff_t range."); + ORT_ENFORCE(common_dims.output_image_size == 0 || plan_rows <= int64_max / common_dims.output_image_size, + "Flattened bilinear plan task count overflows int64."); + const int64_t flattened_plan_tasks = plan_rows * common_dims.output_image_size; + ORT_ENFORCE(flattened_plan_tasks <= ptrdiff_max, + "Flattened bilinear plan tasks exceed ptrdiff_t range (needed for thread pool parallelization)."); + ORT_ENFORCE(params.C <= int64_max / common_dims.kernel_size, "im2col row count overflows int64."); + const int64_t im2col_rows = params.C * common_dims.kernel_size; + ORT_ENFORCE(im2col_rows <= ptrdiff_max, "im2col row count exceeds ptrdiff_t range."); + ORT_ENFORCE(im2col_rows <= int64_max / common_dims.output_image_size, "col_buffer_size overflows int64."); + const int64_t col_buffer_size = im2col_rows * common_dims.output_image_size; + + return CpuDeformConvExecutionDims{ + plan_rows, + padded_spatial_count, + block_count, + im2col_rows, + col_buffer_size}; +} + +inline CpuDeformConvStrides ComputeCpuDeformConvStrides(const DeformConvParams& params, + const DeformConvCommonDims& common_dims, + size_t max_size_t) { + const size_t c_size = static_cast(params.C); + const size_t m_size = static_cast(params.M); + const size_t n_size = static_cast(params.N); + const size_t group_size = static_cast(params.group); + const size_t offset_group_size = static_cast(params.offset_group); + const size_t input_image_size_sz = static_cast(common_dims.input_image_size); + const size_t output_image_size_sz = static_cast(common_dims.output_image_size); + const size_t kernel_size_sz = static_cast(common_dims.kernel_size); + const size_t kernel_dim_sz = static_cast(common_dims.kernel_dim); + const size_t m_per_group_sz = m_size / group_size; + + // Flat strides (elements between batch items or group slices). Computed once per Compute(), not per pixel. + // Hot paths then use pointer += stride instead of repeated rank-4/5 index math — typically saves several + // multiplies/adds per inner iteration in exchange for this O(1) setup cost. + CpuDeformConvStrides strides; + strides.x_batch_stride = CheckedMulSizeT(c_size, input_image_size_sz, max_size_t, "X batch stride overflows size_t."); + strides.y_batch_stride = CheckedMulSizeT(m_size, output_image_size_sz, max_size_t, "Y batch stride overflows size_t."); + strides.w_group_stride = CheckedMulSizeT(m_per_group_sz, kernel_dim_sz, max_size_t, "weight group stride overflows size_t."); + strides.col_group_stride = CheckedMulSizeT(kernel_dim_sz, output_image_size_sz, max_size_t, "col group stride overflows size_t."); + strides.y_group_stride = CheckedMulSizeT(m_per_group_sz, output_image_size_sz, max_size_t, "Y group stride overflows size_t."); + + const size_t offset_rows = CheckedMulSizeT( + CheckedMulSizeT(offset_group_size, kernel_size_sz, max_size_t, "offset rows overflows size_t."), + static_cast(2), max_size_t, "offset rows overflows size_t."); + strides.offset_batch_stride = CheckedMulSizeT(offset_rows, output_image_size_sz, max_size_t, "offset batch stride overflows size_t."); + + const size_t mask_rows = CheckedMulSizeT(offset_group_size, kernel_size_sz, max_size_t, "mask rows overflows size_t."); + strides.mask_batch_stride = CheckedMulSizeT(mask_rows, output_image_size_sz, max_size_t, "mask batch stride overflows size_t."); + + CheckedBatchSpan(n_size, strides.x_batch_stride, max_size_t, "X batch indexing overflows size_t."); + CheckedBatchSpan(n_size, strides.y_batch_stride, max_size_t, "Y batch indexing overflows size_t."); + CheckedBatchSpan(n_size, strides.offset_batch_stride, max_size_t, "offset batch indexing overflows size_t."); + if (params.use_mask) { + CheckedBatchSpan(n_size, strides.mask_batch_stride, max_size_t, "mask batch indexing overflows size_t."); + } + + return strides; +} + +namespace sampling_plan_internal { + +// One AoSoA "macro-cell": 8 output pixels x 4 bilinear corners. See kPlanAoSoALanes and PlanStoreSample. +// [IMPORTANT] Last-block tail lanes may be uninitialized; keep in sync with padded_spatial_count / FillColRow tail_count. +template +struct alignas(64) BilinearSamplePlanBlock { + int32_t idx[4][kPlanAoSoALanes]; + T w[4][kPlanAoSoALanes]; +}; + +template +struct DeformableIm2colContext { + const T* ORT_CPU_RESTRICT data_im = nullptr; + const T* ORT_CPU_RESTRICT data_offset = nullptr; + const T* ORT_CPU_RESTRICT data_mask = nullptr; + int height = 0; + int width = 0; + int64_t kernel_h = 0; + int64_t kernel_w = 0; + int64_t stride_h = 0; + int64_t stride_w = 0; + int64_t channels = 0; + int64_t offset_groups = 0; + int64_t height_col = 0; + int64_t width_col = 0; + int64_t padded_spatial_count = 0; + const size_t* ORT_CPU_RESTRICT kernel_offset_base_delta = nullptr; + const T* ORT_CPU_RESTRICT kernel_base_h = nullptr; + const T* ORT_CPU_RESTRICT kernel_base_w = nullptr; + BilinearSamplePlanBlock* ORT_CPU_RESTRICT sampling_plan_blocks = nullptr; + T* ORT_CPU_RESTRICT data_col = nullptr; + concurrency::ThreadPool* thread_pool = nullptr; +}; + +template +ORT_FORCEINLINE void PlanStoreSample(BilinearSamplePlanBlock* ORT_CPU_RESTRICT blocks, int64_t pidx, + int32_t idx00, int32_t idx01, int32_t idx10, int32_t idx11, + T w00, T w01, T w10, T w11) { + // Scatter one output pixel into lane `pidx % 8` across the four corners. AoSoA vs AoS: here `w[k][0..7]` + // are contiguous in memory for corner k, so the gather loop can load 8 weights per corner with vector + // loads; with AoS (one pixel's 4 corners packed together), the same 8 pixels would be strided and harder + // to SIMD. alignas(64) on the block type aligns starts to cache lines (struct may still span multiple lines). + const int64_t block = pidx / kPlanAoSoALanes; + const int64_t lane = pidx % kPlanAoSoALanes; + auto& dst = blocks[block]; + dst.idx[0][lane] = idx00; + dst.idx[1][lane] = idx01; + dst.idx[2][lane] = idx10; + dst.idx[3][lane] = idx11; + dst.w[0][lane] = w00; + dst.w[1][lane] = w01; + dst.w[2][lane] = w10; + dst.w[3][lane] = w11; +} + +// Matches std::floor for in-range finite x. Call only after coords pass the inverted bounds check below +// (NaN makes each comparison false, so the && fails and ! rejects without a separate isfinite branch). +// Performance trick: std::floor can be slow due to handling edge cases (NaN, Inf, negative zero). +// This custom implementation uses a simple cast to int and a boolean subtraction, which compiles +// to fast, branchless instructions on most architectures. +template +ORT_FORCEINLINE int DeformConvFastFloor(T x) { + // Assumes x is in int range after prior bounds filtering; T→int truncates toward zero. + const int i = static_cast(x); + return i - static_cast(i > x); +} + template -T BilinearInterpolate(const T* in, int height, int width, T h, T w) { - // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). - if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { - return static_cast(0); +ORT_FORCEINLINE void BilinearPlanOneSample( + const T* ORT_CPU_RESTRICT ptr_offset_h, + const T* ORT_CPU_RESTRICT ptr_offset_w, + int height, + int width, + int64_t h_col, + int64_t w_col, + int64_t local_idx, + int64_t stride_h, + int64_t stride_w, + T base_h, + T base_w, + BilinearSamplePlanBlock* ORT_CPU_RESTRICT plan_blocks) { + // Deformable sampling point in input space (fractional): col (h_col,w_col) -> image (h_im, w_im). + const T h_im = static_cast(h_col * stride_h) + base_h + ptr_offset_h[local_idx]; + const T w_im = static_cast(w_col * stride_w) + base_w + ptr_offset_w[local_idx]; + + // In-bounds test on open rectangle (-1, H) x (-1, W) (same as strict && on comparisons). Bitwise & evaluates + // all four preds (no short-circuit); NaN makes each compare false → treated as out-of-bounds without isnan(). + // One branch remains on `in_bounds == 0` to skip bilinear work when fully outside; inner fast/slow path is separate. + const T neg1 = static_cast(-1); + const T h_max = static_cast(height); + const T w_max = static_cast(width); + const unsigned in_bounds = static_cast( + (h_im > neg1) & (h_im < h_max) & (w_im > neg1) & (w_im < w_max)); + if (in_bounds == 0u) { + PlanStoreSample(plan_blocks, local_idx, 0, 0, 0, 0, + static_cast(0), static_cast(0), static_cast(0), static_cast(0)); + return; } - // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. - const T h_floor = std::floor(h); - const T w_floor = std::floor(w); - const int h_low = static_cast(h_floor); - const int w_low = static_cast(w_floor); + const int h_low = DeformConvFastFloor(h_im); + const int w_low = DeformConvFastFloor(w_im); + const T h_floor = static_cast(h_low); + const T w_floor = static_cast(w_low); const int h_high = h_low + 1; const int w_high = w_low + 1; - const T lh = h - h_floor; - const T lw = w - w_floor; + const T lh = h_im - h_floor; + const T lw = w_im - w_floor; const T hh = static_cast(1) - lh; const T hw = static_cast(1) - lw; - // Fast path: all 4 corners in bounds (h in [0, height-1), w in [0, width-1)). - // Most sampling points in deformable conv fall here; avoids 4 per-corner branches. - // [Optimization 3]: Use unsigned comparison to avoid branch on negative height/width. + // Bilinear interpolation weights calculation: + // w00 (top-left) = (1 - dy) * (1 - dx) + // w01 (top-right) = (1 - dy) * dx + // w10 (bottom-left) = dy * (1 - dx) + // w11 (bottom-right) = dy * dx + T plan_w00 = hh * hw; + T plan_w01 = hh * lw; + T plan_w10 = lh * hw; + T plan_w11 = lh * lw; + + // Safe under DeformConvValidateAndParse precondition: (H + 1) * W <= int_max. + // With h_high <= H and w_high <= W, these linearized int indices stay in range. + // Near borders h_low/w_low can be -1, but lower bounds also remain representable in int32. + const int base_low = h_low * width; + const int base_high = h_high * width; + int32_t idx00 = base_low + w_low; + int32_t idx01 = base_low + w_high; + int32_t idx10 = base_high + w_low; + int32_t idx11 = base_high + w_high; + + // Fast path: If the entire 2x2 interpolation window is strictly inside the image boundaries, + // we can safely store the indices and weights without any further bounds checking. + // This branch is taken for the vast majority of pixels, significantly speeding up the plan generation. if (static_cast(h_low) < static_cast(height - 1) && static_cast(w_low) < static_cast(width - 1)) { - const int base_low = h_low * width; - const int base_high = h_high * width; - return hh * hw * in[base_low + w_low] + - hh * lw * in[base_low + w_high] + - lh * hw * in[base_high + w_low] + - lh * lw * in[base_high + w_high]; + PlanStoreSample( + plan_blocks, local_idx, + idx00, idx01, idx10, idx11, + plan_w00, plan_w01, plan_w10, plan_w11); + } else { + // Slow path (Edge cases): The interpolation window overlaps with the image boundary. + // We must check each of the 4 corners individually. If a corner is out of bounds, + // its corresponding weight and index are forced to 0. This ensures that out-of-bounds + // reads fetch a safe value (at index 0) which is then multiplied by a 0.0 weight, + // effectively contributing 0 to the final interpolated result (zero-padding semantics). + const bool v00 = (h_low >= 0 && w_low >= 0); + const bool v01 = (h_low >= 0 && w_high < width); + const bool v10 = (h_high < height && w_low >= 0); + const bool v11 = (h_high < height && w_high < width); + + plan_w00 = v00 ? plan_w00 : static_cast(0); + plan_w01 = v01 ? plan_w01 : static_cast(0); + plan_w10 = v10 ? plan_w10 : static_cast(0); + plan_w11 = v11 ? plan_w11 : static_cast(0); + + idx00 = v00 ? idx00 : 0; + idx01 = v01 ? idx01 : 0; + idx10 = v10 ? idx10 : 0; + idx11 = v11 ? idx11 : 0; + + PlanStoreSample( + plan_blocks, local_idx, + idx00, idx01, idx10, idx11, + plan_w00, plan_w01, plan_w10, plan_w11); } +} - // Slow path: near boundary (one or more of the 4 corners may be out of bounds). - const int base_low = h_low * width; - const int base_high = h_high * width; - const T v1 = (h_low >= 0 && w_low >= 0) ? in[base_low + w_low] : static_cast(0); - const T v2 = (h_low >= 0 && w_high < width) ? in[base_low + w_high] : static_cast(0); - const T v3 = (h_high < height && w_low >= 0) ? in[base_high + w_low] : static_cast(0); - const T v4 = (h_high < height && w_high < width) ? in[base_high + w_high] : static_cast(0); - return hh * hw * v1 + hh * lw * v2 + lh * hw * v3 + lh * lw * v4; +template +void BuildAllBilinearSamplingPlansImpl( + const T* ORT_CPU_RESTRICT data_offset, + int height, + int width, + int64_t stride_h, + int64_t stride_w, + int64_t offset_groups, + size_t output_size, + int64_t padded_spatial_count, + int64_t width_col, + int64_t kernel_size, + const size_t* ORT_CPU_RESTRICT kernel_offset_base_delta, + const T* ORT_CPU_RESTRICT kernel_base_h, + const T* ORT_CPU_RESTRICT kernel_base_w, + BilinearSamplePlanBlock* ORT_CPU_RESTRICT sampling_plan_blocks, + concurrency::ThreadPool* thread_pool) { + const int64_t plan_rows = offset_groups * kernel_size; + ORT_ENFORCE(kernel_offset_base_delta != nullptr, "kernel_offset_base_delta must not be null."); + ORT_ENFORCE(kernel_base_h != nullptr, "kernel_base_h must not be null."); + ORT_ENFORCE(kernel_base_w != nullptr, "kernel_base_w must not be null."); + const size_t kernel_size_sz = static_cast(kernel_size); + const size_t offset_group_stride = static_cast(2) * kernel_size_sz; + const int64_t output_size_i64 = static_cast(output_size); + + // Plan is built once per (offset_group, kernel tap) row and reused for every input channel in that group: + // work factor ~ O(offset_group * kH * kW * out_h * out_w) instead of O(C * kH * kW * ...) for bilinear setup. + // Flatten (row, output pixel) to one task range so TryParallelFor can split fine-grained work even when + // offset_group * kernel_size is small (parallelizing only the outer dimension would under-use threads). + const int64_t total_plan_tasks = plan_rows * output_size_i64; + // Unit cost is a dimensionless heuristic for ORT's thread pool splitter, not CPU cycles. + // We keep plan-build chunking slightly finer than before so offset_group==1 cases can expose + // enough parallel tasks early instead of leaving work concentrated in the later fill stage. + constexpr double kCostPerBilinearSample = 8.0; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + static_cast(total_plan_tasks), + kCostPerBilinearSample, + [&](ptrdiff_t begin, ptrdiff_t end) { + const int64_t end_task = static_cast(end); + int64_t task = static_cast(begin); + int64_t row = task / output_size_i64; + int64_t local_idx = task % output_size_i64; + int64_t offset_grp = row / kernel_size; + int64_t kernel_idx = row % kernel_size; + size_t offset_grp_base = static_cast(offset_grp) * offset_group_stride; + + while (task < end_task) { + const size_t kernel_idx_sz = static_cast(kernel_idx); + const size_t offset_base = offset_grp_base + kernel_offset_base_delta[kernel_idx_sz]; + const T* ORT_CPU_RESTRICT ptr_offset_h = data_offset + offset_base * output_size; + const T* ORT_CPU_RESTRICT ptr_offset_w = data_offset + (offset_base + 1) * output_size; + const size_t plan_row_base = static_cast(row) * static_cast(padded_spatial_count); + BilinearSamplePlanBlock* ORT_CPU_RESTRICT row_plan = sampling_plan_blocks + (plan_row_base / kPlanAoSoALanes); + + // Output pixel index: local_idx = h_col * width_col + w_col (row-major flatten of [0, out_h) x [0, out_w)). + const int64_t h_col = local_idx / width_col; + const int64_t w_col = local_idx % width_col; + BilinearPlanOneSample(ptr_offset_h, ptr_offset_w, height, width, h_col, w_col, local_idx, stride_h, + stride_w, kernel_base_h[kernel_idx_sz], kernel_base_w[kernel_idx_sz], row_plan); + + ++task; + if (++local_idx == output_size_i64) { + local_idx = 0; + ++row; + if (++kernel_idx == kernel_size) { + kernel_idx = 0; + ++offset_grp; + offset_grp_base += offset_group_stride; + } + } + } + }); } -// Deformable Im2Col for a SINGLE image. -// Converts the input image into a matrix suitable for GEMM by sampling with learned offsets. -// Output 'data_col' shape: [C_in * kH * kW, H_out * W_out] -// When UseMask=false, pass nullptr for data_mask; compiler eliminates dead code for mask. template -void DeformableIm2col( - const T* data_im, // Input image [C, H, W] - const T* data_offset, // Offset [offset_groups * 2 * kH * kW, H_out, W_out] - const T* data_mask, // Mask [offset_groups * kH * kW, H_out, W_out] (nullptr when UseMask=false) - int height, int width, // Input spatial dimensions (validated H*W <= INT_MAX) - int64_t kernel_h, int64_t kernel_w, // Kernel dimensions - int64_t pad_h, int64_t pad_w, // Padding (begin) for H and W - int64_t stride_h, int64_t stride_w, // Stride for H and W - int64_t dilation_h, int64_t dilation_w, // Dilation for H and W - int64_t channels, // Input channels - int64_t offset_groups, // Number of offset groups (channels shared per group) - int64_t height_col, int64_t width_col, // Output spatial dimensions (H_out, W_out) - T* data_col, // Output buffer for im2col result - concurrency::ThreadPool* thread_pool) { - const int64_t channel_per_offset_group = channels / offset_groups; - const int64_t kernel_size = kernel_h * kernel_w; - const int64_t output_size = height_col * width_col; +void FillColRowFromSamplingPlanImpl( + const T* ORT_CPU_RESTRICT im_ptr, + const BilinearSamplePlanBlock* ORT_CPU_RESTRICT plan_blocks, + int64_t spatial_count, + size_t mask_row_base, + const T* ORT_CPU_RESTRICT ptr_mask, + T* ORT_CPU_RESTRICT col_ptr) { + // val = sum_{c in corners} w_c * im[idx_c]; optionally val *= mask[local_idx] (DeformConv v2). + // UseMask is a template parameter so the no-mask build has zero mask branches/loads in this loop (better SIMD). + const int64_t block_count = spatial_count / kPlanAoSoALanes; + const int64_t tail_count = spatial_count % kPlanAoSoALanes; + + int64_t local_idx = 0; + for (int64_t b = 0; b < block_count; ++b) { + const auto& block = plan_blocks[b]; + // Inner lane loop: 8 pixels; for float, compilers often SIMD this (commonly a few× faster than scalar, ISA/optimizer dependent). + ORT_CPU_SIMD_INNER_LOOP for (int lane = 0; lane < kPlanAoSoALanes; ++lane) { + T val = block.w[0][lane] * im_ptr[block.idx[0][lane]] + + block.w[1][lane] * im_ptr[block.idx[1][lane]] + + block.w[2][lane] * im_ptr[block.idx[2][lane]] + + block.w[3][lane] * im_ptr[block.idx[3][lane]]; + if constexpr (UseMask) { + val *= ptr_mask[mask_row_base + local_idx]; + } + col_ptr[local_idx] = val; + ++local_idx; + } + } + + // [IMPORTANT] Last partial block: only lanes [0, tail_count) are valid; do not SIMD-load all 8 without init/zero. + if (tail_count > 0) { + const auto& block = plan_blocks[block_count]; + ORT_CPU_SIMD_INNER_LOOP for (int lane = 0; lane < tail_count; ++lane) { + T val = block.w[0][lane] * im_ptr[block.idx[0][lane]] + + block.w[1][lane] * im_ptr[block.idx[1][lane]] + + block.w[2][lane] * im_ptr[block.idx[2][lane]] + + block.w[3][lane] * im_ptr[block.idx[3][lane]]; + if constexpr (UseMask) { + val *= ptr_mask[mask_row_base + local_idx]; + } + col_ptr[local_idx] = val; + ++local_idx; + } + } +} - // Parallelize over (channel, kernel_position) so each task processes one full row of data_col. - // This yields channels*kernel_size tasks, better CPU utilization and cache-friendly sequential writes. +template +void DeformableIm2colPlanned(const DeformableIm2colContext& ctx) { + ORT_ENFORCE(ctx.sampling_plan_blocks != nullptr, "sampling_plan_blocks must not be null."); + ORT_ENFORCE(ctx.data_col != nullptr, "data_col must not be null."); + // Single-image im2col: col buffer is [C*kH*kW, out_h*out_w] per batch index n only → peak memory O(C*kH*kW*HWout) + // instead of O(N*...) when N>1. UseMask is compile-time so mask loads/branches are absent when false. + const int64_t channel_per_offset_group = ctx.channels / ctx.offset_groups; + const int64_t kernel_size = ctx.kernel_h * ctx.kernel_w; + const int64_t output_size = ctx.height_col * ctx.width_col; + const size_t output_size_sz = static_cast(output_size); + + BuildAllBilinearSamplingPlansImpl( + ctx.data_offset, ctx.height, ctx.width, + ctx.stride_h, ctx.stride_w, + ctx.offset_groups, + output_size_sz, ctx.padded_spatial_count, ctx.width_col, kernel_size, + ctx.kernel_offset_base_delta, ctx.kernel_base_h, ctx.kernel_base_w, + ctx.sampling_plan_blocks, ctx.thread_pool); + + // Heuristic cost per im2col row (one channel x one kernel tap): ~one full pass over output pixels with gathers. + // Slightly higher when UseMask (extra multiply per pixel). Same note as kCostPerBilinearSample: for scheduling only. + // For small offset_group (especially 1), each sampling-plan row is reused by many channels, so this stage + // dominates; reduce cost to encourage finer split and better load balance at high C. + const double base_cost = static_cast(output_size) * (UseMask ? 12.0 : 10.0); + const double offset_group_adjust = (ctx.offset_groups == 1) ? 0.5 : 1.0; + const double parallel_cost = base_cost * offset_group_adjust; concurrency::ThreadPool::TryParallelFor( - thread_pool, - static_cast(channels * kernel_size), - static_cast(output_size) * 10.0, + ctx.thread_pool, + static_cast(ctx.channels * kernel_size), + parallel_cost, [&](ptrdiff_t begin, ptrdiff_t end) { + int64_t c_im = begin / kernel_size; + int64_t rem = begin % kernel_size; + int64_t i = rem / ctx.kernel_w; + int64_t j = rem % ctx.kernel_w; + for (ptrdiff_t idx = begin; idx < end; ++idx) { - // Decompose idx into (c_im, i, j): which channel and kernel position. - const int64_t j = static_cast(idx) % kernel_w; - const int64_t i = (static_cast(idx) / kernel_w) % kernel_h; - const int64_t c_im = static_cast(idx) / kernel_size; const int64_t offset_grp = c_im / channel_per_offset_group; - // Output row: one (channel, kernel_pos) across all spatial locations. - T* col_ptr = data_col + static_cast(idx) * output_size; - const T* im_ptr = data_im + c_im * static_cast(height) * width; + // Pointer arithmetic and index calculation for the current im2col row. + // `col_ptr`: Points to the start of the current row in the output `col_buffer`. + // Shape of col_buffer is [C * kH * kW, out_h * out_w]. + // Row-major flatten over (channel, kernel_y, kernel_x): idx = c_im * (kH*kW) + i * kW + j. + T* ORT_CPU_RESTRICT col_ptr = ctx.data_col + static_cast(idx) * output_size; + + // `im_ptr`: Points to the start of the current channel `c_im` in the input image. + // Shape of input image is [C, H, W]. + const T* ORT_CPU_RESTRICT im_ptr = ctx.data_im + c_im * static_cast(ctx.height) * ctx.width; - // Offset tensor layout: [offset_grp, 2*kH*kW, H_out, W_out] flattened. - // For (i,j) we use channel indices 2*(i*kW+j) and 2*(i*kW+j)+1 for offset_h, offset_w. - // Precompute pointers to avoid offset_base * output_size multiplication in inner loop. - const int64_t offset_base = - offset_grp * 2 * kernel_size + 2 * (i * kernel_w + j); - const T* ptr_offset_h = data_offset + offset_base * output_size; - const T* ptr_offset_w = data_offset + (offset_base + 1) * output_size; + // `row`: Identifies which pre-computed sampling plan to use. + // The sampling plan is shared across channels that belong to the same `offset_grp`. + // Formula: plan_row_index = offset_grp * (kH * kW) + (i * kW + j) + const int64_t row = offset_grp * kernel_size + i * ctx.kernel_w + j; - // Base terms for h_im, w_im: invariant in inner loop (i, j fixed). - const T base_h = -pad_h + static_cast(i) * dilation_h; - const T base_w = -pad_w + static_cast(j) * dilation_w; + // `row_plan`: Points to the start of the AoSoA blocks for this specific `row`. + // Since each block holds `kPlanAoSoALanes` elements, we divide the padded base index by it. + const size_t plan_row_base = static_cast(row) * static_cast(ctx.padded_spatial_count); + const BilinearSamplePlanBlock* ORT_CPU_RESTRICT row_plan = ctx.sampling_plan_blocks + (plan_row_base / kPlanAoSoALanes); - // Mask pointer; only used when UseMask=true (compiler removes when false). - [[maybe_unused]] const T* ptr_mask = nullptr; + const T* ORT_CPU_RESTRICT ptr_mask = nullptr; + size_t mask_row_base = 0; if constexpr (UseMask) { - ptr_mask = data_mask + (offset_grp * kernel_size + i * kernel_w + j) * output_size; + // If DeformConv v2 (with modulation mask), fetch the mask pointer. + // Shape of mask is [offset_group * kH * kW, out_h * out_w]. + // The `row` index perfectly matches the mask's row index. + ptr_mask = ctx.data_mask; + mask_row_base = static_cast(row) * output_size_sz; } - // Loop over output spatial positions. - for (int64_t h_col = 0; h_col < height_col; ++h_col) { - for (int64_t w_col = 0; w_col < width_col; ++w_col) { - const int64_t spatial_idx = h_col * width_col + w_col; + // Execute the gather and interpolation for this specific (channel, kernel_y, kernel_x) combination + // across all spatial output pixels [0, out_h * out_w). + FillColRowFromSamplingPlanImpl( + im_ptr, row_plan, output_size, mask_row_base, ptr_mask, col_ptr); - const T offset_h = ptr_offset_h[spatial_idx]; - const T offset_w = ptr_offset_w[spatial_idx]; + if (++j == ctx.kernel_w) { + j = 0; + if (++i == ctx.kernel_h) { + i = 0; + ++c_im; + } + } + } + }); +} - // Deformed sampling coordinates (fractional, for bilinear interpolation). - const T h_im = h_col * stride_h + base_h + offset_h; - const T w_im = w_col * stride_w + base_w + offset_w; +} // namespace sampling_plan_internal - // Sample input at deformed location; returns 0 if out of bounds. - T val = BilinearInterpolate(im_ptr, height, width, h_im, w_im); +} // namespace - // Modulate by mask when UseMask=true; compiled away when false. - // Design choice: we always interpolate then multiply, rather than skip when mask==0. - // Rationale: (1) Skipping adds a branch; unpredictable mask values cause misprediction - // penalties (~15-20 cycles). (2) Straight-line code vectorizes better; conditional - // skip blocks SIMD. (3) Multiplying by 0 is cheap when vectorized. In typical DCN - // usage (moderate mask density), the unconditional path usually wins. - if constexpr (UseMask) { - val *= ptr_mask[spatial_idx]; - } +template +ORT_FORCEINLINE void DeformConvCpuAddBiasToRow(T* ORT_CPU_RESTRICT row, const T* ORT_CPU_RESTRICT bias_data, + int64_t channel, ptrdiff_t spatial_len) { + // row[s] += bias[channel] for all s; Eigen maps to SIMD on large spatial_len (often several x vs scalar loop). + EigenVectorArrayMap(row, spatial_len) += bias_data[channel]; +} - col_ptr[spatial_idx] = val; - } +template +void DeformConvCpuAddBias(T* ORT_CPU_RESTRICT y_data, const T* ORT_CPU_RESTRICT bias_data, int64_t batch_n, + int64_t num_output_channels, int64_t output_image_size, size_t output_image_size_elements, + size_t y_batch_stride, concurrency::ThreadPool* thread_pool) { + const int64_t int64_max = std::numeric_limits::max(); + const int64_t ptrdiff_max = static_cast(std::numeric_limits::max()); + const ptrdiff_t spatial_len = static_cast(output_image_size); + const int64_t M = num_output_channels; + + // N==1: parallelize over M channels only. Avoids the N>1 path's initial k/M and k%M per thread chunk and keeps + // the hot loop free of division (integer div is ~tens of cycles; negligible vs spatial SIMD work for large HW, + // but this path is the common inference case and is simpler for the pool to split). + // Y[0, m, :] += B[m] elementwise over spatial indices. + if (batch_n == 1) { + const double cost_per_channel_slice = static_cast(output_image_size); + concurrency::ThreadPool::TryParallelFor( + thread_pool, static_cast(M), cost_per_channel_slice, + [&](ptrdiff_t first, ptrdiff_t last) { + for (ptrdiff_t m = first; m < last; ++m) { + const size_t m_sz = static_cast(m); + T* ORT_CPU_RESTRICT y_row = y_data + m_sz * output_image_size_elements; + DeformConvCpuAddBiasToRow(y_row, bias_data, static_cast(m), spatial_len); + } + }); + return; + } + + ORT_ENFORCE(batch_n <= int64_max / M, "N*M overflows int64 for bias parallelization."); + + // N>1: flatten (n, m) to k = n * M + m so TryParallelFor sees enough tasks; update (n,m) by increment/wrap + // inside the loop to avoid div/mod per iteration (see loop body). + const int64_t total_tasks = batch_n * M; + ORT_ENFORCE(total_tasks <= ptrdiff_max, "N*M exceeds ptrdiff_t range for bias parallelization."); + const double cost_per_task = static_cast(output_image_size); + + concurrency::ThreadPool::TryParallelFor( + thread_pool, static_cast(total_tasks), cost_per_task, + [&](ptrdiff_t first, ptrdiff_t last) { + // Initialize (n,m) from `first` only once per [first,last); advance by m++ with wrap (no per-iter div). + int64_t n = static_cast(first) / M; + int64_t m = static_cast(first) % M; + for (ptrdiff_t k = first; k < last; ++k) { + const size_t n_sz = static_cast(n); + const size_t m_sz = static_cast(m); + + // Pointer arithmetic formula: Y_row_ptr = y_data + (n * y_batch_stride) + (m * output_image_size) + // Mathematical operation: Y[n, m, spatial_idx] += B[m] for all spatial_idx in [0, output_image_size). + T* ORT_CPU_RESTRICT y_row = y_data + n_sz * y_batch_stride + m_sz * output_image_size_elements; + DeformConvCpuAddBiasToRow(y_row, bias_data, m, spatial_len); + + // For subsequent tasks, we simply increment `m` and wrap around to increment `n`. + // This completely eliminates division and modulo operations inside the hot loop. + if (++m == M) { + m = 0; + ++n; } } }); } -} // namespace - template Status DeformConv::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); @@ -180,124 +648,175 @@ Status DeformConv::Compute(OpKernelContext* context) const { const int64_t M = params.M; const int64_t kH = params.kH; const int64_t kW = params.kW; - const int64_t pad_h = params.pad_h; - const int64_t pad_w = params.pad_w; const int64_t stride_h = params.stride_h; const int64_t stride_w = params.stride_w; - const int64_t dilation_h = params.dilation_h; - const int64_t dilation_w = params.dilation_w; const int64_t group = params.group; const int64_t offset_group = params.offset_group; const int64_t out_h = params.out_h; const int64_t out_w = params.out_w; const bool use_mask = params.use_mask; - // Allocate output tensor [N, M, out_h, out_w]. + // --- Phase 1: Pre-computation and Memory Allocation --- + // 1.0) Output Y [N, M, out_h, out_w]; early exit if empty. const TensorShape Y_shape({N, M, out_h, out_w}); Tensor* Y = context->Output(0, Y_shape); if (Y->Shape().Size() == 0) { return Status::OK(); } - // Precompute common sizes for the im2col + GEMM pipeline. - const int64_t kernel_size = kH * kW; - const int64_t output_image_size = out_h * out_w; - const int64_t input_image_size = H * W_in; - const int64_t kernel_dim = C / group * kernel_size; // K dimension for GEMM: C/group * kH * kW + // 1.1) Shared (CPU/CUDA) runtime bounds + derived dimensions. + const int64_t ptrdiff_max = static_cast(std::numeric_limits::max()); + + DeformConvCommonDims common_dims; + ORT_RETURN_IF_ERROR(DeformConvValidateAndComputeCommonDims(params, common_dims)); + const int64_t output_image_size = common_dims.output_image_size; + const int64_t kernel_dim = common_dims.kernel_dim; // K dimension for GEMM: C/group * kH * kW + const size_t max_size_t = std::numeric_limits::max(); + const CpuDeformConvExecutionDims exec_dims = ComputeCpuDeformConvExecutionDims(params, common_dims, ptrdiff_max, max_size_t); + const int64_t padded_spatial_count = exec_dims.padded_spatial_count; + const size_t block_count = exec_dims.block_count; + + // Compute base sampling points and offset deltas on the fly using InlinedVector. + // This avoids heap allocations (std::vector) while completely eliminating the need for + // shared_mutex, atomic reference counting, and mutable state in the OpKernel. + // The computation cost (a few dozen cycles) is vastly lower than lock/atomic overhead. + const size_t kernel_size_sz = static_cast(common_dims.kernel_size); + // 49 is enough to inline up to 7x7 kernels without heap allocation. + onnxruntime::InlinedVector offset_base_delta(kernel_size_sz); + onnxruntime::InlinedVector base_h(kernel_size_sz); + onnxruntime::InlinedVector base_w(kernel_size_sz); + for (int64_t kernel_idx = 0; kernel_idx < common_dims.kernel_size; ++kernel_idx) { + const int64_t i = kernel_idx / params.kW; + const int64_t j = kernel_idx % params.kW; + const size_t kernel_idx_sz = static_cast(kernel_idx); + // Offset tensor layout per ONNX DeformConv: for each offset_group and kernel tap, two maps (dy, dx) + // of shape [out_h, out_w]. Flat row-major offset index for tap (i,j) is 2 * kernel_idx within the group. + offset_base_delta[kernel_idx_sz] = static_cast(2) * kernel_idx_sz; + // Base sampling point in input space (before adding deform offsets): standard conv unwarped grid. + base_h[kernel_idx_sz] = static_cast(-params.pad_h + i * params.dilation_h); + base_w[kernel_idx_sz] = static_cast(-params.pad_w + j * params.dilation_w); + } // Col buffer: shape [C*kH*kW, out_h*out_w]. Allocate per-image (process one image at a time) // to reduce peak memory when N is large; im2col is implemented per-image anyway. - const int64_t col_buffer_size = (C * kernel_size) * output_image_size; + const int64_t col_buffer_size = exec_dims.col_buffer_size; + + // 1.2) Flat strides (element counts) for batch/group pointer bumping — see ComputeCpuDeformConvStrides body. + // x_batch_stride = C * H * W; y_batch_stride = M * out_h * out_w; w_group_stride = (M/group) * kernel_dim; + // col_group_stride = kernel_dim * out_h * out_w; y_group_stride = (M/group) * out_h * out_w; + // offset_batch_stride = (2 * offset_group * kH * kW) * out_h * out_w (dy and dx maps per tap). + // mask_batch_stride = (offset_group * kH * kW) * out_h * out_w (one modulation weight per tap, no factor 2). + const CpuDeformConvStrides strides = ComputeCpuDeformConvStrides(params, common_dims, max_size_t); + const size_t output_image_size_sz = static_cast(output_image_size); + const size_t x_batch_stride = strides.x_batch_stride; + const size_t y_batch_stride = strides.y_batch_stride; + const size_t w_group_stride = strides.w_group_stride; + const size_t col_group_stride = strides.col_group_stride; + const size_t y_group_stride = strides.y_group_stride; + const size_t offset_batch_stride = strides.offset_batch_stride; + const size_t mask_batch_stride = strides.mask_batch_stride; + + // 1.3) GEMM call-site bounds (checked once outside group loop). + ORT_ENFORCE((M / group) <= ptrdiff_max, "GEMM M dimension exceeds ptrdiff_t range."); + ORT_ENFORCE(output_image_size <= ptrdiff_max, "GEMM N dimension exceeds ptrdiff_t range."); + ORT_ENFORCE(kernel_dim <= ptrdiff_max, "GEMM K dimension exceeds ptrdiff_t range."); AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); - const T* Xdata = X->Data(); - const T* Wdata = W->Data(); - const T* offset_data = offset->Data(); - const T* mask_data = use_mask ? mask->Data() : nullptr; - T* Ydata = Y->MutableData(); - const T* Bdata = (B != nullptr) ? B->Data() : nullptr; + ORT_ENFORCE(static_cast(H) * static_cast(W_in) <= static_cast(std::numeric_limits::max()), + "DeformConv requires H*W to fit in int for sampling indices."); - concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + auto plan_blocks = IAllocator::MakeUniquePtr>(alloc, SafeInt(block_count)); - // Process each image in the batch. + // Aliasing contract for this optimized path: + // - input tensors may alias each other (read-only is fine), + // - output Y must not overlap any input tensor (DeformConv is not an in-place kernel). + const T* ORT_CPU_RESTRICT Xdata = X->Data(); + const T* ORT_CPU_RESTRICT Wdata = W->Data(); + const T* ORT_CPU_RESTRICT offset_data = offset->Data(); + const T* ORT_CPU_RESTRICT mask_data = use_mask ? mask->Data() : nullptr; + T* ORT_CPU_RESTRICT Ydata = Y->MutableData(); + const T* ORT_CPU_RESTRICT Bdata = (B != nullptr) ? B->Data() : nullptr; + + // --- Phase 2: Core Computation (Im2Col + GEMM) --- + // Process each image in the batch sequentially to save peak memory. + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); for (int64_t n = 0; n < N; ++n) { - // Step 1: Deformable Im2Col for image n. + const size_t n_idx = static_cast(n); + + // 2.1) Deformable Im2Col for image n. // Gather deformed samples into col buffer for GEMM. - const T* X_curr = Xdata + n * (C * input_image_size); - const T* offset_curr = offset_data + n * (offset_group * 2 * kernel_size * output_image_size); - const T* mask_curr = use_mask ? (mask_data + n * (offset_group * kernel_size * output_image_size)) : nullptr; - T* col_buffer_ptr = col_buffer.get(); - - // Dispatch to template instantiation: UseMask=true or false eliminates branch in hot loop. - // Note: pad_h, pad_w are begin-side paddings for coordinate mapping; pad_h_end/pad_w_end - // affect only output size (already baked into out_h, out_w), not im2col sampling. + const T* ORT_CPU_RESTRICT X_curr = Xdata + n_idx * x_batch_stride; + const T* ORT_CPU_RESTRICT offset_curr = offset_data + n_idx * offset_batch_stride; + const T* ORT_CPU_RESTRICT mask_curr = use_mask ? (mask_data + n_idx * mask_batch_stride) : nullptr; + T* ORT_CPU_RESTRICT col_buffer_ptr = col_buffer.get(); + + sampling_plan_internal::DeformableIm2colContext im2col_ctx{ + X_curr, offset_curr, mask_curr, + static_cast(H), static_cast(W_in), + kH, kW, stride_h, stride_w, C, offset_group, out_h, out_w, + padded_spatial_count, + offset_base_delta.data(), + base_h.data(), + base_w.data(), + plan_blocks.get(), + col_buffer_ptr, + thread_pool}; + // use_mask is runtime, but the hot gather loop is compiled twice (UseMask true/false) so the false + // build has no mask load/multiply/branch per pixel — see FillColRowFromSamplingPlanImpl. if (use_mask) { - DeformableIm2col( - X_curr, offset_curr, mask_curr, - static_cast(H), static_cast(W_in), kH, kW, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - C, offset_group, out_h, out_w, - col_buffer_ptr, thread_pool); + sampling_plan_internal::DeformableIm2colPlanned(im2col_ctx); } else { - DeformableIm2col( - X_curr, offset_curr, nullptr, - static_cast(H), static_cast(W_in), kH, kW, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - C, offset_group, out_h, out_w, - col_buffer_ptr, thread_pool); + sampling_plan_internal::DeformableIm2colPlanned(im2col_ctx); } - // Step 2: GEMM for each group. Y = W * Col (per group). + // 2.2) GEMM for each group. Y = W * Col (per group). + // The deformable convolution is cast as a Matrix Multiplication (GEMM). + // For each group, the weight matrix W has shape [M/group, C/group * kH * kW] + // and the gathered column matrix Col has shape [C/group * kH * kW, out_h * out_w]. + // The result Y_g is [M/group, out_h * out_w]. for (int64_t g = 0; g < group; ++g) { + const size_t g_idx = static_cast(g); // Weight for group g: shape [M/group, C/group, kH, kW], row-major. - const T* weight_g = Wdata + g * (M / group) * kernel_dim; + const T* ORT_CPU_RESTRICT weight_g = Wdata + g_idx * w_group_stride; // Col rows for group g: layout [C*kH*kW, out_h*out_w], group g spans rows [g*kernel_dim, (g+1)*kernel_dim). - const T* col_g = col_buffer_ptr + g * kernel_dim * output_image_size; + const T* ORT_CPU_RESTRICT col_g = col_buffer_ptr + g_idx * col_group_stride; // Output slice for group g: [n, g*M/group:(g+1)*M/group, out_h, out_w]. - T* Y_g = Ydata + n * M * output_image_size + g * (M / group) * output_image_size; + T* ORT_CPU_RESTRICT Y_g = Ydata + n_idx * y_batch_stride + g_idx * y_group_stride; - // GEMM: Y = W * Col. W [M/group, kernel_dim], Col [kernel_dim, output_image_size]. + // GEMM: C = alpha * A * B + beta * C with alpha=1, beta=0 => Y_g = W_g * Col_g. + // Dimensions: A is (M_g, K), B is (K, N_out), C is (M_g, N_out), where M_g=M/group, K=kernel_dim, N_out=output_image_size. math::Gemm( CblasNoTrans, CblasNoTrans, - narrow(M / group), // M - narrow(output_image_size), // N - narrow(kernel_dim), // K - static_cast(1), // alpha - weight_g, // A - col_g, // B - static_cast(0), // beta - Y_g, // C + static_cast(M / group), // M + static_cast(output_image_size), // N + static_cast(kernel_dim), // K + static_cast(1), // alpha + weight_g, // A + col_g, // B + static_cast(0), // beta + Y_g, // C thread_pool, nullptr); // mlas_backend_kernel_selector_config } } - // Step 3: Add bias if provided (broadcast over spatial dimensions). + // --- Phase 3: Post-processing --- + // 3.1) Add bias if provided (broadcast over spatial dimensions). if (Bdata != nullptr) { - int64_t total_work = N * M; - concurrency::ThreadPool::TryParallelFor( - thread_pool, static_cast(total_work), static_cast(output_image_size), - [&](ptrdiff_t first, ptrdiff_t last) { - for (ptrdiff_t idx = first; idx < last; ++idx) { - int64_t n = idx / M; - int64_t m = idx % M; - T* Y_ptr = Ydata + n * M * output_image_size + m * output_image_size; - // Eigen vectorized add: Y_ptr += Bdata[m] over all spatial positions. - EigenVectorArrayMap(Y_ptr, narrow(output_image_size)) += Bdata[m]; - } - }); + DeformConvCpuAddBias(Ydata, Bdata, N, M, output_image_size, output_image_size_sz, y_batch_stride, thread_pool); } return Status::OK(); } -// Explicit template instantiation for float and double +// Explicit instantiation in this .cc keeps DeformConv definitions out of other TUs that only +// include deform_conv.h — one copy of Compute() per T in the library, faster builds and predictable link size. template class DeformConv; template class DeformConv; diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index a5c0606a55e7d..60f628a58d6ef 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -3,7 +3,7 @@ #pragma once -#include +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" @@ -73,6 +73,42 @@ struct DeformConvParams { bool use_mask{false}; // Whether optional mask input is provided }; +// Common derived dimensions used by both CPU and CUDA kernels. +struct DeformConvCommonDims { + int64_t kernel_size{0}; // kH * kW + int64_t output_image_size{0}; // out_h * out_w + int64_t input_image_size{0}; // H * W_in + int64_t kernel_dim{0}; // (C / group) * kernel_size +}; + +// Validates shared runtime bounds and computes common derived dimensions. +// This helper is backend-agnostic and intended to be reused by both CPU/CUDA +// after DeformConvValidateAndParse() succeeds. +inline Status DeformConvValidateAndComputeCommonDims(const DeformConvParams& params, + DeformConvCommonDims& dims) { + const int64_t int64_max = std::numeric_limits::max(); + ORT_RETURN_IF_NOT(params.N > 0 && params.C > 0 && params.M > 0 && + params.group > 0 && params.offset_group > 0 && + params.kH > 0 && params.kW > 0 && + params.H > 0 && params.W_in > 0 && + params.out_h > 0 && params.out_w > 0, + "Invalid deform conv dimensions."); + + ORT_RETURN_IF_NOT(params.kH <= int64_max / params.kW, "kernel_size overflows int64."); + dims.kernel_size = params.kH * params.kW; + + ORT_RETURN_IF_NOT(params.out_h <= int64_max / params.out_w, "output_image_size overflows int64."); + dims.output_image_size = params.out_h * params.out_w; + + ORT_RETURN_IF_NOT(params.H <= int64_max / params.W_in, "input_image_size overflows int64."); + dims.input_image_size = params.H * params.W_in; + + ORT_RETURN_IF_NOT((params.C / params.group) <= int64_max / dims.kernel_size, "kernel_dim overflows int64."); + dims.kernel_dim = (params.C / params.group) * dims.kernel_size; + + return Status::OK(); +} + // Validates inputs and parses attributes into params. // Returns Status::OK() on success; on failure, params may be partially filled. inline Status DeformConvValidateAndParse( @@ -159,10 +195,10 @@ inline Status DeformConvValidateAndParse( params.out_w = (params.W_in + params.pad_w + params.pad_w_end - params.dilation_w * (params.kW - 1) - 1) / params.stride_w + 1; ORT_RETURN_IF_NOT(params.out_h >= 0 && params.out_w >= 0, "Computed output spatial size must be non-negative."); - // CPU BilinearInterpolate uses int for indices (for performance optimization); W <= INT_MAX / (H+1) covers all index math. + // CPU BilinearInterpolate uses int for indices (for performance optimization); W <= int_max / (H+1) covers all index math. ORT_RETURN_IF_NOT(params.H >= 0 && params.W_in >= 0, "Input spatial dimensions H and W must be non-negative."); - ORT_RETURN_IF_NOT(params.W_in <= static_cast(INT_MAX) / (params.H + 1), - "Input (H+1)*W must not exceed INT_MAX (for performance optimization)."); + ORT_RETURN_IF_NOT(params.W_in <= static_cast(std::numeric_limits::max()) / (params.H + 1), + "Input (H+1)*W must not exceed int max (for performance optimization)."); // Validate tensor shapes (use division to avoid int64 overflow in offset_group * 2 * kH * kW). ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size."); diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv.cc b/onnxruntime/core/providers/cuda/nn/deform_conv.cc index 7a0b896acfe01..4349ec4fc4773 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv.cc +++ b/onnxruntime/core/providers/cuda/nn/deform_conv.cc @@ -2,6 +2,17 @@ // Licensed under the MIT License. // // CUDA implementation of DeformConv (deformable convolution 2D). +// High-level pipeline matches CPU `nn/deform_conv.cc`: im2col then grouped GEMM then optional bias; +// this file hosts the EP and batch chunking; device kernels live in `deform_conv_impl.cu`. +// +// High-level pipeline (batch may be chunked for col_buffer memory; see GetNParallelImgs): +// (1) Deformable im2col per chunk: DeformConvIm2ColImpl launches GPU kernels that fill col_buffer +// (bilinear sampling + optional mask fused in threads; no separate sampling plan like CPU). +// (2) Grouped strided batched GEMM: Y = W * Col via cuBLAS (row-major vs column-major mapping in ComputeInternal). +// (3) Optional bias: add B[m] to each output channel map (DeformConvAddBiasImpl). +// +// Main difference vs CPU path: CPU builds an AoSoA bilinear plan once per image then reuses it across channels; +// CUDA recomputes bilinear samples in the im2col kernel while walking offset/mask tensors. #include "core/providers/shared_library/provider_api.h" #include "deform_conv.h" @@ -21,31 +32,30 @@ namespace { constexpr int kMaxParallelImgs = 32; -// Returns the greatest divisor of n that is <= bound. Used to choose uniform batch chunk sizes. -// Fast path: if n % bound == 0 (common for batch 32/64/128), return immediately. -// When n >= bound^2, linear scan from bound down is O(bound). Otherwise divisor enumeration -// from 1 to sqrt(n) is O(sqrt(n)). Uses integer comparison (no sqrt) for branch decision. -int GetGreatestDivisorBelowBound(int n, int bound) { - if (bound <= 0 || n <= 0) return 1; - if (n % bound == 0) return bound; // Fast path: batch is multiple of target - - // n >= bound^2 <=> bound <= sqrt(n) => linear scan is cheaper - if (static_cast(n) >= static_cast(bound) * bound) { - for (int k = bound - 1; k > 1; --k) { - if (n % k == 0) return k; - } - } else { - // n < bound^2 <=> bound > sqrt(n) => divisor enumeration is cheaper - int best = 1; - for (int i = 1; static_cast(i) * i <= static_cast(n); ++i) { - if (n % i != 0) continue; - const int q = n / i; - if (q <= bound && q > best) best = q; - if (i <= bound && i > best) best = i; - } - return best; - } - return 1; +// ceil(numer / denom) for numer >= 0, denom > 0 (integer, no floating point). +// Avoid (numer + denom - 1) / denom: numer near INT_MAX overflows signed int (UB in C++). +inline int CeilDiv(int numer, int denom) { + return numer / denom + (numer % denom != 0 ? 1 : 0); +} + +// Chooses DeformConv batch chunk size k (images per outer-loop iteration) given batch N and +// a hard cap T from temp-memory budget (target_parallel_imgs). +// +// Goals (in order): +// 1) Minimize the number of outer rounds I = ceil(N / k). Under k <= T, the minimum achievable +// I is I* = ceil(N / min(N, T)) — take the largest allowed step min(N, T), same as always +// using k = T when N > T, or one round when N <= T. +// 2) Among all k with ceil(N/k) == I*, pick k = ceil(N / I*) so chunk sizes are as balanced as +// possible (last chunk is only slightly smaller than full chunks). k need not divide N; choosing +// k = ceil(N / I*) instead of always k = T often shrinks col_buffer stride when a full-T last +// chunk would leave a much smaller tail. +// +// Closed form: k_cap = min(N, T), I = ceil(N / k_cap), return ceil(N / I). +inline int GetDeformConvParallelChunkSize(int N, int T) { + if (N <= 0 || T <= 0) return 1; + const int k_cap = std::min(N, T); + const int num_rounds = CeilDiv(N, k_cap); + return CeilDiv(N, num_rounds); } // Returns the maximum temp memory (bytes) allowed for DeformConv's im2col + GEMM buffers. @@ -76,28 +86,25 @@ size_t GetDeformConvEffectiveMaxTempBytes(size_t total_global_mem) { } // Returns how many images to process in parallel per batch chunk for DeformConv. -// Chooses the largest divisor of batch size N that fits in the temp budget and does not -// exceed kMaxParallelImgs, so that batch dimension is split evenly (no remainder). -// Note: if N is prime and N > target_parallel_imgs, the greatest divisor <= target_parallel_imgs is 1, -// so batching is effectively disabled (single-image chunks). +// +// Temp budget → cap T (see below). Chunk size k = GetDeformConvParallelChunkSize(N, T): minimize +// outer-loop rounds first, then balance chunk sizes via ceil(N / ceil(N / min(N,T))). +// The host loop still uses cur_parallel = min(k, N - b), so k need not divide N. // // Formulas: -// kernel_size = kH * kW -// output_image_size = out_h * out_w -// bytes_per_image = output_image_size * (C * kernel_size + M / group) * sizeof(T) -// (temp bytes per image: im2col col buffer + GEMM output buffer per output position) +// kernel_size / output_image_size come from validated common dims +// bytes_per_image = output_image_size * C * kernel_size * sizeof(T) +// (temp bytes per image: im2col col buffer only; GEMM writes directly to Y) // max_parallel_imgs_mem = max(1, floor(effective_max_temp / bytes_per_image)) -// target_parallel_imgs = min(kMaxParallelImgs, max_parallel_imgs_mem) -// return GetGreatestDivisorBelowBound(N, target_parallel_imgs) +// target_parallel_imgs T = min(kMaxParallelImgs, max_parallel_imgs_mem) +// return GetDeformConvParallelChunkSize(N, T) template -int GetNParallelImgs(const DeformConvParams& params, size_t total_global_mem) { +int GetNParallelImgs(const DeformConvParams& params, int64_t kernel_size, int64_t output_image_size, size_t total_global_mem) { const size_t effective_max_temp = GetDeformConvEffectiveMaxTempBytes(total_global_mem); - const int64_t kernel_size = params.kH * params.kW; - const int64_t output_image_size = params.out_h * params.out_w; - const size_t bytes_per_image = SafeInt(output_image_size) * (params.C * kernel_size + params.M / params.group) * sizeof(T); + const size_t bytes_per_image = SafeInt(output_image_size) * params.C * kernel_size * sizeof(T); const int max_parallel_imgs_mem = std::max(1, static_cast(effective_max_temp / std::max(size_t(1), bytes_per_image))); const int target_parallel_imgs = std::min(kMaxParallelImgs, max_parallel_imgs_mem); - return GetGreatestDivisorBelowBound(static_cast(params.N), target_parallel_imgs); + return GetDeformConvParallelChunkSize(narrow(params.N), target_parallel_imgs); } } // namespace @@ -146,12 +153,13 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } - const int n_parallel_imgs = GetNParallelImgs(params, GetDeviceProp().totalGlobalMem); - - const int64_t kernel_size = kH * kW; - const int64_t output_image_size = out_h * out_w; - const int64_t input_image_size = H * W_in; - const int64_t kernel_dim = (C / group) * kernel_size; + DeformConvCommonDims common_dims; + ORT_RETURN_IF_ERROR(DeformConvValidateAndComputeCommonDims(params, common_dims)); + const int64_t kernel_size = common_dims.kernel_size; + const int64_t output_image_size = common_dims.output_image_size; + const int64_t input_image_size = common_dims.input_image_size; + const int64_t kernel_dim = common_dims.kernel_dim; + const int n_parallel_imgs = GetNParallelImgs(params, kernel_size, output_image_size, GetDeviceProp().totalGlobalMem); const int64_t col_stride = static_cast(n_parallel_imgs) * output_image_size; const int64_t col_buffer_size = (C * kernel_size) * col_stride; @@ -159,8 +167,6 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); auto col_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt(col_buffer_size)); - // Removed col_transposed allocation as we avoid physical transpose. - auto gemm_output_buffer = IAllocator::MakeUniquePtr(alloc, SafeInt((M / group) * col_stride)); const T* Xdata = X->Data(); const T* Wdata = W->Data(); @@ -180,6 +186,7 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { const int64_t cur_out_size = static_cast(cur_parallel) * output_image_size; const T* X_block = Xdata + b * (C * input_image_size); + // Stride per full image along N: offset [N, offset_group*2*kH*kW, OH, OW] -> offset_group * 2*kH*kW * OH*OW floats. const T* offset_block = offset_data + b * (offset_group * 2 * kernel_size * output_image_size); const T* mask_block = use_mask ? (mask_data + b * (offset_group * kernel_size * output_image_size)) : nullptr; @@ -215,16 +222,18 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { // - W (row [M/group, kernel_dim]) -> cuBLAS interprets as col-major [kernel_dim, M/group] = W^T // - C = A*B = Col^T * W^T = (W*Col)^T = Y^T; C is col-major [cur_out_size, M/group] = Y in row-major // - // m=cur_out_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, ldc=cur_out_size. + // Per batch image: m=output_image_size, n=M/group, k=kernel_dim; lda=cur_out_size, ldb=kernel_dim, + // ldc=output_image_size (row-major Y slice [M/group, OH*OW]). // - // cur_parallel==1: cur_out_size==output_image_size, C layout (pos, channel) matches NCHW Y_g[0,ch,pos] -> write - // directly into Y_g. Use strided batched for all groups in one call. - // cur_parallel>1: layouts differ -> write to gemm_output_buffer, then DeformConvCopyGemmOutputRowMajorToNCHW. - - const bool gemm_writes_directly = (cur_parallel == 1); - if (gemm_writes_directly) { - // Strided batched: one call for all groups. Strides between batches: - const int64_t stride_col = kernel_dim * col_stride; // = kernel_dim * output_image_size when cur_parallel==1 + // cur_parallel==1: one strided-batched GEMM over all groups (single launch). + // cur_parallel>1: per group, strided-batched GEMM with batch_count=cur_parallel; each batch writes one image + // directly into NCHW Y (strideC = M * output_image_size), avoiding a temp buffer + scatter kernel. + + if (cur_parallel == 1) { + // col_buffer is packed per iteration with the current chunk width (cur_out_size). + // Using outer-scope col_stride (based on n_parallel_imgs) breaks tail chunks where + // cur_out_size != col_stride (including one-image tails) when group > 1. + const int64_t stride_col = kernel_dim * cur_out_size; const int64_t stride_weight = (M / group) * kernel_dim; const int64_t stride_y = (M / group) * output_image_size; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( @@ -249,44 +258,42 @@ Status DeformConv::ComputeInternal(OpKernelContext* context) const { device_prop, UseTF32())); } else { - // cur_parallel>1: GEMM output layout differs from NCHW; write to buffer then copy per group. + const int64_t stride_a_col = output_image_size; + const int64_t stride_b = 0; + const int64_t stride_c_y = M * output_image_size; for (int64_t g = 0; g < group; ++g) { const T* W_g = Wdata + g * (M / group) * kernel_dim; - const T* col_g = col_buffer.get() + g * kernel_dim * col_stride; + const T* col_g = col_buffer.get() + g * kernel_dim * cur_out_size; T* Y_g = Ydata + b * M * output_image_size + g * (M / group) * output_image_size; - CUBLAS_RETURN_IF_ERROR((cublasGemmHelper( + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, - narrow(cur_out_size), + narrow(output_image_size), narrow(M / group), narrow(kernel_dim), &alpha, reinterpret_cast(col_g), narrow(cur_out_size), + stride_a_col, reinterpret_cast(W_g), narrow(kernel_dim), + stride_b, &beta, - reinterpret_cast(gemm_output_buffer.get()), - narrow(cur_out_size), + reinterpret_cast(Y_g), + narrow(output_image_size), + stride_c_y, + narrow(cur_parallel), device_prop, - UseTF32()))); - - ORT_RETURN_IF_ERROR(DeformConvCopyGemmOutputRowMajorToNCHW( - stream, - gemm_output_buffer.get(), - Y_g, - M, - M / group, - output_image_size, - cur_parallel)); + UseTF32())); } } } if (Bdata != nullptr) { - ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w)); + ORT_RETURN_IF_ERROR(DeformConvAddBiasImpl(stream, Ydata, Bdata, N, M, out_h, out_w, + static_cast(device_prop.maxGridSize[1]))); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu index 7b3666fca810b..033c31fbca112 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu @@ -1,8 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // -// CUDA implementation of DeformConv: deformable im2col kernel + bilinear interpolation. -// Reference: torchvision deform_conv2d_kernel.cu, ONNX DeformConv spec. +// CUDA device code for DeformConv: deformable im2col kernel(s) and bias-add kernel. +// Host orchestration and GEMM: `deform_conv.cc` (pipeline described there, aligned with CPU `nn/deform_conv.cc`). +// +// This file corresponds to CPU step (1) on GPU: each thread contributes im2col entries by sampling X with +// bilinear interpolation at offset positions (+ optional mask), instead of CPU's precomputed AoSoA plan + fill. +// +// Reference: torchvision deform_conv2d_kernel.cu, ONNX DeformConv. +// +// ONNX shapes (this EP; batch chunk = parallel_imgs): +// X [parallel_imgs, C, H, W] +// offset[parallel_imgs, offset_group * 2*kH*kW, out_h, out_w] — per (n, oh, ow), channels are +// (dy, dx) pairs for kernel taps in order (i=0..kH-1, j=0..kW-1): ch = 2*(i*kW+j) for dy, +1 for dx. +// mask [parallel_imgs, offset_group * kH*kW, out_h, out_w] — optional; ch = i*kW+j. +// col row-major [C * kH * kW, parallel_imgs * out_h * out_w]; GEMM uses this as in deform_conv.cc. +// +// Sampling (same as CPU / typical DCN): for output (oh, ow), kernel tap (i, j), +// h_ref = oh * stride_h - pad_h + i * dilation_h + Δh(oh,ow,i,j) +// w_ref = ow * stride_w - pad_w + j * dilation_w + Δw(oh,ow,i,j) +// then bilinear sample X at (h_ref, w_ref); multiply by mask if present. #include "deform_conv_impl.h" #include "core/providers/cuda/cu_inc/common.cuh" @@ -31,13 +48,33 @@ inline int GetGridSize(size_t n, size_t threads_per_block) { return static_cast(std::min(blocks_needed, static_cast(std::numeric_limits::max()))); } +template +inline bool Needs64BitIndex(Values... values) { + constexpr int64_t kInt32Max = static_cast(std::numeric_limits::max()); + return ((static_cast(values) > kInt32Max) || ...); +} + +inline bool ProductExceedsInt32Max(std::initializer_list factors) { + constexpr int64_t kInt32Max = static_cast(std::numeric_limits::max()); + int64_t acc = 1; + for (int64_t v : factors) { + // DeformConv dimensions are expected to be non-negative after validation. + // If violated unexpectedly, conservatively force the 64-bit kernel path. + if (v < 0) return true; + if (v == 0) return false; + if (acc > kInt32Max / v) return true; + acc *= v; + } + return false; +} + // __ldg has no overload for BFloat16*; use 16-bit load + FromBits. Other types use __ldg directly. template -__device__ __inline__ T DeformConvLdg(const T* p) { +__device__ __inline__ T DeformConvLdg(const T* __restrict__ p) { return __ldg(p); } template <> -__device__ __inline__ BFloat16 DeformConvLdg(const BFloat16* p) { +__device__ __inline__ BFloat16 DeformConvLdg(const BFloat16* __restrict__ p) { return BFloat16::FromBits(__ldg(reinterpret_cast(p))); } @@ -50,7 +87,7 @@ template struct DeformConvBilinearTraits { using ComputeT = T; - __device__ static __inline__ ComputeT Load(const T* p) { + __device__ static __inline__ ComputeT Load(const T* __restrict__ p) { return __ldg(p); } @@ -67,7 +104,7 @@ template <> struct DeformConvBilinearTraits { using ComputeT = float; - __device__ static __inline__ ComputeT Load(const half* p) { + __device__ static __inline__ ComputeT Load(const half* __restrict__ p) { return __half2float(__ldg(p)); } @@ -84,7 +121,7 @@ template <> struct DeformConvBilinearTraits { using ComputeT = float; - __device__ static __inline__ ComputeT Load(const BFloat16* p) { + __device__ static __inline__ ComputeT Load(const BFloat16* __restrict__ p) { return static_cast(DeformConvLdg(p)); } @@ -105,9 +142,19 @@ struct DeformConvBilinearTraits { // DeformConvBilinearTraits to avoid precision loss. We keep floor() results in CoordT and // cast to int only for indices (h_low/w_low), which avoids unnecessary CoordT->int->CoordT // round trips when computing lh/lw/hh/hw. +// +// Historical note: before switching to branchless masked loads, this workload had the following +// "edge sample" ratio (counts = samples with >=1 OOB neighbor / total bilinear samples). +// The numbers remain useful as boundary-hit context, but no longer imply control-flow divergence. +// Example workload only; not a benchmark or representative ratio. +// kernel 1x1: 1.3746% (2421 / 176128) +// kernel 3x3: 1.4833% (11756 / 792576) +// kernel 7x7: 4.7593% (52537 / 1103872) +// Current implementation always issues safe-address loads and masks invalid neighbors to zero. +// Offsets are often spatially smooth, so nearby threads still tend to exhibit similar validity patterns. template __device__ __inline__ T BilinearInterpolate( - const T* in, + const T* __restrict__ in, int height, int width, typename DeformConvBilinearTraits::ComputeT h, @@ -115,12 +162,20 @@ __device__ __inline__ T BilinearInterpolate( using Traits = DeformConvBilinearTraits; using CoordT = typename Traits::ComputeT; - // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() for OOB case). + // [Optimization 1]: Early exit for clearly out-of-bounds (skip floor() and neighbor loads for OOB case). + // Semantics guardrail: if sample point is outside [-1, H) x [-1, W), ONNX bilinear contribution is exactly 0. + // Why keep this even with branchless masked loads below: + // - The branchless path guarantees safe addressing and correct masked zero, but still pays floor/weight math + // and four global loads. + // - This early return avoids all of that work for clearly OOB samples. + // About divergence: mixed in/out-of-bound warps can diverge here, but OOB lanes terminate immediately while + // in-bound lanes continue useful work; in practice this often wins unless OOB distribution is highly random + // and branch hit-rate is very high. if (h <= static_cast(-1) || h >= height || w <= static_cast(-1) || w >= width) { return Traits::Zero(); } - // [Optimization 2]: Keep floor result in T; cast to int only for indices. Avoids float->int->float in lh/lw. + // [Optimization 2]: Keep floor result in CoordT; cast to int only for indices. Avoids float->int->float in lh/lw. CoordT h_floor = _Floor(h); CoordT w_floor = _Floor(w); int h_low = static_cast(h_floor); @@ -133,28 +188,44 @@ __device__ __inline__ T BilinearInterpolate( CoordT hh = static_cast(1) - lh; CoordT hw = static_cast(1) - lw; - // [Optimization 3]: Avoid a second multiply for base_high. - // Original code computed both bases as: - // base_low = h_low * width; - // base_high = h_high * width; - // Since h_high = h_low + 1, we can rewrite base_high as base_low + width and - // save one integer multiply in the hot path: - // base_low = h_low * width; - // base_high = base_low + width; - int base_low = h_low * width; - int base_high = base_low + width; - - CoordT v1 = (h_low >= 0 && w_low >= 0) ? Traits::Load(in + base_low + w_low) : static_cast(0); - CoordT v2 = (h_low >= 0 && w_high < width) ? Traits::Load(in + base_low + w_high) : static_cast(0); - CoordT v3 = (h_high < height && w_low >= 0) ? Traits::Load(in + base_high + w_low) : static_cast(0); - CoordT v4 = (h_high < height && w_high < width) ? Traits::Load(in + base_high + w_high) : static_cast(0); - - CoordT w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - return Traits::ToResult(w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + // [Optimization 3]: Branchless neighbor loads via "safe address + one-sided clamp". + // Given the early return above, coordinates are in (-1, H) x (-1, W), so each index only needs one-sided clamp: + // h_low in [-1, H-1], h_high in [0, H], w_low in [-1, W-1], w_high in [0, W]. + // We always load from legal addresses; validity is applied by 2D neighbor masks below. + // CUDA compilers usually lower this to predicated/selp-style code without control-flow branches. + const int safe_h_low = max(0, h_low); + const int safe_h_high = min(h_high, height - 1); + const int safe_w_low = max(0, w_low); + const int safe_w_high = min(w_high, width - 1); + + // [Optimization 4]: One-sided validity checks under the same invariant. + // Keep 2D neighbor masks (m1..m4), algebraically equivalent to masking invalid neighbor terms to zero. + // Use one/zero ternaries directly in CoordT to encourage selp.f32/f16 generation. + const CoordT one = static_cast(1); + const CoordT zero = static_cast(0); + const CoordT m1 = (h_low >= 0 && w_low >= 0) ? one : zero; + const CoordT m2 = (h_low >= 0 && w_high < width) ? one : zero; + const CoordT m3 = (h_high < height && w_low >= 0) ? one : zero; + const CoordT m4 = (h_high < height && w_high < width) ? one : zero; + + const int safe_base_low = safe_h_low * width; + const int safe_base_high = safe_h_high * width; + + const CoordT v1 = Traits::Load(in + safe_base_low + safe_w_low) * m1; + const CoordT v2 = Traits::Load(in + safe_base_low + safe_w_high) * m2; + const CoordT v3 = Traits::Load(in + safe_base_high + safe_w_low) * m3; + const CoordT v4 = Traits::Load(in + safe_base_high + safe_w_high) * m4; + + // [Optimization 5]: Factor bilinear into horizontal blends on two rows, then vertical blend. + // Algebraically equivalent to w1*v1 + w2*v2 + w3*v3 + w4*v4 with w1..w4 from hh/hw/lh/lw; + // this form tends to produce fewer independent multiplies and friendlier FFMA scheduling. + CoordT top = hw * v1 + lw * v2; + CoordT bottom = hw * v3 + lw * v4; + return Traits::ToResult(hh * top + lh * bottom); } // kH/kW = -1 means dynamic (runtime); >= 0 means compile-time constant for loop unrolling. -template +template __global__ void DeformableIm2ColKernel( IndexT num_kernels, const T* __restrict__ input, @@ -176,213 +247,316 @@ __global__ void DeformableIm2ColKernel( DivMod out_w_div, DivMod parallel_imgs_div, DivMod channel_per_offset_grp_div, - bool use_mask, T* __restrict__ data_col) { + // Aliasing contract for this kernel: + // - input/offset/mask are read-only and may alias each other, + // - data_col is write-only and must not overlap any input buffer. constexpr bool is_fixed = (kH >= 0 && kW >= 0); - const int64_t h_dim = is_fixed ? kH : weight_h; - const int64_t w_dim = is_fixed ? kW : weight_w; - - // Reconstruct dimensions from DivMod objects - const int64_t out_h = out_h_div.d_; - const int64_t out_w = out_w_div.d_; - const int64_t parallel_imgs = parallel_imgs_div.d_; - - const int64_t out_size = out_h * out_w; + const int64_t h_dim_i64 = is_fixed ? kH : weight_h; + const int64_t w_dim_i64 = is_fixed ? kW : weight_w; + const IndexT h_dim = static_cast(h_dim_i64); + const IndexT w_dim = static_cast(w_dim_i64); + + // Linear thread index `index` encodes (in_c, out_b, out_y, out_x) with x fastest: + // index = out_x + out_w * (out_y + out_h * (out_b + parallel_imgs * in_c)) + // Unroll: divmod by out_w -> out_x; by out_h -> out_y; by parallel_imgs -> out_b, in_c. + const IndexT out_h = out_h_div.d_; + const IndexT out_w = out_w_div.d_; + const IndexT parallel_imgs = parallel_imgs_div.d_; + + const IndexT out_size = out_h * out_w; // The stride for data_col is (parallel_imgs * out_h * out_w) - const int64_t col_stride = parallel_imgs * out_size; + const IndexT col_stride = parallel_imgs * out_size; // columns span one spatial map per image in the chunk + const int64_t out_size_i64 = static_cast(out_size); + const int64_t col_stride_i64 = static_cast(col_stride); + const int64_t channel_hw_i64 = static_cast(height) * static_cast(width); + const int64_t batch_input_stride_i64 = static_cast(channels) * channel_hw_i64; + // One (n, offset_group g) slice of `offset` in linear memory: 2*kH*kW planes of shape (out_h, out_w). + const int64_t offset_group_block_size_i64 = static_cast(2) * h_dim_i64 * w_dim_i64 * out_size_i64; + // Same for `mask`: kH*kW planes of (out_h, out_w). + [[maybe_unused]] const int64_t mask_group_block_size_i64 = UseMask ? (h_dim_i64 * w_dim_i64 * out_size_i64) : int64_t{0}; + const int height_i = static_cast(height); + const int width_i = static_cast(width); using CoordT = typename DeformConvBilinearTraits::ComputeT; - for (IndexT index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { + for (IndexT index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; index < num_kernels; index += static_cast(blockDim.x) * gridDim.x) { IndexT val = index; IndexT out_x, out_y, out_b, in_c; - // Fast division/modulo to recover coordinates out_w_div.divmod(val, val, out_x); out_h_div.divmod(val, val, out_y); parallel_imgs_div.divmod(val, in_c, out_b); - // [Optimization 3] Avoid expensive division if offset_group is 1 (very common case). + // [Im2Col] offset_group==1: channel_per_offset_grp_div is unused; skip divmod. IndexT offset_grp = 0; if (offset_group > 1) { IndexT dummy; channel_per_offset_grp_div.divmod(in_c, offset_grp, dummy); } - // [Optimization 2] Common Subexpression Elimination (CSE) & Pointer Arithmetic - // Pre-calculate base pointers to reduce integer arithmetic inside the inner loops. + // [Im2Col] CSE: base pointers for this thread (one output pixel × input channel). - // 1. Input pointer base for this batch and channel. - const T* input_ptr = input + static_cast(out_b) * (channels * height * width) + static_cast(in_c) * (height * width); + // 1. Input X: NCHW; offset to (out_b, in_c) is out_b * (C*H*W) + in_c * (H*W). + const IndexT channel_hw = static_cast(channel_hw_i64); + const IndexT batch_input_stride = static_cast(batch_input_stride_i64); + const IndexT input_base = out_b * batch_input_stride + in_c * channel_hw; + const T* __restrict__ input_ptr = input + static_cast(input_base); // 2. Spatial index in the output feature map. - const int64_t spatial_idx = static_cast(out_y) * out_w + static_cast(out_x); - - // 3. Offset pointer base calculation. - // Layout: (N, offset_groups, 2*KH*KW, OH, OW) - // We pre-calculate the pointer to the start of the specific (n, g) block, plus spatial_idx. - const int64_t offset_group_block_size = 2 * h_dim * w_dim * out_size; - const T* offset_ptr_base = offset + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * offset_group_block_size + spatial_idx; - - // 4. Mask pointer base calculation (if used). - // Layout: (N, offset_groups, KH*KW, OH, OW) - const T* mask_ptr_base = nullptr; - if (use_mask) { - const int64_t mask_group_block_size = h_dim * w_dim * out_size; - mask_ptr_base = mask + (static_cast(out_b) * offset_group + static_cast(offset_grp)) * mask_group_block_size + spatial_idx; + const IndexT spatial_idx = static_cast(out_y * out_w + out_x); + + // 3. Offset: linear index to (dy,dx) channel 0 at (out_y, out_x) for image out_b, deformable group offset_grp. + // ng = out_b * offset_group + offset_grp + // offset_base = ng * (2*kH*kW*out_h*out_w) + (out_y*out_w + out_x) + const IndexT offset_group_idx = static_cast(offset_group); + const IndexT ng = out_b * offset_group_idx + offset_grp; + const IndexT offset_group_block_size = static_cast(offset_group_block_size_i64); + const IndexT offset_base = ng * offset_group_block_size + spatial_idx; + const T* __restrict__ offset_ptr_base = offset + static_cast(offset_base); + + // 4. Mask: same as offset but kH*kW planes: mask_base = ng * (kH*kW*out_h*out_w) + spatial_idx. + const T* __restrict__ mask_ptr_base = nullptr; + if constexpr (UseMask) { + const IndexT mask_group_block_size = static_cast(mask_group_block_size_i64); + const IndexT mask_base = ng * mask_group_block_size + spatial_idx; + mask_ptr_base = + mask + static_cast(mask_base); } - // 5. Output pointer base calculation. - // data_col Layout: (C * KH * KW, N * OH * OW) - // The current thread writes to the column `c_col` = (b * OH * OW) + spatial_idx. - // The starting row for this channel is `in_c * KH * KW`. - const int64_t c_col = static_cast(out_b) * out_size + spatial_idx; - T* data_col_ptr_base = data_col + (static_cast(in_c) * h_dim * w_dim) * col_stride + c_col; + // 5. col_buffer row-major: row r = in_c * (kH*kW) + kernel_flat; column c_col = out_b * out_h*out_w + spatial_idx. + // Element (r, c_col) at col_buffer[r * col_stride + c_col]. + const IndexT c_col = out_b * out_size + spatial_idx; + const IndexT row_base = static_cast((in_c * h_dim) * w_dim); + T* __restrict__ data_col_ptr_base = + data_col + static_cast(row_base) * col_stride_i64 + static_cast(c_col); - // 6. Pre-calculate invariant coordinate parts. - // Use float for coordinate math when T is half or BFloat16 to avoid precision loss. + // 6. Undilated top-left of the kernel anchor for this output pixel: base_* = out_* * stride_* - pad_*. + // Row i / col j add i*dilation_h / j*dilation_w before applying offsets (see run_deform_row). const CoordT base_h_im = static_cast(out_y * stride_h - pad_h); const CoordT base_w_im = static_cast(out_x * stride_w - pad_w); - auto process_kernel_point = [&](int64_t i, int64_t j) { - const int64_t kernel_idx = i * w_dim + j; + // Per (output location, channel): one sample from offset/mask tensors and bilinear input. + auto process_kernel_point = [&](const T* __restrict__ offset_h_ptr, const T* __restrict__ offset_w_ptr, + const T* __restrict__ mask_ptr, T* __restrict__ data_col_ptr, CoordT h_base, + CoordT w_base) { T mask_val = static_cast(1); - if (use_mask) { - // Access mask using pre-calculated base and stride. - mask_val = DeformConvLdg(mask_ptr_base + kernel_idx * out_size); + if constexpr (UseMask) { + mask_val = DeformConvLdg(mask_ptr); } - // Calculate offset pointers relative to the base. - // The offset tensor stores (y_offset, x_offset) pairs for each kernel weight. - // Stride between y_offset and x_offset is `out_size`. - const int64_t offset_offset_idx = (2 * kernel_idx) * out_size; - - const CoordT offset_h = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx)); - const CoordT offset_w = static_cast(DeformConvLdg(offset_ptr_base + offset_offset_idx + out_size)); + const CoordT offset_h = static_cast(DeformConvLdg(offset_h_ptr)); + const CoordT offset_w = static_cast(DeformConvLdg(offset_w_ptr)); - const CoordT h_im = base_h_im + static_cast(i * dilation_h) + offset_h; - const CoordT w_im = base_w_im + static_cast(j * dilation_w) + offset_w; + const CoordT h_im = h_base + offset_h; + const CoordT w_im = w_base + offset_w; // height/width are validated on host (DeformConvValidateAndParse) so int is safe here. - T val = BilinearInterpolate(input_ptr, - static_cast(height), - static_cast(width), - h_im, - w_im); + T val = BilinearInterpolate(input_ptr, height_i, width_i, h_im, w_im); // Match CPU path: always interpolate then apply mask to keep branch-free hot loop. - data_col_ptr_base[kernel_idx * col_stride] = val * mask_val; + *data_col_ptr = val * mask_val; }; - if constexpr (is_fixed) { + // One row of kernel weights (fixed kW or runtime weight_w): compute row base once, then walk j with pointer + // adds only (no kernel_idx * stride rebuild each j). Shared by compile-time and dynamic kernel sizes. + // Along the kernel row, dy/dx planes are spaced by out_h*out_w; each (dy,dx) pair spans 2*out_size elements. + const IndexT offset_pair_stride = static_cast(2) * out_size; + auto run_deform_row = [&](IndexT row_kernel_base, CoordT h_base, IndexT row_width) { + CoordT w_base = base_w_im; + const IndexT offset_elem_offset = static_cast(2 * row_kernel_base) * out_size; + const T* __restrict__ offset_h_ptr = offset_ptr_base + offset_elem_offset; + const T* __restrict__ offset_w_ptr = offset_h_ptr + out_size; + const T* __restrict__ mask_ptr = nullptr; + if constexpr (UseMask) { + mask_ptr = mask_ptr_base + row_kernel_base * out_size; + } + T* __restrict__ data_col_ptr = data_col_ptr_base + row_kernel_base * col_stride; + + auto step_kernel_point = [&]() { + process_kernel_point(offset_h_ptr, offset_w_ptr, mask_ptr, data_col_ptr, h_base, w_base); + offset_h_ptr += offset_pair_stride; + offset_w_ptr += offset_pair_stride; + if constexpr (UseMask) { + mask_ptr += out_size; + } + data_col_ptr += col_stride; + w_base += static_cast(dilation_w); + }; + + // Small fixed kernels: unroll inner j so codegen matches the old fully-unrolled 1x1/3x3 path. + if constexpr (is_fixed && kH * kW <= 9) { #pragma unroll - for (int i = 0; i < kH; ++i) { + for (IndexT j = 0; j < row_width; ++j) { + step_kernel_point(); + } + } else { + for (IndexT j = 0; j < row_width; ++j) { + step_kernel_point(); + } + } + }; + + if constexpr (is_fixed) { + if constexpr (kH * kW <= 9) { + // For 1x1 and 3x3, unroll the outer i loop; inner j uses run_deform_row with #pragma unroll there. #pragma unroll - for (int j = 0; j < kW; ++j) { - process_kernel_point(i, j); + for (int i = 0; i < kH; ++i) { + const IndexT i_idx = static_cast(i); + run_deform_row(i_idx * w_dim, base_h_im + static_cast(i_idx * dilation_h), w_dim); + } + } else { + // Larger fixed kernels (including 7x7): keep both outer i and inner j rolled to limit register + // pressure from the heavy bilinear body. 7x7 still benefits from launch-time kH/kW constants + // without inner #pragma unroll. + for (int i = 0; i < kH; ++i) { + const IndexT i_idx = static_cast(i); + run_deform_row(i_idx * w_dim, base_h_im + static_cast(i_idx * dilation_h), w_dim); } } } else { - for (int64_t i = 0; i < weight_h; ++i) { - for (int64_t j = 0; j < weight_w; ++j) { - process_kernel_point(i, j); - } + const IndexT weight_h_idx = static_cast(weight_h); + const IndexT weight_w_idx = static_cast(weight_w); + for (IndexT i = 0; i < weight_h_idx; ++i) { + const IndexT row_base_idx = static_cast(i * weight_w_idx); + run_deform_row(row_base_idx, base_h_im + static_cast(i * dilation_h), weight_w_idx); } } } } -// Bias add: Y[n,m,oh,ow] += B[m]. Layout NCHW. -template +// Bias add: Y[n,m,oh,ow] += B[m]. Y linear row-major NCHW: idx = n*(M*HW) + m*HW + (oh*W+ow). +template __global__ void DeformConvAddBiasKernel( - T* Y, - const T* B, - DivMod spatial_div, // For dividing by (H * W) - DivMod channel_div, // For dividing by M (channel count) - int64_t total_elements) { - for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += blockDim.x * gridDim.x) { - int64_t val = idx; - int64_t batch_channel_idx, pixel_idx; - - // 1. First decomposition: decompose idx into (batch_channel_idx, pixel_idx) - // Equivalent to: batch_channel_idx = idx / (H*W); pixel_idx = idx % (H*W); + T* __restrict__ Y, + const T* __restrict__ B, + DivMod spatial_div, // For dividing by (H * W) + DivMod channel_div, // For dividing by M (channel count) + IndexT total_elements) { + for (IndexT idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; idx < total_elements; + idx += static_cast(blockDim.x) * gridDim.x) { + IndexT val = idx; + IndexT batch_channel_idx, pixel_idx; + + // idx -> (batch_channel_idx, pixel_idx) with pixel_idx = oh*out_w+ow fastest. spatial_div.divmod(val, batch_channel_idx, pixel_idx); - int64_t batch_idx, channel_idx; - - // 2. Second decomposition: decompose batch_channel_idx into (batch_idx, channel_idx) - // Equivalent to: channel_idx = batch_channel_idx % M; - // We only need channel_idx (i.e. m) + // batch_channel_idx = n*M + m -> bias index is m = batch_channel_idx % M. + IndexT batch_idx, channel_idx; channel_div.divmod(batch_channel_idx, batch_idx, channel_idx); - (void)batch_idx; // Only channel_idx is needed + ORT_UNUSED_PARAMETER(batch_idx); - // channel_idx is what we need (i.e. m) Y[idx] += DeformConvLdg(B + channel_idx); } } -// Copy GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) into NCHW Y_g. -// src(c, j) with j = b_idx*output_image_size + pos -> dst[b_idx*M*output_image_size + c*output_image_size + pos]. -template -__global__ void CopyGemmOutputRowMajorToNCHWKernel( - const T* __restrict__ src, - T* __restrict__ dst, - int64_t M, - int64_t M_per_group, - int64_t output_image_size, - int64_t cur_parallel) { - int64_t total = cur_parallel * M_per_group * output_image_size; - for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total; idx += blockDim.x * gridDim.x) { - int64_t pos = idx % output_image_size; - int64_t c = (idx / output_image_size) % M_per_group; - int64_t b_idx = idx / (output_image_size * M_per_group); - int64_t j = b_idx * output_image_size + pos; - // src index for row-major: c * (cur_parallel * output_image_size) + j - dst[b_idx * M * output_image_size + c * output_image_size + pos] = src[c * (cur_parallel * output_image_size) + j]; +// 2D launch: blockIdx.y -> batch_channel_idx in [0, N*M), threadIdx -> pixel_idx in [0, out_h*out_w). +// Indexing: Y[batch_channel_idx * spatial_size + pixel_idx]. Pick IndexT from Needs64BitIndex like the 1D kernel. +template +__global__ void DeformConvAddBias2DKernel(T* __restrict__ Y, const T* __restrict__ B, IndexT spatial_size, + int32_t channels) { + // blockIdx.y maps to batch_channel_idx (N * M) + const IndexT batch_channel_idx = static_cast(blockIdx.y); + const IndexT channel_idx = batch_channel_idx % static_cast(channels); + T bias_val = DeformConvLdg(B + channel_idx); + + const IndexT pixel_idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + static_cast(threadIdx.x); + if (pixel_idx < spatial_size) { + Y[batch_channel_idx * spatial_size + pixel_idx] += bias_val; } } } // namespace template -Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w) { +Status DeformConvAddBiasImpl(cudaStream_t stream, T* Y, const T* B, int64_t N, int64_t M, int64_t out_h, int64_t out_w, int64_t max_grid_y) { int64_t total = N * M * out_h * out_w; if (total <= 0) return Status::OK(); // 1. Prepare divisor - int64_t out_size = out_h * out_w; - - // 2. Create FastDivMod object (note: ensure int64_t version of DivMod is used here) - DivMod spatial_div(out_size); - DivMod channel_div(M); + const int64_t out_size = out_h * out_w; + const int64_t batch_channels = N * M; + // For 1D DivMod kernel only: int32 fast path vs int64. Orthogonal to 2D launch (gridDim.y limit). + const bool use_64bit = Needs64BitIndex(total, out_size, M); + + // Fast 2D launch path: map blockIdx.y to (N*M) to avoid per-thread DivMod in bias add. + // Use it only when the device allows enough grid rows: below ~32 blocks in y, the extra + // parallelism (warps scheduled across blockIdx.y) is often too small to outweigh maintaining + // a second launch + kernel variant; the threshold is a heuristic—revisit if future GPUs change + // occupancy sweet spots or typical batch×channel counts. + constexpr int kMinGridYForBias2DPath = 32; + if (max_grid_y > kMinGridYForBias2DPath && batch_channels <= static_cast(max_grid_y)) { + dim3 block(kDeformConvThreadsPerBlock); + dim3 grid(static_cast(GetGridSize(static_cast(out_size), block.x)), + static_cast(batch_channels)); + const int32_t m_i32 = static_cast(M); + if (use_64bit) { + DeformConvAddBias2DKernel<<>>(Y, B, out_size, m_i32); + } else { + DeformConvAddBias2DKernel<<>>(Y, B, static_cast(out_size), m_i32); + } + return CUDA_CALL(cudaGetLastError()); + } int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); - - // 3. Pass DivMod objects - DeformConvAddBiasKernel<<>>( - Y, - B, - spatial_div, - channel_div, - total); + if (use_64bit) { + // 2. Create FastDivMod object (note: ensure int64_t version of DivMod is used here) + // 3. Pass DivMod objects + DeformConvAddBiasKernel<<>>( + Y, B, + DivMod(out_size), + DivMod(M), + total); + } else { + // 2. Create FastDivMod object + // 3. Pass DivMod objects + DeformConvAddBiasKernel<<>>( + Y, B, + DivMod(static_cast(out_size)), + DivMod(static_cast(M)), + static_cast(total)); + } return CUDA_CALL(cudaGetLastError()); } -template -Status DeformConvCopyGemmOutputRowMajorToNCHW( - cudaStream_t stream, - const T* gemm_output, - T* Y_g, - int64_t M, - int64_t M_per_group, - int64_t output_image_size, - int64_t cur_parallel) { - int64_t total = cur_parallel * M_per_group * output_image_size; - if (total <= 0) return Status::OK(); - int blocks = GetGridSize(static_cast(total), kDeformConvThreadsPerBlock); - CopyGemmOutputRowMajorToNCHWKernel<<>>( - gemm_output, Y_g, M, M_per_group, output_image_size, cur_parallel); - return CUDA_CALL(cudaGetLastError()); +// Determine if we need to fall back to 64-bit integer arithmetic in the CUDA kernel. +// 32-bit arithmetic is significantly faster and uses fewer registers. +// We check if any of the intermediate index calculations could exceed INT32_MAX (~2.14 billion). +// The most likely variable to exceed this is `col_numel`: +// col_numel = C * kH * kW * parallel_imgs * out_h * out_w +// +// Examples of when 64-bit fallback is triggered (col_numel > 2,147,483,647): +// - High Resolution (1K): C=256, kH=3, kW=3, parallel_imgs=1, out_h=1024, out_w=1024 +// col_numel = 256 * 3 * 3 * 1 * 1024 * 1024 = 2,415,919,104 (> 2.14B) +// - Large Kernel & Batch: C=128, kH=5, kW=5, parallel_imgs=11, out_h=256, out_w=256 +// col_numel = 128 * 5 * 5 * 11 * 256 * 256 = 2,306,867,200 (> 2.14B) +// - Massive Channels: C=4096, kH=3, kW=3, parallel_imgs=1, out_h=256, out_w=256 +// col_numel = 4096 * 3 * 3 * 1 * 256 * 256 = 2,415,919,104 (> 2.14B) +// - 3D-like Large Kernel: C=512, kH=7, kW=7, parallel_imgs=1, out_h=512, out_w=512 +// col_numel = 512 * 7 * 7 * 1 * 512 * 512 = 6,576,668,672 (> 2.14B) +// +// Example of a safe 32-bit case: +// - Typical ResNet: C=256, kH=3, kW=3, parallel_imgs=32, out_h=128, out_w=128 +// col_numel = 256 * 3 * 3 * 32 * 128 * 128 = 1,207,959,552 (< 2.14B) +// +// In practice, due to the 2GB hard limit on temp memory allocation in GetDeformConvEffectiveMaxTempBytes(), +// col_numel will almost never exceed INT32_MAX without OOMing first. +inline bool CheckDeformConvNeeds64BitIndex( + int64_t num_kernels, int64_t C, int64_t H, int64_t W, int64_t kH, int64_t kW, int64_t out_h, int64_t out_w, + int64_t parallel_imgs, int64_t offset_group) { + if (Needs64BitIndex(num_kernels, C, H, W, kH, kW, out_h, out_w, parallel_imgs, offset_group)) { + return true; + } + + // Check potentially large products without evaluating intermediate multiplications. + return ProductExceedsInt32Max({C, kH, kW, parallel_imgs, out_h, out_w}) || // col_numel + ProductExceedsInt32Max({2, kH, kW, out_h, out_w}) || // offset_inner_size + ProductExceedsInt32Max({kH, kW, out_h, out_w}) || // mask_inner_size + ProductExceedsInt32Max({parallel_imgs, offset_group, 2, kH, kW, out_h, out_w}) || // offset_numel + ProductExceedsInt32Max({parallel_imgs, offset_group, kH, kW, out_h, out_w}) || // mask_numel + ProductExceedsInt32Max({H, W}) || // channel_hw + ProductExceedsInt32Max({C, H, W}) || // batch_input_stride + ProductExceedsInt32Max({parallel_imgs, C, H, W}); // input_numel } template @@ -413,41 +587,51 @@ Status DeformConvIm2ColImpl( return Status::OK(); } - const int64_t col_numel = static_cast(C) * kH * kW * parallel_imgs * out_h * out_w; - const bool use_64bit = (num_kernels > static_cast(std::numeric_limits::max())) || - (col_numel > static_cast(std::numeric_limits::max())); + const bool use_64bit = CheckDeformConvNeeds64BitIndex(num_kernels, C, H, W, kH, kW, out_h, out_w, parallel_imgs, offset_group); int blocks = GetGridSize(static_cast(num_kernels), kDeformConvThreadsPerBlock); - auto launch = [&](auto kH_tag, auto kW_tag) { + auto launch = [&](auto kH_tag, auto kW_tag, auto use_mask_tag) { constexpr int KH = decltype(kH_tag)::value; constexpr int KW = decltype(kW_tag)::value; + constexpr bool UseMask = decltype(use_mask_tag)::value; if (use_64bit) { - DeformableIm2ColKernel<<>>( + DeformableIm2ColKernel<<>>( num_kernels, input, offset, mask, H, W, kH, kW, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, C, offset_group, DivMod(out_h), DivMod(out_w), DivMod(parallel_imgs), - DivMod(C / offset_group), use_mask, col_buffer); + DivMod(C / offset_group), col_buffer); } else { - DeformableIm2ColKernel<<>>( + DeformableIm2ColKernel<<>>( static_cast(num_kernels), input, offset, mask, H, W, kH, kW, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, C, offset_group, DivMod(static_cast(out_h)), DivMod(static_cast(out_w)), DivMod(static_cast(parallel_imgs)), DivMod(static_cast(C / offset_group)), - use_mask, col_buffer); + col_buffer); + } + }; + + auto launch_with_mask = [&](auto k_size_tag) { + if (use_mask) { + launch(k_size_tag, k_size_tag, std::integral_constant{}); + } else { + launch(k_size_tag, k_size_tag, std::integral_constant{}); } }; + // Keep template specializations for the most common kernel sizes in modern models. + // 5x5 is intentionally not specialized: it is less common in current architectures and is often + // replaced by stacked 3x3 blocks (similar receptive field with better optimization flexibility). if (kH == 1 && kW == 1) { - launch(DeformConvKSize<1>{}, DeformConvKSize<1>{}); + launch_with_mask(DeformConvKSize<1>{}); } else if (kH == 3 && kW == 3) { - launch(DeformConvKSize<3>{}, DeformConvKSize<3>{}); - } else if (kH == 5 && kW == 5) { - launch(DeformConvKSize<5>{}, DeformConvKSize<5>{}); + launch_with_mask(DeformConvKSize<3>{}); + } else if (kH == 7 && kW == 7) { + launch_with_mask(DeformConvKSize<7>{}); } else { - launch(DeformConvKSize<-1>{}, DeformConvKSize<-1>{}); + launch_with_mask(DeformConvKSize<-1>{}); } return CUDA_CALL(cudaGetLastError()); } @@ -460,15 +644,10 @@ INST_DeformConvIm2ColImpl(double); INST_DeformConvIm2ColImpl(half); INST_DeformConvIm2ColImpl(BFloat16); -template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const float*, float*, int64_t, int64_t, int64_t, int64_t); -template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const double*, double*, int64_t, int64_t, int64_t, int64_t); -template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const half*, half*, int64_t, int64_t, int64_t, int64_t); -template Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t, const BFloat16*, BFloat16*, int64_t, int64_t, int64_t, int64_t); - -template Status DeformConvAddBiasImpl(cudaStream_t, float*, const float*, int64_t, int64_t, int64_t, int64_t); -template Status DeformConvAddBiasImpl(cudaStream_t, double*, const double*, int64_t, int64_t, int64_t, int64_t); -template Status DeformConvAddBiasImpl(cudaStream_t, half*, const half*, int64_t, int64_t, int64_t, int64_t); -template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, float*, const float*, int64_t, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, double*, const double*, int64_t, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, half*, const half*, int64_t, int64_t, int64_t, int64_t, int64_t); +template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const BFloat16*, int64_t, int64_t, int64_t, int64_t, int64_t); // Delegate ORT type to CUDA type (e.g. MLFloat16 -> half); avoids repeating three identical specializations. #define DELEGATE_DEFORM_CONV_IMPL(ORT_T, CUDA_T) \ @@ -488,20 +667,10 @@ template Status DeformConvAddBiasImpl(cudaStream_t, BFloat16*, const B offset_group, use_mask); \ } \ template <> \ - Status DeformConvCopyGemmOutputRowMajorToNCHW(cudaStream_t stream, \ - const ORT_T* gemm_output, ORT_T* Y_g, \ - int64_t M, int64_t M_per_group, \ - int64_t output_image_size, int64_t cur_parallel) { \ - return DeformConvCopyGemmOutputRowMajorToNCHW(stream, \ - reinterpret_cast(gemm_output), \ - reinterpret_cast(Y_g), \ - M, M_per_group, output_image_size, cur_parallel); \ - } \ - template <> \ Status DeformConvAddBiasImpl(cudaStream_t stream, ORT_T * Y, const ORT_T* B, \ - int64_t N, int64_t M, int64_t out_h, int64_t out_w) { \ + int64_t N, int64_t M, int64_t out_h, int64_t out_w, int64_t max_grid_y) { \ return DeformConvAddBiasImpl(stream, reinterpret_cast(Y), \ - reinterpret_cast(B), N, M, out_h, out_w); \ + reinterpret_cast(B), N, M, out_h, out_w, max_grid_y); \ } // BFloat16 is not delegated: ORT's BFloat16 is the same type used in device code (ToCudaType in diff --git a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h index 0c26cb55311bc..60d0c27e7b081 100644 --- a/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h +++ b/onnxruntime/core/providers/cuda/nn/deform_conv_impl.h @@ -1,5 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// +// CUDA DeformConv kernel entry points (im2col, bias). Host pipeline and chunking: `deform_conv.cc` +// (see CPU `nn/deform_conv.cc` for the same high-level im2col → GEMM → bias flow). #pragma once @@ -10,7 +13,6 @@ namespace onnxruntime { namespace cuda { // Adds bias to output: Y[n,m,oh,ow] += B[m]. Y is [N, M, out_h, out_w], B is [M]. -// T may be float, double, MLFloat16 (FP16), or BFloat16. template Status DeformConvAddBiasImpl( cudaStream_t stream, @@ -19,22 +21,12 @@ Status DeformConvAddBiasImpl( int64_t N, int64_t M, int64_t out_h, - int64_t out_w); - -// Copies GEMM output (row-major [M_per_group, cur_parallel*output_image_size]) to NCHW slice at Y_g. -// T may be float, double, MLFloat16 (FP16), or BFloat16. -template -Status DeformConvCopyGemmOutputRowMajorToNCHW( - cudaStream_t stream, - const T* gemm_output, - T* Y_g, - int64_t M, - int64_t M_per_group, - int64_t output_image_size, - int64_t cur_parallel); + int64_t out_w, + int64_t max_grid_y); -// Fills col_buffer with deformable im2col. col_buffer layout: row-major [C*kH*kW, parallel_imgs*out_h*out_w]. -// Called once per batch block; caller does GEMM and bias. T may be float, double, MLFloat16 (FP16), or BFloat16. +// Fills col_buffer with deformable im2col. Row-major [C*kH*kW, parallel_imgs*out_h*out_w]: +// row = c * (kH*kW) + (i*kW + j), col = n * (out_h*out_w) + (oh*out_w + ow), same semantics as ONNX DeformConv im2col. +// Called once per batch chunk; caller GEMM + bias. template Status DeformConvIm2ColImpl( cudaStream_t stream, diff --git a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc index 860c0d2f08b18..0153886d480eb 100644 --- a/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc @@ -4,10 +4,13 @@ // Unit tests for DeformConv (CPU and Cuda), aligned with PyTorch Vision deform_conv2d tests. // Reference: https://github.com/pytorch/vision/blob/main/test/test_ops.py (TestDeformConv) +// No CI test for CUDA int64 index path: needs huge shapes to exceed INT32_MAX (often OOM/slow). + #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/testdata/deform_conv_test_data.inc" #include "test/unittest_util/conversion.h" +#include #if defined(USE_CUDA) #include "test/common/cuda_op_test_utils.h" @@ -109,14 +112,6 @@ void RunDeformConvTest(const DeformConvTestParams& params, params.stride[1] + 1; - OpTester test("DeformConv", opset); - test.AddAttribute("kernel_shape", params.kernel_shape); - test.AddAttribute("strides", params.stride); - test.AddAttribute("pads", params.pad); - test.AddAttribute("dilations", params.dilation); - test.AddAttribute("group", params.n_weight_grps); - test.AddAttribute("offset_group", params.n_offset_grps); - const std::vector X_shape = {params.batch_sz, params.n_in_channels, params.in_h, params.in_w}; const std::vector W_shape = {params.n_out_channels, params.n_in_channels / params.n_weight_grps, kH, kW}; const std::vector offset_shape = {params.batch_sz, params.n_offset_grps * 2 * kH * kW, out_h, out_w}; @@ -127,27 +122,58 @@ void RunDeformConvTest(const DeformConvTestParams& params, auto offset_t = DeformConvTestTraits::Convert(offset); auto expected_Y_t = DeformConvTestTraits::Convert(expected_Y); - test.AddInput("X", X_shape, X_t); - test.AddInput("W", W_shape, W_t); - test.AddInput("offset", offset_shape, offset_t); - if (omit_bias) { - test.AddOptionalInputEdge(); - } else { - auto B_t = DeformConvTestTraits::Convert(B); - test.AddInput("B", {params.n_out_channels}, B_t); - } - if (mask != nullptr) { - const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; - test.AddInput("mask", mask_shape, DeformConvTestTraits::Convert(*mask)); - } else { - test.AddOptionalInputEdge(); - } - const float rtol_f = static_cast(rtol); const float atol_f = static_cast(atol); - test.AddOutput("Y", Y_shape, expected_Y_t, false, rtol_f, atol_f); + auto run_once = [&](bool force_cuda) { + OpTester test("DeformConv", opset); + test.AddAttribute("kernel_shape", params.kernel_shape); + test.AddAttribute("strides", params.stride); + test.AddAttribute("pads", params.pad); + test.AddAttribute("dilations", params.dilation); + test.AddAttribute("group", params.n_weight_grps); + test.AddAttribute("offset_group", params.n_offset_grps); + + test.AddInput("X", X_shape, X_t); + test.AddInput("W", W_shape, W_t); + test.AddInput("offset", offset_shape, offset_t); + if (omit_bias) { + test.AddOptionalInputEdge(); + } else { + auto B_t = DeformConvTestTraits::Convert(B); + test.AddInput("B", {params.n_out_channels}, B_t); + } + if (mask != nullptr) { + const std::vector mask_shape = {params.batch_sz, params.n_offset_grps * kH * kW, out_h, out_w}; + test.AddInput("mask", mask_shape, DeformConvTestTraits::Convert(*mask)); + } else { + test.AddOptionalInputEdge(); + } + test.AddOutput("Y", Y_shape, expected_Y_t, false, rtol_f, atol_f); + + if (!force_cuda) { + test.Run(OpTester::ExpectResult::kExpectSuccess, "", DeformConvTestTraits::ExcludedProviders()); + return; + } - test.Run(OpTester::ExpectResult::kExpectSuccess, "", DeformConvTestTraits::ExcludedProviders()); +#if defined(USE_CUDA) + auto cuda_ep = DefaultCudaExecutionProvider(); + if (cuda_ep) { + std::vector> execution_providers; + execution_providers.emplace_back(std::move(cuda_ep)); + test.ConfigEps(std::move(execution_providers)).RunWithConfig(); + } +#endif + }; + + // Keep existing behavior first (typically CPU for float/double, CUDA for half/bfloat16). + run_once(false); + +#if defined(USE_CUDA) + // For types supported by both CPU and CUDA, additionally force a CUDA-only run. + if constexpr (std::is_same_v || std::is_same_v) { + run_once(true); + } +#endif } // MinimalBilinear test: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). @@ -190,6 +216,38 @@ void RunMinimalBilinearTest(int opset = 19, int min_cuda_arch = 0, bool omit_bia DeformConvTestTraits::DefaultRtol(), DeformConvTestTraits::DefaultAtol(), false); } } + +void RunChunkTailGroupedIdentityTest(int64_t n, int64_t group1_base) { + DeformConvTestParams p = {}; + p.batch_sz = n; + p.n_in_channels = 2; + p.n_out_channels = 2; + p.n_weight_grps = 2; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 1; + p.in_w = 1; + + std::vector X(static_cast(n * p.n_in_channels), 0.f); + for (int64_t i = 0; i < n; ++i) { + X[static_cast(i * 2)] = static_cast(i + 1); // group 0 input + X[static_cast(i * 2 + 1)] = static_cast(group1_base + i); // group 1 input + } + + std::vector W = {1.f, 1.f}; // identity mapping per group + std::vector offset(static_cast(n * 2), 0.f); + + std::vector expected_Y(static_cast(n * p.n_out_channels), 0.f); + for (int64_t i = 0; i < n; ++i) { + expected_Y[static_cast(i * 2)] = static_cast(i + 1); + expected_Y[static_cast(i * 2 + 1)] = static_cast(group1_base + i); + } + + RunDeformConvTest(p, X, W, offset, {} /* B unused */, nullptr, expected_Y, 19, 1e-5f, 1e-5f, true); +} } // namespace // Minimal case: 1x1 kernel, 2x2 input, one output position with fractional offset (bilinear). @@ -827,6 +885,114 @@ TEST(DeformConvTest, LargeBatchSize) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } +// output_image_size = 3*3 = 9 (not a multiple of 8): exercises AoSoA tail path in FillColRowFromSamplingPlanImpl (CPU). +TEST(DeformConvTest, OutputPixelsNotMultipleOf8_AoSoATail) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {1, 1}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 3; + const int64_t out_w = 3; + + std::vector X(1 * 1 * 3 * 3, 0.1f); + std::vector W(1 * 1 * 1 * 1, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 1 * 1 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 1 * 1 * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(static_cast(out_h * out_w), 0.01f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// Prime batch N=7: uneven last chunk vs fixed chunk size on CUDA (GetDeformConvParallelChunkSize path). +TEST(DeformConvTest, PrimeBatchSizeSeven) { + const int64_t N = 7; + DeformConvTestParams p = {}; + p.batch_sz = N; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {2, 2}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 3; + p.in_w = 3; + const int64_t out_h = 2; + const int64_t out_w = 2; + + const size_t x_size = static_cast(N * 1 * 3 * 3); + const size_t offset_size = static_cast(N * 1 * 2 * 2 * 2 * out_h * out_w); + const size_t mask_size = static_cast(N * 1 * 2 * 2 * out_h * out_w); + const size_t y_size = static_cast(N * 1 * out_h * out_w); + + std::vector X(x_size, 0.1f); + std::vector W(1 * 1 * 2 * 2, 0.1f); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + std::vector expected_Y(y_size, 0.04f); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + +// CUDA chunking regression guard: +// - Keep bytes_per_image tiny so target_parallel_imgs hits kMaxParallelImgs(32) on all CUDA devices. +// - N=993 -> balanced chunk size k=32 and final tail chunk size 1 (after many 32-sized chunks). +// - group=2 exercises cur_parallel==1 grouped GEMM path where per-group col stride must use cur_out_size. +TEST(DeformConvTest, ChunkTailOneWithGroups) { + RunChunkTailGroupedIdentityTest(/*n=*/993, /*group1_base=*/1000); +} + +// CUDA chunking regression guard for partial tail chunk: +// - Keep bytes_per_image tiny so target_parallel_imgs hits kMaxParallelImgs(32) on all CUDA devices. +// - N=33 -> balanced chunk size k=17 and final tail chunk size 16 (1 < cur_parallel < n_parallel_imgs). +// - group=2 exercises grouped path where each group's col base must use current cur_out_size. +TEST(DeformConvTest, ChunkTailPartialWithGroups) { + RunChunkTailGroupedIdentityTest(/*n=*/33, /*group1_base=*/2000); +} + +// 7x7 kernel with 9x9 input -> 3x3 output: exercises compile-time kH=kW=7 CUDA im2col specialization. +TEST(DeformConvTest, Kernel7x7) { + DeformConvTestParams p = {}; + p.batch_sz = 1; + p.n_in_channels = 1; + p.n_out_channels = 1; + p.n_weight_grps = 1; + p.n_offset_grps = 1; + p.kernel_shape = {7, 7}; + p.stride = {1, 1}; + p.pad = {0, 0, 0, 0}; + p.dilation = {1, 1}; + p.in_h = 9; + p.in_w = 9; + const int64_t out_h = 3; + const int64_t out_w = 3; + + std::vector X(1 * 1 * 9 * 9, 0.1f); + std::vector W(1 * 1 * 7 * 7, 0.1f); + const size_t offset_size = static_cast(1 * 1 * 2 * 7 * 7 * out_h * out_w); + const size_t mask_size = static_cast(1 * 1 * 7 * 7 * out_h * out_w); + std::vector offset(offset_size, 0.f); + std::vector mask(mask_size, 1.f); + std::vector B = {0.f}; + const float expected_val = 49.f * 0.1f * 0.1f; // 7x7 dot with uniform 0.1 * 0.1 + std::vector expected_Y(static_cast(out_h * out_w), expected_val); + + RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); +} + // group=1, offset_group=2: weights not grouped, offset/mask grouped. TEST(DeformConvTest, Group1OffsetGroup2) { DeformConvTestParams p = {}; @@ -860,7 +1026,7 @@ TEST(DeformConvTest, Group1OffsetGroup2) { RunDeformConvTest(p, X, W, offset, B, &mask, expected_Y); } -// Mask with zeros: exercises CUDA early-exit when mask_val == 0. +// Mask with zeros: verifies zero mask suppresses sampled values (val * mask == 0). TEST(DeformConvTest, MaskWithZeros) { DeformConvTestParams p = {}; p.batch_sz = 1;