Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for NHWC GridSample in the CUDA EP and enable grid_sample_test for all EPs #19562

Merged
merged 9 commits into from
Feb 23, 2024
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze);
#endif

#ifdef ENABLE_CUDA_NHWC_OPS
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample);
#endif

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
Expand Down Expand Up @@ -408,6 +412,9 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze)>,
#endif

#ifdef ENABLE_CUDA_NHWC_OPS
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample)>,

Check warning on line 416 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:416: Lines should be <= 120 characters long [whitespace/line_length] [2]
#endif
};

for (auto& function_table_entry : function_table) {
Expand Down
35 changes: 21 additions & 14 deletions onnxruntime/contrib_ops/cuda/grid_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,23 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GridSample, \
kMSDomain, \
1, \
DOMAIN, \
VERSION, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
GridSample<T>);
onnxruntime::contrib::cuda::GridSample<T, LAYOUT>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)
REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)

template <typename T>
GridSample<T>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
template <typename T, bool IsNHWC>
GridSample<T, IsNHWC>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
std::string mode_str = info.GetAttrOrDefault<std::string>("mode", "bilinear");
std::string padding_mode_str = info.GetAttrOrDefault<std::string>("padding_mode", "zeros");
align_corners_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("align_corners", 0));
Expand All @@ -48,8 +49,8 @@ GridSample<T>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
}
}

template <typename T>
Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
template <typename T, bool IsNHWC>
Status GridSample<T, IsNHWC>::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const auto& dims_input = X->Shape().GetDims();
const Tensor* Grid = context->Input<Tensor>(1);
Expand All @@ -61,11 +62,13 @@ Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");

using Ch = Channels<IsNHWC>;

TensorShapeVector dims_output(4);
dims_output[0] = dims_input[0];
dims_output[1] = dims_input[1];
dims_output[2] = dims_grid[1];
dims_output[3] = dims_grid[2];
dims_output[Ch::N] = dims_input[Ch::N];
dims_output[Ch::C] = dims_input[Ch::C];
dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
Tensor* Y = context->Output(0, dims_output);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
Expand All @@ -74,7 +77,7 @@ Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {

typedef typename ToCudaType<T>::MappedType CudaT;
CudaT* Y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());
GridSampleImpl<CudaT>(
GridSampleImpl<CudaT, IsNHWC>(
Stream(context),
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<const CudaT*>(Grid->Data<T>()),
Expand All @@ -89,4 +92,8 @@ Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
}
} // namespace cuda
} // namespace contrib

namespace cuda {
REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
} // namespace cuda
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/grid_sample.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace cuda {

using namespace onnxruntime::cuda;

template <typename T>
template <typename T, bool IsNHWC>
class GridSample final : public CudaKernel {
public:
explicit GridSample(const OpKernelInfo& info);
Expand Down
94 changes: 59 additions & 35 deletions onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,32 @@
return static_cast<T>(fx);
}

template <typename T>
template <typename T, bool Layout>
__device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x,
int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
T pixel = 0.0f;

auto PixelOffset = [bIdx, cIdx, N, C, H, W](int64_t x, int64_t y) -> int64_t {
return Layout == LAYOUT_NCHW ? (bIdx * C * H * W + cIdx * H * W + y * W + x) : (bIdx * H * W * C + y * W * C + x * C + cIdx);

Check warning on line 59 in onnxruntime/contrib_ops/cuda/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/grid_sample_impl.cu:59: Lines should be <= 120 characters long [whitespace/line_length] [2]
};

if (padding_mode == 0) { // zeros
if (x >= 0 && x < W && y >= 0 && y < H) {
pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
pixel = input_data[PixelOffset(x, y)];
}
} else if (padding_mode == 1) { //border
} else if (padding_mode == 1) { // border
x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x));
y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y));
pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
pixel = input_data[PixelOffset(x, y)];
} else { // Reflection
x = (int64_t) GsReflect<T>(x, border[0], border[2]);
y = (int64_t) GsReflect<T>(y, border[1], border[3]);
pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
x = (int64_t)GsReflect<T>(x, border[0], border[2]);
y = (int64_t)GsReflect<T>(y, border[1], border[3]);
pixel = input_data[PixelOffset(x, y)];
}
return pixel;
}

__device__ void GsGetCubicCoeffs(float x, float coeffs[4])
{
__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) {
float cubic_alpha = -0.75f;
x = abs(x);
coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha);
Expand All @@ -93,7 +97,7 @@
return pixel;
}

template <typename T>
template <typename T, bool Layout>
__global__ void _GridSampleKernel(
const T* input_data,
const T* grid_data,
Expand All @@ -110,16 +114,32 @@
{
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out);
// extract batch index, channel index, y index, x index for current thread
int BIdx = idx / (C * H_out * W_out );
int tmpBCnt = BIdx * (C * H_out * W_out);
int BIdx, yIdx, xIdx, cIdx;
if constexpr (Layout == LAYOUT_NCHW) {
BIdx = idx / (C * H_out * W_out);
int tmpBCnt = BIdx * (C * H_out * W_out);

cIdx = (idx - tmpBCnt) / (H_out * W_out);
int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);

int cIdx = (idx - tmpBCnt) / (H_out * W_out);
int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
yIdx = (idx - tmpCCnt) / W_out;
int tmpHCnt = tmpCCnt + yIdx * W_out;

int yIdx = (idx - tmpCCnt) / W_out;
int tmpHCnt = tmpCCnt + yIdx * W_out;
xIdx = (idx - tmpHCnt);
} else {
static_assert(Layout == LAYOUT_NHWC, "Unsupported layout");

int xIdx = (idx - tmpHCnt);
BIdx = idx / (H_out * W_out * C);
int tmpBCnt = BIdx * (H_out * W_out * C);

yIdx = (idx - tmpBCnt) / (W_out * C);
int tmpHCnt = tmpBCnt + yIdx * (W_out * C);

xIdx = (idx - tmpHCnt) / C;
int tmpWCnt = tmpHCnt + xIdx * C;

cIdx = (idx - tmpWCnt);
}

