2525from tvm .relay import testing
2626from tvm .relay .op .contrib import tensorrt
2727import numpy as np
28- from typing import List , Tuple
29-
30- # from tvm import script
31- # from tvm._ffi import register_func
32- # from tvm.runtime import Module
28+ from typing import List
3329from tvm ._ffi import register_func
34- from tvm .relay .testing .init import Initializer
3530from tvm .target import Target
3631from tvm .runtime import Module
3732from tvm .meta_schedule .arg_info import TensorInfo
@@ -94,25 +89,11 @@ def verify_meta_schedule_with_tensorrt(
9489):
9590 if use_meta_sched :
9691 # With meta_schedule
97- dev = "nvidia/geforce-rtx-2080 "
92+ dev = "cuda "
9893
9994 # Build
10095 if use_trt :
101-
102- def relay_build_with_tensorrt (
103- mod : Module ,
104- target : Target ,
105- params : dict ,
106- ) -> List [BuilderResult ]:
107- from tvm .relay .op .contrib .tensorrt import partition_for_tensorrt
108-
109- mod , config = partition_for_tensorrt (mod , params )
110- with tvm .transform .PassContext (
111- opt_level = 3 , config = {"relay.ext.tensorrt.options" : config }
112- ):
113- return tvm .relay .build_module ._build_module_no_factory (
114- mod , "cuda" , "llvm" , params
115- )
96+ from tvm .meta_schedule .testing import relay_build_with_tensorrt
11697
11798 builder = LocalBuilder (f_build = relay_build_with_tensorrt )
11899 else :
@@ -122,7 +103,6 @@ def relay_build_without_tensorrt(
122103 target : Target ,
123104 params : dict ,
124105 ) -> List [BuilderResult ]:
125- # @Sung: Weird. Cannot pass keyword arg
126106 return tvm .relay .build_module ._build_module_no_factory (mod , "cuda" , "llvm" , params )
127107
128108 builder = LocalBuilder (f_build = relay_build_without_tensorrt )
@@ -235,7 +215,7 @@ def test_conv2d_relu():
235215 "model_name" ,
236216 ["resnet-50" , "mobilenet" ],
237217)
238- @pytest .mark .parametrize ("batch_size" , [1 , 8 , 16 ])
218+ @pytest .mark .parametrize ("batch_size" , [1 ])
239219@pytest .mark .parametrize ("use_meta_sched" , [True ])
240220@pytest .mark .parametrize ("use_trt" , [True , False ])
241221def test_relay_model (model_name : str , batch_size : int , use_meta_sched : bool , use_trt : bool ):
@@ -246,6 +226,5 @@ def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use
246226 )
247227
248228
249- # @sunggg: memory verification error at test_relay_model("resnet-50", 1, use_meta_sched=False, use_trt=True)
250229if __name__ == "__main__" :
251230 sys .exit (pytest .main ([__file__ ] + sys .argv [1 :]))
0 commit comments