Skip to content

Commit 2eb104b

Browse files
author
Michalis Papapdimitriou
committed
Address PR comments
1 parent 5bdd0ed commit 2eb104b

File tree

2 files changed

+102
-20
lines changed

2 files changed

+102
-20
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 96 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@
2828
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
2929

3030
logger = logging.getLogger("TensorRT")
31+
supported_types = ["float32", "float16"]
32+
33+
34+
def is_supported_trt_dtype(args):
35+
"""Check if the TensorRT BYOC support input tensor dtype.
36+
Returns
37+
-------
38+
ret: bool
39+
True if supported, False if not.
40+
"""
41+
if any([x.checked_type.dtype in supported_types for x in args]):
42+
logger.info("Only float32 and float16 inputs are supported for TensorRT BYOC.")
43+
return True
44+
return False
3145

3246

3347
def is_tensorrt_runtime_enabled():
@@ -113,8 +127,10 @@ def partition_for_tensorrt(
113127
How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
114128
See TensorRT documentation for more info.
115129
use_fp16: Optional[bool]
116-
Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled if FP16 inputs tensors and weights are used.
117-
Note that TensorRT will still choose a higher-precision kernel if it results in overall lower runtime, or if no low-precision implementation exists.
130+
Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled
131+
if FP16 inputs tensors and weights are used.
132+
Note that TensorRT will still choose a higher-precision kernel if it results in overall
133+
lower runtime, or if no low-precision implementation exists.
118134
use_uint8: Optional[bool]
119135
Allows, TRT to automatically convert FP32 inputs to UINT8.
120136
Returns
@@ -209,6 +225,8 @@ def _register_external_op_helper_with_checker(op_name, checker):
209225
def _func_wrapper(expr):
210226
attrs, args = expr.attrs, expr.args
211227
# ops with dynamic shapes are offloaded to VM
228+
if not is_supported_trt_dtype(args):
229+
return False
212230
if check_dynamism(args, op_name):
213231
return False
214232
if op_name == "multiply":
@@ -321,7 +339,8 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
321339
"""Check if add is supported by TensorRT."""
322340

323341
args = expr.args
324-
342+
if not is_supported_trt_dtype(args):
343+
return False
325344
shapes = [
326345
[int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape]
327346
for arg in args
@@ -350,6 +369,8 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
350369
"""Check if nn.batch_norm is supported by TensorRT."""
351370

352371
attrs, args = expr.attrs, expr.args
372+
if not is_supported_trt_dtype(args):
373+
return False
353374
if len(args[0].checked_type.shape) == 5 and get_tensorrt_version() < (6, 0, 1):
354375
logger.info("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs.")
355376
return False
@@ -366,7 +387,9 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
366387
def softmax_annotate_fn(expr): # pylint: disable=unused-variable
367388
"""Check if nn.softmax is supported by TensorRT."""
368389

369-
attrs = expr.attrs
390+
attrs, args = expr.attrs, expr.args
391+
if not is_supported_trt_dtype(args):
392+
return False
370393
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
371394
logger.info("nn.softmax: can't modify batch dimension.")
372395
return False
@@ -377,7 +400,9 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
377400
def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
378401
"""Check if nn.conv1d is supported by TensorRT."""
379402

380-
attrs = expr.attrs
403+
attrs, args = expr.attrs, expr.args
404+
if not is_supported_trt_dtype(args):
405+
return False
381406
if attrs.data_layout != "NCW":
382407
logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout)
383408
return False
@@ -391,7 +416,9 @@ def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
391416
def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
392417
"""Check if nn.conv2d is supported by TensorRT."""
393418

394-
attrs = expr.attrs
419+
attrs, args = expr.attrs, expr.args
420+
if not is_supported_trt_dtype(args):
421+
return False
395422
if attrs.data_layout != "NCHW":
396423
logger.info("nn.conv2d: data_layout is %s but must be NCHW.", attrs.data_layout)
397424
return False
@@ -409,6 +436,8 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
409436
"""Check if dense is supported by TensorRT."""
410437

411438
args = expr.args
439+
if not is_supported_trt_dtype(args):
440+
return False
412441
input_rank = len(args[0].checked_type.shape)
413442
weight_rank = len(args[1].checked_type.shape)
414443
if input_rank not in (2, 3, 4):
@@ -424,6 +453,9 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
424453
def batch_matmul_annotate_fn(expr):
425454
"""Check if dense is supported by TensorRT."""
426455

456+
args = expr.args
457+
if not is_supported_trt_dtype(args):
458+
return False
427459
if get_tensorrt_use_implicit_batch_mode() and len(expr.args[0].checked_type.shape) != len(
428460
expr.args[1].checked_type.shape
429461
):
@@ -436,6 +468,9 @@ def batch_matmul_annotate_fn(expr):
436468
def layer_norm_annotate_fn(expr):
437469
"""Check if dense is supported by TensorRT."""
438470

471+
args = expr.args
472+
if not is_supported_trt_dtype(args):
473+
return False
439474
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
440475
logger.info("nn.layer_norm: requires use_implict_batch=False.")
441476
return False
@@ -446,7 +481,9 @@ def layer_norm_annotate_fn(expr):
446481
def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
447482
"""Check if nn.bias_add is supported by TensorRT."""
448483

449-
args = expr.args
484+
attrs, args = expr.attrs, expr.args
485+
if not is_supported_trt_dtype(args):
486+
return False
450487
input_rank = len(args[0].checked_type.shape)
451488
if input_rank not in (2, 3, 4):
452489
logger.info("nn.bias_add: input rank is %d but must be 2, 3 or 4.", input_rank)
@@ -458,7 +495,9 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
458495
def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
459496
"""Check if nn.max_pool2d is supported by TensorRT."""
460497

461-
attrs = expr.attrs
498+
attrs, args = expr.attrs, expr.args
499+
if not is_supported_trt_dtype(args):
500+
return False
462501
if attrs.layout != "NCHW":
463502
logger.info("nn.max_pool2d: layout is %s but must be NCHW.", attrs.layout)
464503
return False
@@ -472,7 +511,9 @@ def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
472511
def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
473512
"""Check if nn.avg_pool2d is supported by TensorRT."""
474513

475-
attrs = expr.attrs
514+
attrs, args = expr.attrs, expr.args
515+
if not is_supported_trt_dtype(args):
516+
return False
476517
if attrs.layout != "NCHW":
477518
logger.info("nn.avg_pool2d: layout is %d but must be NCHW.", attrs.layout)
478519
return False
@@ -499,7 +540,9 @@ def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
499540
def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
500541
"""Check if nn.global_max_pool2d is supported by TensorRT."""
501542

502-
attrs = expr.attrs
543+
attrs, args = expr.attrs, expr.args
544+
if not is_supported_trt_dtype(args):
545+
return False
503546
if attrs.layout != "NCHW":
504547
logger.info("nn.global_max_pool2d: layout is %s but must be NCHW.", attrs.layout)
505548
return False
@@ -510,7 +553,9 @@ def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
510553
def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
511554
"""Check if nn.global_avg_pool2d is supported by TensorRT."""
512555

513-
attrs = expr.attrs
556+
attrs, args = expr.attrs, expr.args
557+
if not is_supported_trt_dtype(args):
558+
return False
514559
if attrs.layout != "NCHW":
515560
logger.info("nn.global_avg_pool2d: layout is %s but must be NCHW.", attrs.layout)
516561
return False
@@ -521,7 +566,9 @@ def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
521566
def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
522567
"""Check if expand_dims is supported by TensorRT."""
523568

524-
attrs = expr.attrs
569+
attrs, args = expr.attrs, expr.args
570+
if not is_supported_trt_dtype(args):
571+
return False
525572
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axis) == 0:
526573
logger.info("expand_dims: can't modify batch dimension.")
527574
return False
@@ -532,7 +579,9 @@ def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
532579
def squeeze_annotate_fn(expr): # pylint: disable=unused-variable
533580
"""Check if squeeze is supported by TensorRT."""
534581

