2828from tvm .relay .expr_functor import ExprMutator , ExprVisitor
2929
3030logger = logging .getLogger ("TensorRT" )
31+ supported_types = ["float32" , "float16" ]
32+
33+
34+ def is_supported_trt_dtype (args ):
35+ """Check if the TensorRT BYOC support input tensor dtype.
36+ Returns
37+ -------
38+ ret: bool
39+ True if supported, False if not.
40+ """
41+ if any ([x .checked_type .dtype in supported_types for x in args ]):
42+ logger .info ("Only float32 and float16 inputs are supported for TensorRT BYOC." )
43+ return True
44+ return False
3145
3246
3347def is_tensorrt_runtime_enabled ():
@@ -113,8 +127,10 @@ def partition_for_tensorrt(
113127 How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation.
114128 See TensorRT documentation for more info.
115129 use_fp16: Optional[bool]
116- Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled if FP16 inputs tensors and weights are used.
117- Note that TensorRT will still choose a higher-precision kernel if it results in overall lower runtime, or if no low-precision implementation exists.
130+ Allows, TRT to automatically convert FP32 inputs to FP16. Also, it is required to be enabled
131+ if FP16 inputs tensors and weights are used.
132+ Note that TensorRT will still choose a higher-precision kernel if it results in overall
133+ lower runtime, or if no low-precision implementation exists.
118134 use_uint8: Optional[bool]
119135 Allows, TRT to automatically convert FP32 inputs to UINT8.
120136 Returns
@@ -209,6 +225,8 @@ def _register_external_op_helper_with_checker(op_name, checker):
209225 def _func_wrapper (expr ):
210226 attrs , args = expr .attrs , expr .args
211227 # ops with dynamic shapes are offloaded to VM
228+ if not is_supported_trt_dtype (args ):
229+ return False
212230 if check_dynamism (args , op_name ):
213231 return False
214232 if op_name == "multiply" :
@@ -321,7 +339,8 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
321339 """Check if add is supported by TensorRT."""
322340
323341 args = expr .args
324-
342+ if not is_supported_trt_dtype (args ):
343+ return False
325344 shapes = [
326345 [int (x ) if not isinstance (x , tvm .tir .expr .Any ) else - 1 for x in arg .checked_type .shape ]
327346 for arg in args
@@ -350,6 +369,8 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
350369 """Check if nn.batch_norm is supported by TensorRT."""
351370
352371 attrs , args = expr .attrs , expr .args
372+ if not is_supported_trt_dtype (args ):
373+ return False
353374 if len (args [0 ].checked_type .shape ) == 5 and get_tensorrt_version () < (6 , 0 , 1 ):
354375 logger .info ("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs." )
355376 return False
@@ -366,7 +387,9 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
366387def softmax_annotate_fn (expr ): # pylint: disable=unused-variable
367388 """Check if nn.softmax is supported by TensorRT."""
368389
369- attrs = expr .attrs
390+ attrs , args = expr .attrs , expr .args
391+ if not is_supported_trt_dtype (args ):
392+ return False
370393 if get_tensorrt_use_implicit_batch_mode () and int (attrs .axis ) == 0 :
371394 logger .info ("nn.softmax: can't modify batch dimension." )
372395 return False
@@ -377,7 +400,9 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
377400def conv1d_annotate_fn (expr ): # pylint: disable=unused-variable
378401 """Check if nn.conv1d is supported by TensorRT."""
379402
380- attrs = expr .attrs
403+ attrs , args = expr .attrs , expr .args
404+ if not is_supported_trt_dtype (args ):
405+ return False
381406 if attrs .data_layout != "NCW" :
382407 logger .info ("nn.conv1d: data_layout is %s but must be NCW." , attrs .data_layout )
383408 return False
@@ -391,7 +416,9 @@ def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
391416def conv2d_annotate_fn (expr ): # pylint: disable=unused-variable
392417 """Check if nn.conv2d is supported by TensorRT."""
393418
394- attrs = expr .attrs
419+ attrs , args = expr .attrs , expr .args
420+ if not is_supported_trt_dtype (args ):
421+ return False
395422 if attrs .data_layout != "NCHW" :
396423 logger .info ("nn.conv2d: data_layout is %s but must be NCHW." , attrs .data_layout )
397424 return False
@@ -409,6 +436,8 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
409436 """Check if dense is supported by TensorRT."""
410437
411438 args = expr .args
439+ if not is_supported_trt_dtype (args ):
440+ return False
412441 input_rank = len (args [0 ].checked_type .shape )
413442 weight_rank = len (args [1 ].checked_type .shape )
414443 if input_rank not in (2 , 3 , 4 ):
@@ -424,6 +453,9 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
424453def batch_matmul_annotate_fn (expr ):
425454 """Check if dense is supported by TensorRT."""
426455
456+ args = expr .args
457+ if not is_supported_trt_dtype (args ):
458+ return False
427459 if get_tensorrt_use_implicit_batch_mode () and len (expr .args [0 ].checked_type .shape ) != len (
428460 expr .args [1 ].checked_type .shape
429461 ):
@@ -436,6 +468,9 @@ def batch_matmul_annotate_fn(expr):
436468def layer_norm_annotate_fn (expr ):
437469 """Check if dense is supported by TensorRT."""
438470
471+ args = expr .args
472+ if not is_supported_trt_dtype (args ):
473+ return False
439474 if get_tensorrt_use_implicit_batch_mode () and int (expr .attrs .axis ) == 0 :
440475 logger .info ("nn.layer_norm: requires use_implict_batch=False." )
441476 return False
@@ -446,7 +481,9 @@ def layer_norm_annotate_fn(expr):
446481def bias_add_annotate_fn (expr ): # pylint: disable=unused-variable
447482 """Check if nn.bias_add is supported by TensorRT."""
448483
449- args = expr .args
484+ attrs , args = expr .attrs , expr .args
485+ if not is_supported_trt_dtype (args ):
486+ return False
450487 input_rank = len (args [0 ].checked_type .shape )
451488 if input_rank not in (2 , 3 , 4 ):
452489 logger .info ("nn.bias_add: input rank is %d but must be 2, 3 or 4." , input_rank )
@@ -458,7 +495,9 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
458495def max_pool_2d_annotate_fn (expr ): # pylint: disable=unused-variable
459496 """Check if nn.max_pool2d is supported by TensorRT."""
460497
461- attrs = expr .attrs
498+ attrs , args = expr .attrs , expr .args
499+ if not is_supported_trt_dtype (args ):
500+ return False
462501 if attrs .layout != "NCHW" :
463502 logger .info ("nn.max_pool2d: layout is %s but must be NCHW." , attrs .layout )
464503 return False
@@ -472,7 +511,9 @@ def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
472511def avg_pool_2d_annotate_fn (expr ): # pylint: disable=unused-variable
473512 """Check if nn.avg_pool2d is supported by TensorRT."""
474513
475- attrs = expr .attrs
514+ attrs , args = expr .attrs , expr .args
515+ if not is_supported_trt_dtype (args ):
516+ return False
476517 if attrs .layout != "NCHW" :
477518 logger .info ("nn.avg_pool2d: layout is %d but must be NCHW." , attrs .layout )
478519 return False
@@ -499,7 +540,9 @@ def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
499540def global_max_pool_2d_annotate_fn (expr ): # pylint: disable=unused-variable
500541 """Check if nn.global_max_pool2d is supported by TensorRT."""
501542
502- attrs = expr .attrs
543+ attrs , args = expr .attrs , expr .args
544+ if not is_supported_trt_dtype (args ):
545+ return False
503546 if attrs .layout != "NCHW" :
504547 logger .info ("nn.global_max_pool2d: layout is %s but must be NCHW." , attrs .layout )
505548 return False
@@ -510,7 +553,9 @@ def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
510553def global_avg_pool_2d_annotate_fn (expr ): # pylint: disable=unused-variable
511554 """Check if nn.global_avg_pool2d is supported by TensorRT."""
512555
513- attrs = expr .attrs
556+ attrs , args = expr .attrs , expr .args
557+ if not is_supported_trt_dtype (args ):
558+ return False
514559 if attrs .layout != "NCHW" :
515560 logger .info ("nn.global_avg_pool2d: layout is %s but must be NCHW." , attrs .layout )
516561 return False
@@ -521,7 +566,9 @@ def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
521566def expand_dims_annotate_fn (expr ): # pylint: disable=unused-variable
522567 """Check if expand_dims is supported by TensorRT."""
523568
524- attrs = expr .attrs
569+ attrs , args = expr .attrs , expr .args
570+ if not is_supported_trt_dtype (args ):
571+ return False
525572 if get_tensorrt_use_implicit_batch_mode () and int (attrs .axis ) == 0 :
526573 logger .info ("expand_dims: can't modify batch dimension." )
527574 return False
@@ -532,7 +579,9 @@ def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
532579def squeeze_annotate_fn (expr ): # pylint: disable=unused-variable
533580 """Check if squeeze is supported by TensorRT."""
534581
535- attrs = expr .attrs
582+ attrs , args = expr .attrs , expr .args
583+ if not is_supported_trt_dtype (args ):
584+ return False
536585 if not attrs .axis :
537586 logger .info ("squeeze: must explicitly set axis." )
538587 return False
@@ -547,6 +596,8 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
547596 """Check if concatenate is supported by TensorRT."""
548597
549598 attrs , args = expr .attrs , expr .args
599+ if not is_supported_trt_dtype (args ):
600+ return False
550601 if not get_tensorrt_use_implicit_batch_mode ():
551602 return True
552603 if int (attrs .axis ) == 0 :
@@ -564,6 +615,9 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
564615def split_annotate_fn (expr ):
565616 """Check if split is supported by TensorRT."""
566617
618+ attrs , args = expr .attrs , expr .args
619+ if not is_supported_trt_dtype (args ):
620+ return False
567621 if get_tensorrt_use_implicit_batch_mode () and int (expr .attrs .axis ) == 0 :
568622 logger .info ("split: can't modify batch dimension." )
569623 return False
@@ -574,7 +628,9 @@ def split_annotate_fn(expr):
574628def conv2d_transpose_annotate_fn (expr ): # pylint: disable=unused-variable
575629 """Check if nn.conv2d_transpose is supported by TensorRT."""
576630
577- attrs = expr .attrs
631+ attrs , args = expr .attrs , expr .args
632+ if not is_supported_trt_dtype (args ):
633+ return False
578634 if attrs .data_layout != "NCHW" :
579635 logger .info ("nn.conv2d_transpose: data_layout is %s but must be NCHW." , attrs .data_layout )
580636 return False
@@ -596,7 +652,9 @@ def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
596652def transpose_annotate_fn (expr ): # pylint: disable=unused-variable
597653 """Check if transpose is supported by TensorRT."""
598654
599- attrs = expr .attrs
655+ attrs , args = expr .attrs , expr .args
656+ if not is_supported_trt_dtype (args ):
657+ return False
600658 if get_tensorrt_use_implicit_batch_mode () and int (attrs .axes [0 ]) != 0 :
601659 logger .info ("transpose: can't modify batch dimension." )
602660 return False
@@ -607,7 +665,9 @@ def transpose_annotate_fn(expr): # pylint: disable=unused-variable
607665def layout_transform_annotate_fn (expr ): # pylint: disable=unused-variable
608666 """Check if layout_transform is supported by TensorRT."""
609667
610- attrs = expr .attrs
668+ attrs , args = expr .attrs , expr .args
669+ if not is_supported_trt_dtype (args ):
670+ return False
611671 if (attrs .src_layout , attrs .dst_layout ) not in [
612672 ("NCHW" , "NHWC" ),
613673 ("NHWC" , "NCHW" ),
@@ -625,6 +685,8 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
625685def reshape_annotate_fn (expr ): # pylint: disable=unused-variable
626686 """Check if reshape is supported by TensorRT."""
627687 attrs , args = expr .attrs , expr .args
688+ if not is_supported_trt_dtype (args ):
689+ return False
628690 if any ([x < - 1 for x in map (int , attrs .newshape )]):
629691 logger .info ("reshape: new shape dims must be explicit." )
630692 return False
@@ -680,6 +742,8 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
680742 """Check if nn.pad is supported by TensorRT."""
681743
682744 attrs , args = expr .attrs , expr .args
745+ if not is_supported_trt_dtype (args ):
746+ return False
683747 pad_value = args [1 ]
684748 assert isinstance (pad_value , relay .Constant )
685749 pad_value = pad_value .data .numpy ().item ()
@@ -706,6 +770,8 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
706770 """Check if strided_slice is supported by TensorRT."""
707771
708772 attrs , args = expr .attrs , expr .args
773+ if not is_supported_trt_dtype (args ):
774+ return False
709775 if not trt_version_annotate_fn ((5 , 1 , 5 ))(attrs , args , "strided_slice" ):
710776 return False
711777 if get_tensorrt_use_implicit_batch_mode ():
@@ -750,7 +816,9 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
750816def adaptive_max_pool2d_annotate_fn (expr ): # pylint: disable=unused-variable
751817 """Check if nn.adaptive_max_pool2d is supported by TensorRT."""
752818
753- attrs = expr .attrs
819+ attrs , args = expr .attrs , expr .args
820+ if not is_supported_trt_dtype (args ):
821+ return False
754822 if len (attrs .output_size ) == 0 or any ([size != 1 for size in map (int , attrs .output_size )]):
755823 logger .info ("nn.adaptive_max_pool2d: output size must be (1, 1)." )
756824 return False
@@ -761,7 +829,9 @@ def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
761829def adaptive_avg_pool2d_annotate_fn (expr ): # pylint: disable=unused-variable
762830 """Check if nn.adaptive_avg_pool2d is supported by TensorRT."""
763831
764- attrs = expr .attrs
832+ attrs , args = expr .attrs , expr .args
833+ if not is_supported_trt_dtype (args ):
834+ return False
765835 if len (attrs .output_size ) == 0 or any ([size != 1 for size in map (int , attrs .output_size )]):
766836 logger .info ("nn.adaptive_avg_pool2d: output size must be (1, 1)." )
767837 return False
@@ -773,6 +843,8 @@ def conv3d_annotate_fn(expr): # pylint: disable=unused-variable
773843 """Check if nn.conv3d is supported by TensorRT."""
774844
775845 attrs , args = expr .attrs , expr .args
846+ if not is_supported_trt_dtype (args ):
847+ return False
776848 if not trt_version_annotate_fn ((6 , 0 , 1 ))(attrs , args , "nn.conv3d" ):
777849 return False
778850 if attrs .data_layout != "NCDHW" :
@@ -792,6 +864,8 @@ def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
792864 """Check if nn.max_pool3d is supported by TensorRT."""
793865
794866 attrs , args = expr .attrs , expr .args
867+ if not is_supported_trt_dtype (args ):
868+ return False
795869 if not trt_version_annotate_fn ((6 , 0 , 1 ))(attrs , args , "nn.max_pool3d" ):
796870 return False
797871 if attrs .layout != "NCDHW" :
@@ -805,6 +879,8 @@ def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
805879 """Check if nn.avg_pool3d is supported by TensorRT."""
806880
807881 attrs , args = expr .attrs , expr .args
882+ if not is_supported_trt_dtype (args ):
883+ return False
808884 if not trt_version_annotate_fn ((6 , 0 , 1 ))(attrs , args , "nn.avg_pool3d" ):
809885 return False
810886 if attrs .layout != "NCDHW" :
@@ -818,6 +894,8 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
818894 """Check if nn.conv3d_transpose is supported by TensorRT."""
819895
820896 attrs , args = expr .attrs , expr .args
897+ if not is_supported_trt_dtype (args ):
898+ return False
821899 if not trt_version_annotate_fn ((6 , 0 , 1 ))(attrs , args , "nn.conv3d_transpose" ):
822900 return False
823901 if attrs .data_layout != "NCDHW" :
0 commit comments