@@ -202,9 +202,6 @@ def _func_wrapper(expr):
202202 # ops with dynamic shapes are offloaded to VM
203203 if check_dynamism (args , op_name ):
204204 return False
205- if any ([x .checked_type .dtype != "float32" for x in args ]):
206- logger .info ("Only float32 inputs are supported for TensorRT." )
207- return False
208205 if op_name == "multiply" :
209206 shapes = [
210207 [
@@ -325,9 +322,6 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
325322 if get_tensorrt_use_implicit_batch_mode () and any ([len (shape ) < 1 for shape in shapes ]):
326323 return False
327324
328- if any ([x .checked_type .dtype != "float32" for x in args ]):
329- logger .info ("Only float32 inputs are supported for TensorRT." )
330- return False
331325 if (
332326 not get_tensorrt_use_implicit_batch_mode ()
333327 and (isinstance (args [0 ], Constant ) or isinstance (args [1 ], Constant ))
@@ -347,9 +341,6 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
347341 """Check if nn.batch_norm is supported by TensorRT."""
348342
349343 attrs , args = expr .attrs , expr .args
350- if any ([x .checked_type .dtype != "float32" for x in args ]):
351- logger .info ("Only float32 inputs are supported for TensorRT." )
352- return False
353344 if len (args [0 ].checked_type .shape ) == 5 and get_tensorrt_version () < (6 , 0 , 1 ):
354345 logger .info ("nn.batch_norm: TensorRT 6.0.1 or higher is required for rank 5 inputs." )
355346 return False
@@ -366,10 +357,7 @@ def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable
366357def softmax_annotate_fn (expr ): # pylint: disable=unused-variable
367358 """Check if nn.softmax is supported by TensorRT."""
368359
369- attrs , args = expr .attrs , expr .args
370- if any ([x .checked_type .dtype != "float32" for x in args ]):
371- logger .info ("Only float32 inputs are supported for TensorRT." )
372- return False
360+ attrs = expr .attrs
373361 if get_tensorrt_use_implicit_batch_mode () and int (attrs .axis ) == 0 :
374362 logger .info ("nn.softmax: can't modify batch dimension." )
375363 return False
@@ -380,10 +368,7 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
380368def conv1d_annotate_fn (expr ): # pylint: disable=unused-variable
381369 """Check if nn.conv1d is supported by TensorRT."""
382370
383- attrs , args = expr .attrs , expr .args
384- if any ([x .checked_type .dtype != "float32" for x in args ]):
385- logger .info ("Only float32 inputs are supported for TensorRT." )
386- return False
371+ attrs = expr .attrs
387372 if attrs .data_layout != "NCW" :
388373 logger .info ("nn.conv1d: data_layout is %s but must be NCW." , attrs .data_layout )
389374 return False
@@ -397,10 +382,7 @@ def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
397382def conv2d_annotate_fn (expr ): # pylint: disable=unused-variable
398383 """Check if nn.conv2d is supported by TensorRT."""
399384
400- attrs , args = expr .attrs , expr .args
401- if any ([x .checked_type .dtype != "float32" for x in args ]):
402- logger .info ("Only float32 inputs are supported for TensorRT." )
403- return False
385+ attrs = expr .attrs
404386 if attrs .data_layout != "NCHW" :
405387 logger .info ("nn.conv2d: data_layout is %s but must be NCHW." , attrs .data_layout )
406388 return False
@@ -418,9 +400,6 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
418400 """Check if dense is supported by TensorRT."""
419401
420402 args = expr .args
421- if any ([x .checked_type .dtype != "float32" for x in args ]):
422- logger .info ("Only float32 inputs are supported for TensorRT." )
423- return False
424403 input_rank = len (args [0 ].checked_type .shape )
425404 weight_rank = len (args [1 ].checked_type .shape )
426405 if input_rank not in (2 , 3 , 4 ):
@@ -436,9 +415,6 @@ def dense_annotate_fn(expr): # pylint: disable=unused-variable
436415def batch_matmul_annotate_fn (expr ):
437416 """Check if dense is supported by TensorRT."""
438417
439- if any ([x .checked_type .dtype != "float32" for x in expr .args ]):
440- logger .info ("Only float32 inputs are supported for TensorRT." )
441- return False
442418 if get_tensorrt_use_implicit_batch_mode () and len (expr .args [0 ].checked_type .shape ) != len (
443419 expr .args [1 ].checked_type .shape
444420 ):
@@ -451,9 +427,6 @@ def batch_matmul_annotate_fn(expr):
451427def layer_norm_annotate_fn (expr ):
452428 """Check if dense is supported by TensorRT."""
453429
454- if any ([x .checked_type .dtype != "float32" for x in expr .args ]):
455- logger .info ("Only float32 inputs are supported for TensorRT." )
456- return False
457430 if get_tensorrt_use_implicit_batch_mode () and int (expr .attrs .axis ) == 0 :
458431 logger .info ("nn.layer_norm: requires use_implict_batch=False." )
459432 return False
@@ -465,9 +438,6 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
465438 """Check if nn.bias_add is supported by TensorRT."""
466439
467440 args = expr .args
468- if any ([x .checked_type .dtype != "float32" for x in args ]):
469- logger .info ("Only float32 inputs are supported for TensorRT." )
470- return False
471441 input_rank = len (args [0 ].checked_type .shape )
472442 if input_rank not in (2 , 3 , 4 ):
473443 logger .info ("nn.bias_add: input rank is %d but must be 2, 3 or 4." , input_rank )
@@ -479,10 +449,7 @@ def bias_add_annotate_fn(expr): # pylint: disable=unused-variable
479449def max_pool_2d_annotate_fn (expr ): # pylint: disable=unused-variable
480450 """Check if nn.max_pool2d is supported by TensorRT."""
481451
482- attrs , args = expr .attrs , expr .args
483- if any ([x .checked_type .dtype != "float32" for x in args ]):
484- logger .info ("Only float32 inputs are supported for TensorRT." )
485- return False
452+ attrs = expr .attrs
486453 if attrs .layout != "NCHW" :
487454 logger .info ("nn.max_pool2d: layout is %s but must be NCHW." , attrs .layout )
488455 return False
@@ -496,10 +463,7 @@ def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
496463def avg_pool_2d_annotate_fn (expr ): # pylint: disable=unused-variable
497464 """Check if nn.avg_pool2d is supported by TensorRT."""
498465
499- attrs , args = expr .attrs , expr .args
500- if any ([x .checked_type .dtype != "float32" for x in args ]):
501- logger .info ("Only float32 inputs are supported for TensorRT." )
502- return False
466+ attrs = expr .attrs
503467 if attrs .layout != "NCHW" :
504468 logger .info ("nn.avg_pool2d: layout is %d but must be NCHW." , attrs .layout )
505469 return False
@@ -526,10 +490,7 @@ def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
526490def global_max_pool_2d_annotate_fn (expr ): # pylint: disable=unused-variable
527491 """Check if nn.global_max_pool2d is supported by TensorRT."""
528492
529- attrs , args = expr .attrs , expr .args
530- if any ([x .checked_type .dtype != "float32" for x in args ]):
531- logger .info ("Only float32 inputs are supported for TensorRT." )
532- return False
493+ attrs = expr .attrs
533494 if attrs .layout != "NCHW" :
534495 logger .info ("nn.global_max_pool2d: layout is %s but must be NCHW." , attrs .layout )
535496 return False
@@ -540,10 +501,7 @@ def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
540501def global_avg_pool_2d_annotate_fn (expr ): # pylint: disable=unused-variable
541502 """Check if nn.global_avg_pool2d is supported by TensorRT."""
542503
543- attrs , args = expr .attrs , expr .args
544- if any ([x .checked_type .dtype != "float32" for x in args ]):
545- logger .info ("Only float32 inputs are supported for TensorRT." )
546- return False
504+ attrs = expr .attrs
547505 if attrs .layout != "NCHW" :
548506 logger .info ("nn.global_avg_pool2d: layout is %s but must be NCHW." , attrs .layout )
549507 return False
@@ -554,10 +512,7 @@ def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable
554512def expand_dims_annotate_fn (expr ): # pylint: disable=unused-variable
555513 """Check if expand_dims is supported by TensorRT."""
556514
557- attrs , args = expr .attrs , expr .args
558- if any ([x .checked_type .dtype != "float32" for x in args ]):
559- logger .info ("Only float32 inputs are supported for TensorRT." )
560- return False
515+ attrs = expr .attrs
561516 if get_tensorrt_use_implicit_batch_mode () and int (attrs .axis ) == 0 :
562517 logger .info ("expand_dims: can't modify batch dimension." )
563518 return False
@@ -568,10 +523,7 @@ def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable
568523def squeeze_annotate_fn (expr ): # pylint: disable=unused-variable
569524 """Check if squeeze is supported by TensorRT."""
570525
571- attrs , args = expr .attrs , expr .args
572- if any ([x .checked_type .dtype != "float32" for x in args ]):
573- logger .info ("Only float32 inputs are supported for TensorRT." )
574- return False
526+ attrs = expr .attrs
575527 if not attrs .axis :
576528 logger .info ("squeeze: must explicitly set axis." )
577529 return False
@@ -586,9 +538,6 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
586538 """Check if concatenate is supported by TensorRT."""
587539
588540 attrs , args = expr .attrs , expr .args
589- if any ([x .dtype != "float32" for x in args [0 ].checked_type .fields ]):
590- logger .info ("Only float32 inputs are supported for TensorRT." )
591- return False
592541 if not get_tensorrt_use_implicit_batch_mode ():
593542 return True
594543 if int (attrs .axis ) == 0 :
@@ -606,9 +555,6 @@ def concatenate_annotate_fn(expr): # pylint: disable=unused-variable
606555def split_annotate_fn (expr ):
607556 """Check if split is supported by TensorRT."""
608557
609- if any ([x .checked_type .dtype != "float32" for x in expr .args ]):
610- logger .info ("Only float32 inputs are supported for TensorRT." )
611- return False
612558 if get_tensorrt_use_implicit_batch_mode () and int (expr .attrs .axis ) == 0 :
613559 logger .info ("split: can't modify batch dimension." )
614560 return False
@@ -619,10 +565,7 @@ def split_annotate_fn(expr):
619565def conv2d_transpose_annotate_fn (expr ): # pylint: disable=unused-variable
620566 """Check if nn.conv2d_transpose is supported by TensorRT."""
621567
622- attrs , args = expr .attrs , expr .args
623- if any ([x .checked_type .dtype != "float32" for x in args ]):
624- logger .info ("Only float32 inputs are supported for TensorRT." )
625- return False
568+ attrs = expr .attrs
626569 if attrs .data_layout != "NCHW" :
627570 logger .info ("nn.conv2d_transpose: data_layout is %s but must be NCHW." , attrs .data_layout )
628571 return False
@@ -644,10 +587,7 @@ def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
644587def transpose_annotate_fn (expr ): # pylint: disable=unused-variable
645588 """Check if transpose is supported by TensorRT."""
646589
647- attrs , args = expr .attrs , expr .args
648- if any ([x .checked_type .dtype != "float32" for x in args ]):
649- logger .info ("Only float32 inputs are supported for TensorRT." )
650- return False
590+ attrs = expr .attrs
651591 if get_tensorrt_use_implicit_batch_mode () and int (attrs .axes [0 ]) != 0 :
652592 logger .info ("transpose: can't modify batch dimension." )
653593 return False
@@ -658,10 +598,7 @@ def transpose_annotate_fn(expr): # pylint: disable=unused-variable
658598def layout_transform_annotate_fn (expr ): # pylint: disable=unused-variable
659599 """Check if layout_transform is supported by TensorRT."""
660600
661- attrs , args = expr .attrs , expr .args
662- if any ([x .checked_type .dtype != "float32" for x in args ]):
663- logger .info ("Only float32 inputs are supported for TensorRT." )
664- return False
601+ attrs = expr .attrs
665602 if (attrs .src_layout , attrs .dst_layout ) not in [
666603 ("NCHW" , "NHWC" ),
667604 ("NHWC" , "NCHW" ),
@@ -679,9 +616,6 @@ def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable
679616def reshape_annotate_fn (expr ): # pylint: disable=unused-variable
680617 """Check if reshape is supported by TensorRT."""
681618 attrs , args = expr .attrs , expr .args
682- if args [0 ].checked_type .dtype != "float32" :
683- logger .info ("Only float32 inputs are supported for TensorRT." )
684- return False
685619 if any ([x < - 1 for x in map (int , attrs .newshape )]):
686620 logger .info ("reshape: new shape dims must be explicit." )
687621 return False
@@ -740,9 +674,6 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable
740674 pad_value = args [1 ]
741675 assert isinstance (pad_value , relay .Constant )
742676 pad_value = pad_value .data .numpy ().item ()
743- if any ([x .checked_type .dtype != "float32" for x in args ]):
744- logger .info ("Only float32 inputs are supported for TensorRT." )
745- return False
746677 if attrs .pad_mode != "constant" :
747678 logger .info ("nn.pad: pad mode is %s but must be constant." , attrs .pad_mode )
748679 return False
@@ -766,9 +697,6 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
766697 """Check if strided_slice is supported by TensorRT."""
767698
768699 attrs , args = expr .attrs , expr .args
769- if args [0 ].checked_type .dtype != "float32" :
770- logger .info ("Only float32 inputs are supported for TensorRT." )
771- return False
772700 if not trt_version_annotate_fn ((5 , 1 , 5 ))(attrs , args , "strided_slice" ):
773701 return False
774702 if get_tensorrt_use_implicit_batch_mode ():
@@ -813,10 +741,7 @@ def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable
813741def adaptive_max_pool2d_annotate_fn (expr ): # pylint: disable=unused-variable
814742 """Check if nn.adaptive_max_pool2d is supported by TensorRT."""
815743
816- attrs , args = expr .attrs , expr .args
817- if any ([x .checked_type .dtype != "float32" for x in args ]):
818- logger .info ("Only float32 inputs are supported for TensorRT." )
819- return False
744+ attrs = expr .attrs
820745 if len (attrs .output_size ) == 0 or any ([size != 1 for size in map (int , attrs .output_size )]):
821746 logger .info ("nn.adaptive_max_pool2d: output size must be (1, 1)." )
822747 return False
@@ -827,10 +752,7 @@ def adaptive_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable
827752def adaptive_avg_pool2d_annotate_fn (expr ): # pylint: disable=unused-variable
828753 """Check if nn.adaptive_avg_pool2d is supported by TensorRT."""
829754
830- attrs , args = expr .attrs , expr .args
831- if any ([x .checked_type .dtype != "float32" for x in args ]):
832- logger .info ("Only float32 inputs are supported for TensorRT." )
833- return False
755+ attrs = expr .attrs
834756 if len (attrs .output_size ) == 0 or any ([size != 1 for size in map (int , attrs .output_size )]):
835757 logger .info ("nn.adaptive_avg_pool2d: output size must be (1, 1)." )
836758 return False
@@ -842,9 +764,6 @@ def conv3d_annotate_fn(expr): # pylint: disable=unused-variable
842764 """Check if nn.conv3d is supported by TensorRT."""
843765
844766 attrs , args = expr .attrs , expr .args
845- if any ([x .checked_type .dtype != "float32" for x in args ]):
846- logger .info ("Only float32 inputs are supported for TensorRT." )
847- return False
848767 if not trt_version_annotate_fn ((6 , 0 , 1 ))(attrs , args , "nn.conv3d" ):
849768 return False
850769 if attrs .data_layout != "NCDHW" :
@@ -864,9 +783,6 @@ def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
864783 """Check if nn.max_pool3d is supported by TensorRT."""
865784
866785 attrs , args = expr .attrs , expr .args
867- if any ([x .checked_type .dtype != "float32" for x in args ]):
868- logger .info ("Only float32 inputs are supported for TensorRT." )
869- return False
870786 if not trt_version_annotate_fn ((6 , 0 , 1 ))(attrs , args , "nn.max_pool3d" ):
871787 return False
872788 if attrs .layout != "NCDHW" :
@@ -880,9 +796,6 @@ def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable
880796 """Check if nn.avg_pool3d is supported by TensorRT."""
881797
882798 attrs , args = expr .attrs , expr .args
883- if any ([x .checked_type .dtype != "float32" for x in args ]):
884- logger .info ("Only float32 inputs are supported for TensorRT." )
885- return False
886799 if not trt_version_annotate_fn ((6 , 0 , 1 ))(attrs , args , "nn.avg_pool3d" ):
887800 return False
888801 if attrs .layout != "NCDHW" :
@@ -896,9 +809,6 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
896809 """Check if nn.conv3d_transpose is supported by TensorRT."""
897810
898811 attrs , args = expr .attrs , expr .args
899- if any ([x .checked_type .dtype != "float32" for x in args ]):
900- logger .info ("Only float32 inputs are supported for TensorRT." )
901- return False
902812 if not trt_version_annotate_fn ((6 , 0 , 1 ))(attrs , args , "nn.conv3d_transpose" ):
903813 return False
904814 if attrs .data_layout != "NCDHW" :
0 commit comments