Skip to content

Commit

Permalink
Run lintrunner & remove nhwc hack.
Browse files Browse the repository at this point in the history
  • Loading branch information
mtavenrath committed Feb 21, 2024
1 parent dac6e76 commit 16bc365
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cuda/grid_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \
#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GridSample, \
DOMAIN, \
VERSION, \
VERSION, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
Expand Down Expand Up @@ -68,7 +68,7 @@ Status GridSample<T, IsNHWC>::ComputeInternal(OpKernelContext* context) const {
dims_output[Ch::N] = dims_input[Ch::N];
dims_output[Ch::C] = dims_input[Ch::C];
dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
Tensor* Y = context->Output(0, dims_output);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
Expand All @@ -94,6 +94,6 @@ Status GridSample<T, IsNHWC>::ComputeInternal(OpKernelContext* context) const {
} // namespace contrib

namespace cuda {
REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
} // namespace cuda
} // namespace onnxruntime
3 changes: 0 additions & 3 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,6 @@ struct CUDA_Provider : Provider {
info.cudnn_conv_use_max_workspace = params->cudnn_conv_use_max_workspace != 0;
info.enable_cuda_graph = params->enable_cuda_graph != 0;
info.prefer_nhwc = params->prefer_nhwc;
// HACK
info.prefer_nhwc = true;
//info.prefer_nhwc = false;
info.cudnn_conv1d_pad_to_nc1d = params->cudnn_conv1d_pad_to_nc1d != 0;
info.tunable_op.enable = params->tunable_op_enable;
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ struct NumericLimits<double> {
}
};

// TODO Where to put this? good places might be
// TODO Where to put this? good places might be
// core/framework/tensor_shape.h
// core/util/matrix_layout.h

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
print('test.AddAttribute("padding_mode", padding_mode);')
print('test.AddAttribute("align_corners", align_corners);')
print('test.AddOutput<float>("Y", Y_shape, Y_data);')
print(f'test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders({opset_version}));')
print(
f'test.Run(OpTester::ExpectResult::kExpectSuccess, "", GetExcludedExecutionProviders({opset_version}));'
)
print("}")
print("\n")

0 comments on commit 16bc365

Please sign in to comment.