Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,9 @@ Do not modify directly.*
|||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|22+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(int64)|
Comment thread
tianleiwu marked this conversation as resolved.
|||[16, 21]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(int64)|
|||[10, 15]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
|RotaryEmbedding|*in* X:**T**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**M**<br> *out* Y:**T**|23+|**M** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(float), tensor(float16)|
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cpu/object_detection/roialign.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ class RoiAlignBase {
std::string coordinate_transformation_mode;
if (info.template GetAttr<std::string>("coordinate_transformation_mode", &coordinate_transformation_mode).IsOK()) {
half_pixel_ = coordinate_transformation_mode == "half_pixel";
} else {
// For opset 16+, the default is "half_pixel" per ONNX spec.
// For opset 10 (which has no coordinate_transformation_mode attribute), false is correct.
half_pixel_ = info.node().SinceVersion() >= 16;
}

if (mode_ == RoiAlignMode::max && sampling_ratio_ != 1) {
Expand Down
22 changes: 18 additions & 4 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -944,8 +944,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, ReverseSequence);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, float, RoiAlign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, double, RoiAlign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, MLFloat16, RoiAlign);
Comment thread
tianleiwu marked this conversation as resolved.
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu);
Expand Down Expand Up @@ -1601,6 +1604,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GRU);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, GRU);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, GRU);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign);

// Opset 23.
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention);
Expand Down Expand Up @@ -2042,8 +2049,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, ReverseSequence)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, float, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, double, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, MLFloat16, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu)>,
Expand Down Expand Up @@ -2700,6 +2710,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GRU)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, GRU)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, GRU)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign)>,

// Opset 23
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention)>,
Expand Down
45 changes: 40 additions & 5 deletions onnxruntime/core/providers/cuda/object_detection/roialign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,37 @@
namespace onnxruntime {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
#define ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
RoiAlign, \
kOnnxDomain, \
10, \
15, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int64_t>()), \
RoiAlign<T>);

#define ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
RoiAlign, \
kOnnxDomain, \
16, \
21, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int64_t>()), \
RoiAlign<T>);

#define ADD_TYPED_ROIALIGN_OP_22(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RoiAlign, \
kOnnxDomain, \
22, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
Expand Down Expand Up @@ -67,13 +93,22 @@ Status RoiAlign<T>::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}

#define SPECIALIZED_COMPUTE(T) \
REGISTER_KERNEL_TYPED(T) \
#define SPECIALIZED_COMPUTE(T) \
ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \
ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \
ADD_TYPED_ROIALIGN_OP_22(T) \
template Status RoiAlign<T>::ComputeInternal(OpKernelContext* ctx) const;

SPECIALIZED_COMPUTE(float)
SPECIALIZED_COMPUTE(double)
// SPECIALIZED_COMPUTE(MLFloat16)
// MLFloat16 is available for RoiAlign op from version 16 (not version 10):
ADD_VERSIONED_TYPED_ROIALIGN_OP_16(MLFloat16)
ADD_TYPED_ROIALIGN_OP_22(MLFloat16)
template Status RoiAlign<MLFloat16>::ComputeInternal(OpKernelContext* ctx) const;

// BFloat16 is available for RoiAlign op from version 22:
ADD_TYPED_ROIALIGN_OP_22(BFloat16)
template Status RoiAlign<BFloat16>::ComputeInternal(OpKernelContext* ctx) const;

} // namespace cuda
}; // namespace onnxruntime
112 changes: 63 additions & 49 deletions onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,72 @@

