1515# specific language governing permissions and limitations
1616# under the License.
1717
18- """ Test Managers in MSC. """
18+ """ Test Pipeline in MSC. """
1919
2020import json
2121import pytest
2222import torch
2323
2424import tvm .testing
25- from tvm .contrib .msc .pipeline import MSCManager
25+ from tvm .contrib .msc .pipeline import MSCManager , TorchDynamic
2626from tvm .contrib .msc .core .utils .namespace import MSCFramework
2727from tvm .contrib .msc .core import utils as msc_utils
2828
3232)
3333
3434
35- def _get_config (model_type , compile_type , inputs , outputs , atol = 1e-1 , rtol = 1e-1 ):
35+ def _get_config (model_type , compile_type , inputs , outputs , dynamic = False , atol = 1e-1 , rtol = 1e-1 ):
3636 """Get msc config"""
3737
38- path = "test_manager_ {}_{}" .format (model_type , compile_type )
38+ path = "test_pipe_ {}_{}_{} " .format (model_type , compile_type , "dynamic" if dynamic else "static" )
3939 return {
4040 "workspace" : msc_utils .msc_dir (path ),
41- "verbose" : "critical " ,
41+ "verbose" : "info " ,
4242 "model_type" : model_type ,
4343 "inputs" : inputs ,
4444 "outputs" : outputs ,
@@ -95,23 +95,29 @@ def _get_tf_graph():
9595 return None
9696
9797
98- def _check_manager ( manager , expected_info ):
99- """Check the manager results"""
98+ def _check_pipeline ( pipeline , expected_info , dynamic = False ):
99+ """Check the pipeline results"""
100100
101- model_info = manager .runner .model_info
102101 passed , err = True , ""
103- if not manager .report ["success" ]:
102+ if not pipeline .report ["success" ]:
104103 passed = False
105- err = "Failed to run pipe for {} -> {}" .format (manager .model_type , manager .compile_type )
106- if not msc_utils .dict_equal (model_info , expected_info ):
107- passed = False
108- err = "Model info {} mismatch with expected {}" .format (model_info , expected_info )
109- manager .destory ()
104+ err = "Failed to run pipe for {} -> {}" .format (pipeline .model_type , pipeline .compile_type )
105+ if not dynamic :
106+ model_info = pipeline .get_runtime ().model_info
107+ if not msc_utils .dict_equal (model_info , expected_info ):
108+ passed = False
109+ err = "Model info {} mismatch with expected {}" .format (model_info , expected_info )
110+ pipeline .destory ()
110111 if not passed :
111- raise Exception ("{}\n Report:{}" .format (err , json .dumps (manager .report , indent = 2 )))
112+ raise Exception ("{}\n Report:{}" .format (err , json .dumps (pipeline .report , indent = 2 )))
113+
112114
115+ def _test_from_torch (
116+ compile_type , expected_info , training = False , dynamic = False , atol = 1e-1 , rtol = 1e-1
117+ ):
118+ if dynamic and not hasattr (torch , "compile" ):
119+ return
113120
114- def _test_from_torch (compile_type , expected_info , training = False , atol = 1e-1 , rtol = 1e-1 ):
115121 torch_model = _get_torch_model ("resnet50" , training )
116122 if torch_model :
117123 if torch .cuda .is_available ():
@@ -121,12 +127,13 @@ def _test_from_torch(compile_type, expected_info, training=False, atol=1e-1, rto
121127 compile_type ,
122128 inputs = [["input_0" , [1 , 3 , 224 , 224 ], "float32" ]],
123129 outputs = ["output" ],
130+ dynamic = dynamic ,
124131 atol = atol ,
125132 rtol = rtol ,
126133 )
127- manager = MSCManager (torch_model , config )
128- manager .run_pipe ()
129- _check_manager ( manager , expected_info )
134+ pipeline = TorchDynamic ( torch_model , config ) if dynamic else MSCManager (torch_model , config )
135+ pipeline .run_pipe ()
136+ _check_pipeline ( pipeline , expected_info , dynamic )
130137
131138
132139def _test_from_tf (compile_type , expected_info , atol = 1e-2 , rtol = 1e-2 ):
@@ -143,11 +150,12 @@ def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2):
143150 config ["compile" ]["profile" ]["check" ]["err_rate" ] = - 1
144151 manager = MSCManager (graphdef , config )
145152 manager .run_pipe ()
146- _check_manager (manager , expected_info )
153+ _check_pipeline (manager , expected_info )
147154
148155
149- def test_tvm_manager ():
150- """Test manager for tvm"""
156+ @pytest .mark .parametrize ("dynamic" , [False , True ])
157+ def test_tvm_pipeline (dynamic ):
158+ """Test pipeline for tvm"""
151159
152160 model_info = {
153161 "inputs" : [
@@ -168,40 +176,42 @@ def test_tvm_manager():
168176 "msc.linear_bias" : 1 ,
169177 },
170178 }
171- _test_from_torch (MSCFramework .TVM , model_info , training = False )
172-
173- model_info = {
174- "inputs" : [
175- {"name" : "input" , "shape" : [1 , 224 , 224 , 3 ], "dtype" : "float32" , "layout" : "NHWC" }
176- ],
177- "outputs" : [
178- {
179- "name" : "MobilenetV2/Predictions/Reshape_1:0" ,
180- "shape" : [1 , 1001 ],
181- "dtype" : "float32" ,
182- "layout" : "NC" ,
183- }
184- ],
185- "nodes" : {
186- "total" : 138 ,
187- "input" : 1 ,
188- "msc.conv2d_bias" : 36 ,
189- "clip" : 35 ,
190- "nn.conv2d" : 17 ,
191- "nn.batch_norm" : 17 ,
192- "get_item" : 17 ,
193- "add" : 10 ,
194- "nn.avg_pool2d" : 1 ,
195- "squeeze" : 1 ,
196- "reshape" : 2 ,
197- "nn.softmax" : 1 ,
198- },
199- }
200- _test_from_tf (MSCFramework .TVM , model_info )
201-
202-
203- def test_torch_manager ():
204- """Test manager for torch"""
179+ _test_from_torch (MSCFramework .TVM , model_info , training = False , dynamic = dynamic )
180+
181+ if not dynamic :
182+ model_info = {
183+ "inputs" : [
184+ {"name" : "input" , "shape" : [1 , 224 , 224 , 3 ], "dtype" : "float32" , "layout" : "NHWC" }
185+ ],
186+ "outputs" : [
187+ {
188+ "name" : "MobilenetV2/Predictions/Reshape_1:0" ,
189+ "shape" : [1 , 1001 ],
190+ "dtype" : "float32" ,
191+ "layout" : "NC" ,
192+ }
193+ ],
194+ "nodes" : {
195+ "total" : 138 ,
196+ "input" : 1 ,
197+ "msc.conv2d_bias" : 36 ,
198+ "clip" : 35 ,
199+ "nn.conv2d" : 17 ,
200+ "nn.batch_norm" : 17 ,
201+ "get_item" : 17 ,
202+ "add" : 10 ,
203+ "nn.avg_pool2d" : 1 ,
204+ "squeeze" : 1 ,
205+ "reshape" : 2 ,
206+ "nn.softmax" : 1 ,
207+ },
208+ }
209+ _test_from_tf (MSCFramework .TVM , model_info )
210+
211+
212+ @pytest .mark .parametrize ("dynamic" , [False , True ])
213+ def test_torch_pipeline (dynamic ):
214+ """Test pipeline for torch"""
205215
206216 model_info = {
207217 "inputs" : [
@@ -222,10 +232,10 @@ def test_torch_manager():
222232 "msc.linear_bias" : 1 ,
223233 },
224234 }
225- _test_from_torch (MSCFramework .TORCH , model_info , training = False )
235+ _test_from_torch (MSCFramework .TORCH , model_info , training = False , dynamic = dynamic )
226236
227237
228- def test_tensorflow_manager ():
238+ def test_tensorflow_pipeline ():
229239 """Test manager for tensorflow"""
230240
231241 model_info = {
@@ -259,8 +269,9 @@ def test_tensorflow_manager():
259269
260270
261271@requires_tensorrt
262- def test_tensorrt_manager ():
263- """Test manager for tensorrt"""
272+ @pytest .mark .parametrize ("dynamic" , [False , True ])
273+ def test_tensorrt_pipeline (dynamic ):
274+ """Test pipeline for tensorrt"""
264275
265276 model_info = {
266277 "inputs" : [
@@ -269,7 +280,7 @@ def test_tensorrt_manager():
269280 "outputs" : [{"name" : "output" , "shape" : [1 , 1000 ], "dtype" : "float32" , "layout" : "" }],
270281 "nodes" : {"total" : 2 , "input" : 1 , "msc_tensorrt" : 1 },
271282 }
272- _test_from_torch (MSCFramework .TENSORRT , model_info , training = False )
283+ _test_from_torch (MSCFramework .TENSORRT , model_info , training = False , dynamic = dynamic )
273284
274285
275286if __name__ == "__main__" :
0 commit comments