535-
attrs = expr.attrs
582+
attrs, args = expr.attrs, expr.args
583+
if not is_supported_trt_dtype(args):
584+
return False
536585
if not attrs.axis:
537586
logger.info("squeeze: must explicitly set axis.")
538587
return False
@@ -547,6 +596,8 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
547596
"""Check if concatenate is supported by TensorRT."""
548597

549598
attrs, args = expr.attrs, expr.args
599+
if not is_supported_trt_dtype(args):
600+
return False
550601
if not get_tensorrt_use_implicit_batch_mode():
551602
return True
552603
if int(attrs.axis) == 0:
@@ -564,6 +615,9 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
564615
def split_annotate_fn(expr):
565616
"""Check if split is supported by TensorRT."""
566617

618+
attrs, args = expr.attrs, expr.args
619+
if not is_supported_trt_dtype(args):
620+
return False
567621
if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
568622
logger.info("split: can't modify batch dimension.")
569623
return False
@@ -574,7 +628,9 @@ def split_annotate_fn(expr):
574628
def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
575629
"""Check if nn.conv2d_transpose is supported by TensorRT."""
576630

577-
attrs = expr.attrs
631+
attrs, args = expr.attrs, expr.args
632+
if not is_supported_trt_dtype(args):
633+
return False
578634
if attrs.data_layout != "NCHW":
579635
logger.info("nn.conv2d_transpose: data_layout is %s but must be NCHW.", attrs.data_layout)
580636
return False
@@ -596,7 +652,9 @@ def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
596652
def transpose_annotate_fn(expr): # pylint: disable=unused-variable
597653
"""Check if transpose is supported by TensorRT."""
598654