int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx;
T grid_X = grid_data[grid_idx * 2 + 0];
Expand Down Expand Up @@ -147,8 +167,9 @@
if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max ||
grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound
if (padding_mode == 1) { // border
grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
// Clamping must not be done here, see #10607
//grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));

Check warning on line 171 in onnxruntime/contrib_ops/cuda/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cuda/grid_sample_impl.cu:171: Should have a space between // and comment [whitespace/comments] [4]
//grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));

Check warning on line 172 in onnxruntime/contrib_ops/cuda/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cuda/grid_sample_impl.cu:172: Should have a space between // and comment [whitespace/comments] [4]
} else if (padding_mode == 2) { // reflection
grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max);
grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max);
Expand All @@ -175,18 +196,18 @@
w_lb = w_b * w_l;
w_rb = w_b * w_r;

T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
T lt_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
T rt_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
T lb_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
T rb_v = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v;
output_data[outIdx] = interpoV;
return;
}
if (mode == 1) { // nearest
int x_n = grid_x_imgSpace;
int y_n = grid_y_imgSpace;
output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
output_data[outIdx] = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);

Check warning on line 210 in onnxruntime/contrib_ops/cuda/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/grid_sample_impl.cu:210: Lines should be <= 120 characters long [whitespace/line_length] [2]
return;
}
if (mode == 2) { // bicubic
Expand All @@ -195,7 +216,7 @@
T p[4][4] = {}; // [H][W]
for (int64_t h = 0; h < 4; h++) {
for (int64_t w = 0; w < 4; w++) {
p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
p[h][w] = PixelAtGrid<T, Layout>(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);

Check warning on line 219 in onnxruntime/contrib_ops/cuda/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/grid_sample_impl.cu:219: Lines should be <= 120 characters long [whitespace/line_length] [2]
}
}
T dx = grid_x_imgSpace - x0 - 1;
Expand All @@ -204,7 +225,7 @@
}
}

template <typename T>
template <typename T, bool IsNHWC>
void GridSampleImpl(
cudaStream_t stream,
const T* input_data,
Expand All @@ -216,17 +237,20 @@
const int64_t H_out,
const int64_t W_out,
T* output_data) {
int blocksPerGrid = (int)(ceil(static_cast<T>(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock));
_GridSampleKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data);
using Ch = Channels<IsNHWC>;

int blocksPerGrid = (int)(ceil(static_cast<T>(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock));

Check warning on line 242 in onnxruntime/contrib_ops/cuda/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/grid_sample_impl.cu:242: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 242 in onnxruntime/contrib_ops/cuda/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/grid_sample_impl.cu:242: Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4]
_GridSampleKernel<T, IsNHWC><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_data, grid_data, mode, padding_mode, align_corners, dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], H_out, W_out, output_data);

Check warning on line 244 in onnxruntime/contrib_ops/cuda/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/grid_sample_impl.cu:244: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

#define SPECIALIZED_IMPL(T) \
template void GridSampleImpl<T>(cudaStream_t stream, const T* input_data, const T* grid_data, \
const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
#define SPECIALIZED_IMPL(T, IsNHWC) \
template void GridSampleImpl<T, IsNHWC>(cudaStream_t stream, const T* input_data, const T* grid_data, \
const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);

SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(float, false) // NCHW
SPECIALIZED_IMPL(float, true) // NHWC

} // namespace cuda
} // namespace contrib
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/grid_sample_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

template <typename T>
template <typename T, bool IsNHWC>
void GridSampleImpl(
cudaStream_t stream,
const T* input_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
}

#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
// TODO generate list from registered kernels using nhwc domain

Check warning on line 34 in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc:34: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
const std::unordered_set<std::string_view>& GetCUDALayoutSensitiveOps() {
static std::unordered_set<std::string_view> cuda_nhwc_ops = []() {
return std::unordered_set<std::string_view>{
Expand All @@ -41,6 +42,7 @@
"MaxPool",
"GlobalAveragePool",
"AveragePool",
"GridSample",
};
}();
return cuda_nhwc_ops;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);

// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
Expand Down Expand Up @@ -2143,6 +2144,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample)>,

// Opset 17
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,31 @@ struct NumericLimits<double> {
}
};

// TODO Where to put this? good places might be
// core/framework/tensor_shape.h
// core/util/matrix_layout.h

constexpr bool LAYOUT_NCHW = false;
constexpr bool LAYOUT_NHWC = true;

template <bool IsNHWC>
struct Channels;

template <>
struct Channels<LAYOUT_NHWC> {
static constexpr size_t N = 0;
static constexpr size_t H = 1;
static constexpr size_t W = 2;
static constexpr size_t C = 3;
};

template <>
struct Channels<LAYOUT_NCHW> {
static constexpr size_t N = 0;
static constexpr size_t C = 1;
static constexpr size_t H = 2;
static constexpr size_t W = 3;
};

} // namespace cuda
} // namespace onnxruntime
Loading
Loading