diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 08840c623b709..9a13cd76e477a 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -747,7 +747,9 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
-|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)|
+|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|22+|**T1** = tensor(float)
**T2** = tensor(float)|
+|||[20, 21]|**T1** = tensor(float)
**T2** = tensor(float)|
+|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)|
|HardSigmoid|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|HardSwish|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
@@ -1062,7 +1064,9 @@ Do not modify directly.*
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
|GlobalAveragePool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GlobalMaxPool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)|
+|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|22+|**T1** = tensor(float)
**T2** = tensor(float)|
+|||[20, 21]|**T1** = tensor(float)
**T2** = tensor(float)|
+|||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)|
|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|MaxPool|*in* X:**T**
*out* Y:**T**
or
*in* X:**T**
*out* Y:**T**
*out* Indices:**I**|12+|**I** = tensor(int64)
**T** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)|
diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc
old mode 100644
new mode 100755
index 6f2538bcde3b1..f0b1a30b293ff
--- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc
+++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc
@@ -748,7 +748,7 @@ std::vector ChannelLastToFirstPerm(size_t rank) {
}
std::vector p(rank);
- p[0] = 0;
+ p[0] = 0; // This is usually the batch dimension (hence preserve this position)
p[1] = rank - 1;
for (size_t i = 2; i < rank; ++i) {
p[i] = i - 1;
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
old mode 100644
new mode 100755
index eb29e4edbf897..f2d3f28ba2bf6
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1421,7 +1421,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 19, float, GridSample);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -1510,6 +1510,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, 21, float, GridSample);
// Opset 21.
// TODO(fajin): support other quantized types
@@ -1583,6 +1584,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GridSample);
// Opset 23.
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention);
@@ -2485,7 +2487,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 17
BuildKernelCreateInfo,
@@ -2582,6 +2584,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 21
// TODO(fajin): support other quantized types
@@ -2654,6 +2657,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 23
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
old mode 100644
new mode 100755
index e8995a0ec623a..8239a8ac252e6
--- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
+++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
@@ -164,12 +164,17 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
#ifndef DISABLE_CONTRIB_OPS
namespace onnxruntime::contrib::cuda {
-class CUDA_NHWC_OP_TYPED_CLASS_NAME(16, float, GridSample);
+class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(16, 19, float, GridSample);
+class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(20, 21, float, GridSample);
+class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, float, GridSample);
onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn nhwc_function_table[] = {
BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+
};
for (auto& function_table_entry : nhwc_function_table) {
diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc
old mode 100644
new mode 100755
index 1884ae4689899..b9d47a27e8e83
--- a/onnxruntime/core/providers/cuda/tensor/grid_sample.cc
+++ b/onnxruntime/core/providers/cuda/tensor/grid_sample.cc
@@ -21,28 +21,66 @@ namespace cuda {
.TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
onnxruntime::contrib::cuda::GridSample);
+#define REGISTER_KERNEL_VERSIONED_TYPED(T, FROM_VERSION, TO_VERSION, LAYOUT, DOMAIN) \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ GridSample, \
+ DOMAIN, \
+ FROM_VERSION, \
+ TO_VERSION, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ onnxruntime::contrib::cuda::GridSample);
+
REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)
#ifdef ENABLE_CUDA_NHWC_OPS
-REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)
+// Op was introduced in opset 16
+REGISTER_KERNEL_VERSIONED_TYPED(float, 16, 19, LAYOUT_NHWC, kMSInternalNHWCDomain)
+
+// Op was modified to support multiple spatial dimensions in opset 20
+REGISTER_KERNEL_VERSIONED_TYPED(float, 20, 21, LAYOUT_NHWC, kMSInternalNHWCDomain)
+
+// Op spec introduced BFloat16 support in opset 22
+REGISTER_KERNEL_TYPED(float, 22, LAYOUT_NHWC, kMSInternalNHWCDomain)
#endif
template
GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
+ opset_start_version_ = info.node().SinceVersion();
+
std::string mode_str = info.GetAttrOrDefault("mode", "bilinear");
std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros");
align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0));
- ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic",
- "mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic");
- ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection",
- "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection");
- if (mode_str == "bicubic") {
- mode_i_ = 2;
- } else if (mode_str == "nearest") {
- mode_i_ = 1;
+
+ if (opset_start_version_ >= 20) {
+ std::string mode_str = info.GetAttrOrDefault("mode", "linear");
+ if (mode_str == "cubic") {
+ mode_i_ = 2;
+ } else if (mode_str == "nearest") {
+ mode_i_ = 1;
+ } else if (mode_str == "linear") {
+ mode_i_ = 0;
+ } else {
+ ORT_THROW("mode \"", mode_str, "\" not supported, expect linear, nearest or cubic");
+ }
} else {
- mode_i_ = 0;
+ std::string mode_str = info.GetAttrOrDefault("mode", "bilinear");
+ if (mode_str == "bicubic") {
+ mode_i_ = 2;
+ } else if (mode_str == "nearest") {
+ mode_i_ = 1;
+ } else if (mode_str == "bilinear") {
+ mode_i_ = 0;
+ } else {
+ ORT_THROW("mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic");
+ }
}
+
+ ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection",
+ "padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection");
if (padding_mode_str == "reflection") {
padding_mode_i_ = 2;
} else if (padding_mode_str == "border") {
@@ -59,20 +97,52 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
const Tensor* Grid = context->Input(1);
const auto& dims_grid = Grid->Shape().GetDims();
- if (dims_input.size() != 4 || dims_grid.size() != 4) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported");
+ if (dims_input.size() != dims_grid.size()) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Input and grid must have the same number of dimensions");
+ }
+
+ if (opset_start_version_ < 20 && dims_input.size() != 4) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Opset 16-19 versions of this op only supports 4-D input tensors");
+ }
+
+ if (dims_input[0] != dims_grid[0]) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Grid batch size does not match input batch size ");
+ }
+
+ if ((dims_input.size() == 4 && dims_grid[3] != 2) || (dims_input.size() == 5 && dims_grid[4] != 3)) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Last dimension of grid input must match the number of "
+ "spatial dimensions in the input (2 for 2D, 3 for 3D).");
+ }
+
+ if (dims_input.size() != 4 && dims_input.size() != 5) {
+ return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Only 4-D and 5-D input tensors are supported");
+ }
+
+ if (dims_input.size() == 5 && mode_i_ == 2) {
+ // This is common for CPU and CUDA to not support Cubic mode for 5D input
+ // So it won't break CUDA users who were previously dropping down to CPU version of the op.
+ return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Cubic mode is only supported in 4-D cases.");
}
- ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
- ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");
using Ch = Channels;
- TensorShapeVector dims_output(4);
- dims_output[Ch::N] = dims_input[Ch::N];
- dims_output[Ch::C] = dims_input[Ch::C];
- dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
- dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
+ TensorShapeVector dims_output(dims_input.size());
+ if (dims_input.size() == 4) {
+ dims_output[Ch::N] = dims_input[Ch::N];
+ dims_output[Ch::C] = dims_input[Ch::C];
+ dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
+ dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
+ } else {
+ // 5D input - deal with both NCHW and NHWC layouts
+ dims_output[0] = dims_input[0];
+ dims_output[1] = !IsNHWC ? dims_input[1] : dims_grid[1];
+ dims_output[2] = !IsNHWC ? dims_grid[1] : dims_grid[2];
+ dims_output[3] = !IsNHWC ? dims_grid[2] : dims_grid[3];
+ dims_output[4] = !IsNHWC ? dims_grid[3] : dims_input[4];
+ }
Tensor* Y = context->Output(0, dims_output);
+
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
return Status::OK();
@@ -80,23 +150,49 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType::MappedType CudaT;
CudaT* Y_data = reinterpret_cast(Y->MutableData());
- GridSampleImpl(
- Stream(context),
- reinterpret_cast(X->Data()),
- reinterpret_cast(Grid->Data()),
- mode_i_,
- padding_mode_i_,
- align_corners_,
- dims_input.data(),
- dims_grid[1],
- dims_grid[2],
- Y_data);
+
+ if (dims_input.size() == 4) {
+ // sample 2d
+ GridSampleImpl(
+ Stream(context),
+ reinterpret_cast(X->Data()),
+ reinterpret_cast(Grid->Data()),
+ mode_i_,
+ padding_mode_i_,
+ align_corners_,
+ dims_input.data(),
+ dims_grid[1],
+ dims_grid[2],
+ Y_data);
+ } else {
+ // sample 3d
+ GridSampleImpl3D(
+ Stream(context),
+ reinterpret_cast(X->Data()),
+ reinterpret_cast(Grid->Data()),
+ mode_i_,
+ padding_mode_i_,
+ align_corners_,
+ dims_input.data(),
+ dims_grid[1],
+ dims_grid[2],
+ dims_grid[3],
+ Y_data);
+ }
+
return Status::OK();
}
} // namespace cuda
} // namespace contrib
namespace cuda {
-REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
+// Op was introduced in opset 16
+REGISTER_KERNEL_VERSIONED_TYPED(float, 16, 19, LAYOUT_NCHW, kOnnxDomain)
+
+// Op was modified to support multiple spatial dimensions in opset 20
+REGISTER_KERNEL_VERSIONED_TYPED(float, 20, 21, LAYOUT_NCHW, kOnnxDomain)
+
+// Op spec introduced BFloat16 support in opset 22
+REGISTER_KERNEL_TYPED(float, 22, LAYOUT_NCHW, kOnnxDomain)
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample.h b/onnxruntime/core/providers/cuda/tensor/grid_sample.h
old mode 100644
new mode 100755
index 16581bfe77482..74be67c921aae
--- a/onnxruntime/core/providers/cuda/tensor/grid_sample.h
+++ b/onnxruntime/core/providers/cuda/tensor/grid_sample.h
@@ -22,6 +22,7 @@ class GridSample final : public CudaKernel {
int64_t mode_i_; // 0: bilinear (default), 1: nearest 2: bicubic
int64_t padding_mode_i_; // 0:'zeros', 1: 'border', 2:'reflection'
int64_t align_corners_;
+ int opset_start_version_;
};
} // namespace cuda
diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu
old mode 100644
new mode 100755
index b5b4a84576bbe..0e7d947741924
--- a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu
@@ -142,6 +142,10 @@ __global__ void _GridSampleKernel(
}
int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx;
+ // ONNX spec guideline about ordering of grid indices:
+ // Following computer vision convention, the coordinates in the length-r location vector
+ // are listed from the innermost tensor dimension to the outermost, the opposite of regular
+ // tensor indexing.
T grid_X = grid_data[grid_idx * 2 + 0];
T grid_Y = grid_data[grid_idx * 2 + 1];
int outIdx = idx;
@@ -159,9 +163,9 @@ __global__ void _GridSampleKernel(
if (align_corners) {
x_min = 0.0f;
- x_max = W_in - 1.0;
+ x_max = float(W_in) - 1.0f;
y_min = 0.0f;
- y_max = H_in - 1.0f;
+ y_max = float(H_in) - 1.0f;
}
float border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b
if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max ||
@@ -257,6 +261,258 @@ void GridSampleImpl(
SPECIALIZED_IMPL(float, false) // NCHW
SPECIALIZED_IMPL(float, true) // NHWC
+template
+__device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t z, int64_t y, int64_t x,
+ int64_t padding_mode, int64_t N, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) {
+ T pixel = 0.0f;
+
+ auto PixelOffset3D = [bIdx, cIdx, C, D, H, W](int64_t z, int64_t y, int64_t x) -> int64_t {
+ return Layout == LAYOUT_NCHW
+ ? (bIdx * C * D * H * W + cIdx * D * H * W + z * H * W + y * W + x)
+ : (bIdx * D * H * W * C + z * H * W * C + y * W * C + x * C + cIdx);
+ };
+
+ if (padding_mode == 0) { // zeros
+ if (z >= 0 && z < D && y >= 0 && y < H && x >= 0 && x < W) {
+ pixel = input_data[PixelOffset3D(z, y, x)];
+ }
+ } else if (padding_mode == 1) { // border
+ z = max((int64_t)0, min((int64_t)D - 1, (int64_t)z));
+ y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y));
+ x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x));
+
+ pixel = input_data[PixelOffset3D(z, y, x)];
+ } else { // Reflection
+ z = (int64_t)GsReflect(z, border[0], border[3]);
+ y = (int64_t)GsReflect(y, border[1], border[4]);
+ x = (int64_t)GsReflect(x, border[2], border[5]);
+
+ pixel = input_data[PixelOffset3D(z, y, x)];
+ }
+ return pixel;
+}
+
+template
+__global__ void _GridSampleKernel3D(
+ const T* __restrict__ input_data,
+ const T* __restrict__ grid_data,
+ const int64_t mode,
+ const int64_t padding_mode,
+ const int64_t align_corners,
+ const int64_t N,
+ const int64_t C,
+ const int64_t D_in,
+ const int64_t H_in,
+ const int64_t W_in,
+ const int64_t D_out,
+ const int64_t H_out,
+ const int64_t W_out,
+ T* __restrict__ output_data) {
+ CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * D_out * H_out * W_out);
+
+ // extract batch index, channel index, y index, x index, z index for current thread
+ int BIdx, cIdx, zIdx, yIdx, xIdx;
+ if constexpr (Layout == LAYOUT_NCHW) {
+ BIdx = idx / (C * D_out * H_out * W_out);
+ int tmpBCnt = BIdx * (C * D_out * H_out * W_out);
+
+ cIdx = (idx - tmpBCnt) / (D_out * H_out * W_out);
+ int tmpCCnt = tmpBCnt + cIdx * (D_out * H_out * W_out);
+
+ zIdx = (idx - tmpCCnt) / (H_out * W_out);
+ int tmpDCnt = tmpCCnt + zIdx * (H_out * W_out);
+
+ yIdx = (idx - tmpDCnt) / (W_out);
+ int tmpHCnt = tmpDCnt + yIdx * W_out;
+
+ xIdx = (idx - tmpHCnt);
+ } else { // Layout == LAYOUT_NHWC
+ BIdx = idx / (D_out * H_out * W_out * C);
+ int tmpBCnt = BIdx * (D_out * H_out * W_out * C);
+
+ zIdx = (idx - tmpBCnt) / (H_out * W_out * C);
+ int tmpDCnt = tmpBCnt + zIdx * (H_out * W_out * C);
+
+ yIdx = (idx - tmpDCnt) / (W_out * C);
+ int tmpHCnt = tmpDCnt + yIdx * (W_out * C);
+
+ xIdx = (idx - tmpHCnt) / C;
+ int tmpWCnt = tmpHCnt + xIdx * C;
+
+ cIdx = (idx - tmpWCnt);
+ }
+
+ int grid_idx = BIdx * D_out * H_out * W_out + zIdx * H_out * W_out + yIdx * W_out + xIdx;
+
+ // ONNX spec guideline about ordering of grid indices:
+ // Following computer vision convention, the coordinates in the length-r location vector
+ // are listed from the innermost tensor dimension to the outermost, the opposite of regular
+ // tensor indexing.
+ T grid_X = grid_data[grid_idx * 3 + 0];
+ T grid_Y = grid_data[grid_idx * 3 + 1];
+ T grid_Z = grid_data[grid_idx * 3 + 2];
+
+ int outIdx = idx;
+
+ T grid_x_volSpace = GsDenormalize(grid_X, W_in, align_corners == 1);
+ T grid_y_volSpace = GsDenormalize(grid_Y, H_in, align_corners == 1);
+ T grid_z_volSpace = GsDenormalize(grid_Z, D_in, align_corners == 1);
+
+ if (mode == 1) { // nearest
+ grid_x_volSpace = nearbyint(grid_x_volSpace);
+ grid_y_volSpace = nearbyint(grid_y_volSpace);
+ grid_z_volSpace = nearbyint(grid_z_volSpace);
+ }
+
+ float z_min = -0.5f;
+ float z_max = D_in - 0.5f;
+ float y_min = -0.5f;
+ float y_max = H_in - 0.5f;
+ float x_min = -0.5f;
+ float x_max = W_in - 0.5f;
+
+ if (align_corners) {
+ z_min = 0.0f;
+ z_max = float(D_in) - 1.0f;
+ y_min = 0.0f;
+ y_max = float(H_in) - 1.0f;
+ x_min = 0.0f;
+ x_max = float(W_in) - 1.0f;
+ }
+
+ float border[] = {z_min, y_min, x_min, z_max, y_max, x_max}; // zmin,ymin,xmin,zmax,ymax,xmax
+ if (grid_z_volSpace < z_min || grid_z_volSpace > z_max ||
+ grid_y_volSpace < y_min || grid_y_volSpace > y_max ||
+ grid_x_volSpace < x_min || grid_x_volSpace > x_max) { // out of bound
+ if (padding_mode == 1) { // border
+ // Clamping must not be done here, see #10607
+ // grid_z_volSpace = max(0.0f, min(grid_z_volSpace, D_in - 1.0f));
+ // grid_y_volSpace = max(0.0f, min(grid_y_volSpace, H_in - 1.0f));
+ // grid_x_volSpace = max(0.0f, min(grid_x_volSpace, W_in - 1.0f));
+ } else if (padding_mode == 2) { // reflection
+ grid_z_volSpace = GsReflect(grid_z_volSpace, z_min, z_max);
+ grid_y_volSpace = GsReflect(grid_y_volSpace, y_min, y_max);
+ grid_x_volSpace = GsReflect(grid_x_volSpace, x_min, x_max);
+ }
+ }
+
+ if (mode == 0) { // bilinear
+ int z1 = floor(grid_z_volSpace);
+ int y1 = floor(grid_y_volSpace);
+ int x1 = floor(grid_x_volSpace);
+ int z2 = z1 + 1;
+ int y2 = y1 + 1;
+ int x2 = x1 + 1;
+
+ // Weights
+ T w_lt_front = 0.0f;
+ T w_rt_front = 0.0f;
+ T w_lb_front = 0.0f;
+ T w_rb_front = 0.0f;
+
+ T w_lt_back = 0.0f;
+ T w_rt_back = 0.0f;
+ T w_lb_back = 0.0f;
+ T w_rb_back = 0.0f;
+
+ // Assign weight values
+ T w_back = grid_z_volSpace - z1;
+ T w_front = 1.0f - w_back;
+ T w_b = grid_y_volSpace - y1;
+ T w_t = 1.0f - w_b;
+ T w_r = grid_x_volSpace - x1;
+ T w_l = 1.0f - w_r;
+
+ w_lt_front = w_t * w_l * w_front;
+ w_rt_front = w_t * w_r * w_front;
+ w_lb_front = w_b * w_l * w_front;
+ w_rb_front = w_b * w_r * w_front;
+
+ w_lt_back = w_t * w_l * w_back;
+ w_rt_back = w_t * w_r * w_back;
+ w_lb_back = w_b * w_l * w_back;
+ w_rb_back = w_b * w_r * w_back;
+
+ T lt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border);
+ T rt_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border);
+ T lb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border);
+ T rb_front_v = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border);
+
+ T lt_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border);
+ T rt_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border);
+ T lb_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border);
+ T rb_back_v = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border);
+
+ T interpoV = w_lt_front * lt_front_v + w_rt_front * rt_front_v + w_lb_front * lb_front_v + w_rb_front * rb_front_v +
+ w_lt_back * lt_back_v + w_rt_back * rt_back_v + w_lb_back * lb_back_v + w_rb_back * rb_back_v;
+
+ output_data[outIdx] = interpoV;
+ return;
+ }
+ if (mode == 1) { // nearest
+ int x_n = grid_x_volSpace;
+ int y_n = grid_y_volSpace;
+ int z_n = grid_z_volSpace;
+
+ output_data[outIdx] =
+ PixelAtGrid3D(input_data, BIdx, cIdx, z_n, y_n, x_n, padding_mode, N, C, D_in, H_in, W_in, border);
+ return;
+ }
+ if (mode == 2) { // cubic
+ // Not implemented for 3D input. But will not reach here per input validation
+ }
+}
+
+template
+void GridSampleImpl3D(
+ cudaStream_t stream,
+ const T* input_data,
+ const T* grid_data,
+ const int64_t mode,
+ const int64_t padding_mode,
+ const int64_t align_corners,
+ const int64_t dims[5],
+ const int64_t D_out,
+ const int64_t H_out,
+ const int64_t W_out,
+ T* output_data) {
+ int64_t N = 0;
+ int64_t C = 0;
+ int64_t D_in = 0;
+ int64_t H_in = 0;
+ int64_t W_in = 0;
+
+ if constexpr (IsNHWC) {
+ N = dims[0];
+ D_in = dims[1];
+ H_in = dims[2];
+ W_in = dims[3];
+ C = dims[4];
+ } else {
+ N = dims[0];
+ C = dims[1];
+ D_in = dims[2];
+ H_in = dims[3];
+ W_in = dims[4];
+ }
+
+ int blocksPerGrid = static_cast(
+ ceil(static_cast(N * C * D_out * H_out * W_out) / GridDim::maxThreadsPerBlock));
+ _GridSampleKernel3D<<>>(
+ input_data, grid_data, mode, padding_mode, align_corners,
+ N, C, D_in, H_in, W_in,
+ D_out, H_out, W_out, output_data);
+}
+
+template void GridSampleImpl3D(cudaStream_t stream, const float* input_data, const float* grid_data,
+ const int64_t mode, const int64_t padding_mode, const int64_t align_corners,
+ const int64_t dims[5], const int64_t D_out, const int64_t H_out, const int64_t W_out,
+ float* output_data);
+
+template void GridSampleImpl3D(cudaStream_t stream, const float* input_data, const float* grid_data,
+ const int64_t mode, const int64_t padding_mode, const int64_t align_corners,
+ const int64_t dims[5], const int64_t D_out, const int64_t H_out, const int64_t W_out,
+ float* output_data);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h
old mode 100644
new mode 100755
index 62cd66a48fa84..156de20ed11e3
--- a/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h
+++ b/onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h
@@ -21,6 +21,20 @@ void GridSampleImpl(
const int64_t W_out,
T* output_data);
+template
+void GridSampleImpl3D(
+ cudaStream_t stream,
+ const T* input_data,
+ const T* grid_data,
+ const int64_t mode,
+ const int64_t padding_mode,
+ const int64_t align_corners,
+ const int64_t dims_input[5],
+ const int64_t D_out,
+ const int64_t H_out,
+ const int64_t W_out,
+ T* output_data);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
old mode 100644
new mode 100755
index 05cfb5c13d689..217233fe89d23
--- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
@@ -7,20 +7,16 @@
namespace onnxruntime {
namespace test {
-std::vector> GetExecutionProviders(int opset_version) {
- ORT_UNUSED_PARAMETER(opset_version);
-
+std::vector> GetExecutionProviders() {
std::vector> execution_providers;
execution_providers.emplace_back(DefaultCpuExecutionProvider());
#ifdef USE_CUDA
- if (opset_version < 20) {
- execution_providers.emplace_back(DefaultCudaExecutionProvider());
+ execution_providers.emplace_back(DefaultCudaExecutionProvider());
#ifdef ENABLE_CUDA_NHWC_OPS
- execution_providers.push_back(DefaultCudaNHWCExecutionProvider());
+ execution_providers.push_back(DefaultCudaNHWCExecutionProvider());
#endif
- }
#endif
#if defined(USE_COREML)
@@ -64,7 +60,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
@@ -84,7 +80,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
@@ -104,7 +100,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
@@ -124,7 +120,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_border_no_align_corner
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
@@ -144,7 +140,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_align_corne
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) {
@@ -164,7 +160,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_co
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
@@ -184,7 +180,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
@@ -204,7 +200,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corner
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
@@ -224,7 +220,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
@@ -244,7 +240,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corne
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
@@ -264,7 +260,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corn
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) {
@@ -284,7 +280,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_c
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
@@ -304,7 +300,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
@@ -324,7 +320,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) {
@@ -344,7 +340,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) {
@@ -364,7 +360,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corner
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) {
@@ -384,7 +380,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corne
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) {
@@ -404,7 +400,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_co
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(16));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
@@ -424,7 +420,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
@@ -444,7 +440,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) {
@@ -464,7 +460,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) {
@@ -484,7 +480,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners) {
@@ -504,7 +500,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners) {
@@ -524,7 +520,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) {
@@ -544,7 +540,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_border_no_align_corner
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) {
@@ -564,7 +560,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_border_no_align_corner
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) {
@@ -584,7 +580,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_align_corne
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) {
@@ -604,7 +600,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_align_corne
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) {
@@ -624,7 +620,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_co
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) {
@@ -644,7 +640,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_co
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) {
@@ -664,7 +660,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) {
@@ -684,7 +680,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) {
@@ -704,11 +700,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corner
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
-TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) {
- OpTester test("GridSample", 20);
+TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_zeros_no_align_corners) {
+ OpTester test("GridSample", 22);
std::string mode = "linear";
std::string padding_mode = "zeros";
int64_t align_corners = 0;
@@ -724,10 +720,10 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corner
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
-TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
+TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
OpTester test("GridSample", 20);
std::string mode = "linear";
std::string padding_mode = "border";
@@ -744,11 +740,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
-TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
- OpTester test("GridSample", 20);
+TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_align_corners) {
+ OpTester test("GridSample", 22);
std::string mode = "linear";
std::string padding_mode = "border";
int64_t align_corners = 1;
@@ -764,11 +760,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
-TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) {
- OpTester test("GridSample", 20);
+TYPED_TEST(GridSampleTest, test_grid_sample_22_4D_bilinear_border_no_align_corners) {
+ OpTester test("GridSample", 22);
std::string mode = "linear";
std::string padding_mode = "border";
int64_t align_corners = 0;
@@ -784,11 +780,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corne
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
-TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) {
- OpTester test("GridSample", 20);
+TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_border_no_align_corners) {
+ OpTester test("GridSample", 22);
std::string mode = "linear";
std::string padding_mode = "border";
int64_t align_corners = 0;
@@ -804,7 +800,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corne
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) {
@@ -824,11 +820,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corn
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
-TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) {
- OpTester test("GridSample", 20);
+TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_reflection_align_corners) {
+ OpTester test("GridSample", 22);
std::string mode = "linear";
std::string padding_mode = "reflection";
int64_t align_corners = 1;
@@ -844,7 +840,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corn
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) {
@@ -864,11 +860,11 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_c
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
-TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) {
- OpTester test("GridSample", 20);
+TYPED_TEST(GridSampleTest, test_grid_sample_22_5D_bilinear_reflection_no_align_corners) {
+ OpTester test("GridSample", 22);
std::string mode = "linear";
std::string padding_mode = "reflection";
int64_t align_corners = 0;
@@ -884,7 +880,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_c
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
@@ -904,7 +900,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) {
@@ -924,7 +920,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) {
@@ -944,7 +940,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) {
@@ -964,7 +960,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corner
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) {
@@ -984,7 +980,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corne
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) {
@@ -1004,7 +1000,7 @@ TYPED_TEST(GridSampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_co
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- RunTests(test, GetExecutionProviders(20));
+ RunTests(test, GetExecutionProviders());
}
} // namespace test