599-
attrs = expr.attrs
655+
attrs, args = expr.attrs, expr.args
656+
if not is_supported_trt_dtype(args):
657+
return False
600658
if get_tensorrt_use_implicit_batch_mode() and int(attrs.axes[0]) != 0:
601659
logger.info("transpose: can't modify batch dimension.")
602660
return False
@@ -607,7 +665,9 @@ def transpose_annotate_fn(expr): # pylint: disable=unused-variable
607665
def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
608666
"""Check if layout_transform is supported by TensorRT."""
609667

610-
attrs = expr.attrs
668+
attrs, args = expr.attrs, expr.args
669+
if not is_supported_trt_dtype(args):
670+
return False
611671
if (attrs.src_layout, attrs.dst_layout) not in [
612672
("NCHW", "NHWC"),
613673
("NHWC", "NCHW"),
@@ -625,6 +685,8 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
625685
def reshape_annotate_fn(expr): # pylint: disable=unused-variable
626686
"""Check if reshape is supported by TensorRT."""
627687
attrs, args = expr.attrs, expr.args
688+
if not is_supported_trt_dtype(args):
689+
return False
628690
if any([x < -1 for x in map(int, attrs.newshape)]):
629691
logger.info("reshape: new shape dims must be explicit.")
630692
return False
@@ -680,6 +742,8 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
680742
"""Check if nn.pad is supported by TensorRT."""
681743

682744
attrs, args = expr.attrs, expr.args
745+
if not is_supported_trt_dtype(args):
746+
return False
683747
pad_value = args[1]
684748
assert isinstance(pad_value, relay.Constant)
685749
pad_value = pad_value.data.numpy().item()
@@ -706,6 +770,8 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
706770
"""Check if strided_slice is supported by TensorRT."""
707771

708772
attrs, args = expr.attrs, expr.args
773+
if not is_supported_trt_dtype(args):
774+
return False
709775
if not trt_version_annotate_fn((5, 1, 5))(attrs, args, "strided_slice"):
710776
return False
711777
if get_tensorrt_use_implicit_batch_mode():
@@ -750,7 +816,9 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
750816
def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
751817
"""Check if nn.adaptive_max_pool2d is supported by TensorRT."""
752818

753-
attrs = expr.attrs
819+
attrs, args = expr.attrs, expr.args
820+
if not is_supported_trt_dtype(args):
821+
return False
754822
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
755823
logger.info("nn.adaptive_max_pool2d: output size must be (1, 1).")
756824
return False
@@ -761,7 +829,9 @@ def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
761829
def adaptive_avg_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
762830
"""Check if nn.adaptive_avg_pool2d is supported by TensorRT."""
763831

764-
attrs = expr.attrs
832+
attrs, args = expr.attrs, expr.args
833+
if not is_supported_trt_dtype(args):
834+
return False
765835
if len(attrs.output_size) == 0 or any([size != 1 for size in map(int, attrs.output_size)]):
766836
logger.info("nn.adaptive_avg_pool2d: output size must be (1, 1).")
767837
return False
@@ -773,6 +843,8 @@ def conv3d_annotate_fn(expr): # pylint: disable=unused-variable
773843
"""Check if nn.conv3d is supported by TensorRT."""
774844

775845
attrs, args = expr.attrs, expr.args
846+
if not is_supported_trt_dtype(args):
847+
return False
776848
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d"):
777849
return False
778850
if attrs.data_layout != "NCDHW":
@@ -792,6 +864,8 @@ def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
792864
"""Check if nn.max_pool3d is supported by TensorRT."""
793865

794866
attrs, args = expr.attrs, expr.args
867+
if not is_supported_trt_dtype(args):
868+
return False
795869
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.max_pool3d"):
796870
return False
797871
if attrs.layout != "NCDHW":
@@ -805,6 +879,8 @@ def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
805879
"""Check if nn.avg_pool3d is supported by TensorRT."""
806880

807881
attrs, args = expr.attrs, expr.args
882+
if not is_supported_trt_dtype(args):
883+
return False
808884
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.avg_pool3d"):
809885
return False
810886
if attrs.layout != "NCDHW":
@@ -818,6 +894,8 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
818894
"""Check if nn.conv3d_transpose is supported by TensorRT."""
819895

820896
attrs, args = expr.attrs, expr.args
897+
if not is_supported_trt_dtype(args):
898+
return False
821899
if not trt_version_annotate_fn((6, 0, 1))(attrs, args, "nn.conv3d_transpose"):
822900
return False
823901
if attrs.data_layout != "NCDHW":

src/runtime/contrib/tensorrt/tensorrt_builder.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode&
8585
shape.erase(shape.begin());
8686
}
8787
nvinfer1::Dims dims = VectorToTrtDims(shape);
88+
ICHECK((dtypes[i].bits != 16 || dtypes[i].bits != 32))
89+
<< "Invalid input Tensor type. Float16 and Float32 are supported";
90+
8891
auto tensor_dtype =
8992
(dtypes[i].bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
9093

@@ -210,7 +213,8 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
210213
nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,
211214
DLDeviceType src_device) {
212215
ICHECK_EQ(dptr->device.device_type, src_device);
213-
216+
ICHECK((dptr->dtype.bits != 16 || dptr->dtype.bits != 32))
217+
<< "Invalid input Tensor type. Float16 and Float32 are supported";
214218
const auto trt_dtype = (static_cast<int>(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF
215219
: nvinfer1::DataType::kFLOAT;
216220

@@ -253,7 +257,7 @@ void TensorRTBuilder::CleanUp() {
253257
#endif
254258
builder_->destroy();
255259
for (auto weight : trt_weights_) {
256-
if (static_cast<int>(weight.type) <= 1) {
260+
if (weight.type == nvinfer1::DataType::kFLOAT || weight.type == nvinfer1::DataType::kHALF) {
257261
delete[] static_cast<const float*>(weight.values);
258262
} else {
259263
delete[] static_cast<const uint16_t*>(weight.values);

0 commit comments

Comments
 (0)