#include "roialign_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/shared_inc/accumulation_type.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
__device__ T bilinear_interpolate(
__device__ AccumulationType_t<T> bilinear_interpolate(
const T* bottom_data,
const int height,
const int width,
T y,
T x,
AccumulationType_t<T> y,
AccumulationType_t<T> x,
const bool is_mode_avg,
const int index /* index for debug only*/) {
using TAcc = AccumulationType_t<T>;

// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
if (y < static_cast<TAcc>(-1.0f) || y > static_cast<TAcc>(height) ||
x < static_cast<TAcc>(-1.0f) || x > static_cast<TAcc>(width)) {
// empty
return 0;
return static_cast<TAcc>(0.0f);
}

if (y <= 0) {
y = 0;
if (y <= static_cast<TAcc>(0.0f)) {
y = static_cast<TAcc>(0.0f);
}
if (x <= 0) {
x = 0;
if (x <= static_cast<TAcc>(0.0f)) {
x = static_cast<TAcc>(0.0f);
}

int y_low = (int)y;
int x_low = (int)x;
int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x);
int y_high;
int x_high;

if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
y = static_cast<TAcc>(y_low);
} else {
y_high = y_low + 1;
}

if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
x = static_cast<TAcc>(x_low);
} else {
x_high = x_low + 1;
}

T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
TAcc ly = y - static_cast<TAcc>(y_low);
TAcc lx = x - static_cast<TAcc>(x_low);
TAcc hy = static_cast<TAcc>(1.0f) - ly;
TAcc hx = static_cast<TAcc>(1.0f) - lx;
// do bilinear interpolation
T v1 = bottom_data[y_low * width + x_low];
T v2 = bottom_data[y_low * width + x_high];
T v3 = bottom_data[y_high * width + x_low];
T v4 = bottom_data[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
TAcc v1 = static_cast<TAcc>(bottom_data[y_low * width + x_low]);
TAcc v2 = static_cast<TAcc>(bottom_data[y_low * width + x_high]);
TAcc v3 = static_cast<TAcc>(bottom_data[y_high * width + x_low]);
TAcc v4 = static_cast<TAcc>(bottom_data[y_high * width + x_high]);
TAcc w1 = hy * hx;
TAcc w2 = hy * lx;
TAcc w3 = ly * hx;
TAcc w4 = ly * lx;

T val = is_mode_avg
? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg
: max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max
TAcc val = is_mode_avg
? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg
: max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max
Comment thread
tianleiwu marked this conversation as resolved.

return val;
}
Expand All @@ -97,6 +105,8 @@ __global__ void RoIAlignForward(
const bool half_pixel,
const int64_t* batch_indices_ptr,
const int64_t batch_size) {
using TAcc = AccumulationType_t<T>;

for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
Expand All @@ -111,53 +121,55 @@ __global__ void RoIAlignForward(
// If the index is out of range, we set the output to 0 for this RoI element.
if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) {
CUDA_KERNEL_ASSERT(false && "batch_indices values are out of range");
top_data[index] = 0;
top_data[index] = static_cast<T>(0.0f);
continue;
}

// Do not using rounding; this implementation detail is critical
T roi_offset = half_pixel ? T(0.5) : T(0);
T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset;
T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset;
T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset;
T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset;

T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
const TAcc spatial_scale_acc = static_cast<TAcc>(spatial_scale);
const TAcc roi_offset = half_pixel ? static_cast<TAcc>(0.5f) : static_cast<TAcc>(0.0f);
TAcc roi_start_w = static_cast<TAcc>(offset_bottom_rois[0]) * spatial_scale_acc - roi_offset;
TAcc roi_start_h = static_cast<TAcc>(offset_bottom_rois[1]) * spatial_scale_acc - roi_offset;
TAcc roi_end_w = static_cast<TAcc>(offset_bottom_rois[2]) * spatial_scale_acc - roi_offset;
TAcc roi_end_h = static_cast<TAcc>(offset_bottom_rois[3]) * spatial_scale_acc - roi_offset;

TAcc roi_width = roi_end_w - roi_start_w;
TAcc roi_height = roi_end_h - roi_start_h;
if (!half_pixel) { // backward compatibility
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
roi_width = max(roi_width, static_cast<TAcc>(1.0f));
roi_height = max(roi_height, static_cast<TAcc>(1.0f));
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const TAcc bin_size_h = roi_height / static_cast<TAcc>(pooled_height);
const TAcc bin_size_w = roi_width / static_cast<TAcc>(pooled_width);

const T* offset_bottom_data =
bottom_data + static_cast<int64_t>((roi_batch_ind * channels + c) * height * width);

// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: _Ceil(roi_height / pooled_height); // e.g., = 2
: static_cast<int>(_Ceil(roi_height / static_cast<TAcc>(pooled_height))); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : _Ceil(roi_width / pooled_width);
(sampling_ratio > 0) ? sampling_ratio : static_cast<int>(_Ceil(roi_width / static_cast<TAcc>(pooled_width)));

// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
const int grid_count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
const TAcc count = static_cast<TAcc>(grid_count); // e.g. = 4

T output_val = 0.;
TAcc output_val = static_cast<TAcc>(0.0f);
bool max_flag = false;
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
const TAcc y = roi_start_h + static_cast<TAcc>(ph) * bin_size_h +
(static_cast<TAcc>(iy) + static_cast<TAcc>(0.5f)) * bin_size_h /
static_cast<TAcc>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
const TAcc x = roi_start_w + static_cast<TAcc>(pw) * bin_size_w +
(static_cast<TAcc>(ix) + static_cast<TAcc>(0.5f)) * bin_size_w /
static_cast<TAcc>(roi_bin_grid_w);

T val = bilinear_interpolate(
const TAcc val = bilinear_interpolate(
offset_bottom_data, height, width, y, x, is_mode_avg, index);

if (is_mode_avg) {
Expand All @@ -176,7 +188,7 @@ __global__ void RoIAlignForward(
output_val /= count;
}

top_data[index] = output_val;
top_data[index] = static_cast<T>(output_val);
}
}

Expand Down Expand Up @@ -241,6 +253,8 @@ void RoiAlignImpl(

SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
SPECIALIZED_IMPL(half)
SPECIALIZED_IMPL(BFloat16)

} // namespace cuda
} // namespace onnxruntime
Loading
Loading