Skip to content

Commit d8f7ddc

Browse files
committed
update test
1 parent 25201ef commit d8f7ddc

File tree

4 files changed

+77
-66
lines changed

4 files changed

+77
-66
lines changed

tests/python/contrib/test_msc/test_manager.py renamed to tests/python/contrib/test_msc/test_pipeline.py

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
""" Test Managers in MSC. """
18+
""" Test Pipeline in MSC. """
1919

2020
import json
2121
import pytest
2222
import torch
2323

2424
import tvm.testing
25-
from tvm.contrib.msc.pipeline import MSCManager
25+
from tvm.contrib.msc.pipeline import MSCManager, TorchDynamic
2626
from tvm.contrib.msc.core.utils.namespace import MSCFramework
2727
from tvm.contrib.msc.core import utils as msc_utils
2828

@@ -32,13 +32,13 @@
3232
)
3333

3434

35-
def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1):
35+
def _get_config(model_type, compile_type, inputs, outputs, dynamic=False, atol=1e-1, rtol=1e-1):
3636
"""Get msc config"""
3737

38-
path = "test_manager_{}_{}".format(model_type, compile_type)
38+
path = "test_pipe_{}_{}_{}".format(model_type, compile_type, "dynamic" if dynamic else "static")
3939
return {
4040
"workspace": msc_utils.msc_dir(path),
41-
"verbose": "critical",
41+
"verbose": "info",
4242
"model_type": model_type,
4343
"inputs": inputs,
4444
"outputs": outputs,
@@ -95,23 +95,29 @@ def _get_tf_graph():
9595
return None
9696

9797

98-
def _check_manager(manager, expected_info):
99-
"""Check the manager results"""
98+
def _check_pipeline(pipeline, expected_info, dynamic=False):
99+
"""Check the pipeline results"""
100100

101-
model_info = manager.runner.model_info
102101
passed, err = True, ""
103-
if not manager.report["success"]:
102+
if not pipeline.report["success"]:
104103
passed = False
105-
err = "Failed to run pipe for {} -> {}".format(manager.model_type, manager.compile_type)
106-
if not msc_utils.dict_equal(model_info, expected_info):
107-
passed = False
108-
err = "Model info {} mismatch with expected {}".format(model_info, expected_info)
109-
manager.destory()
104+
err = "Failed to run pipe for {} -> {}".format(pipeline.model_type, pipeline.compile_type)
105+
if not dynamic:
106+
model_info = pipeline.get_runtime().model_info
107+
if not msc_utils.dict_equal(model_info, expected_info):
108+
passed = False
109+
err = "Model info {} mismatch with expected {}".format(model_info, expected_info)
110+
pipeline.destory()
110111
if not passed:
111-
raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2)))
112+
raise Exception("{}\nReport:{}".format(err, json.dumps(pipeline.report, indent=2)))
113+
112114

115+
def _test_from_torch(
116+
compile_type, expected_info, training=False, dynamic=False, atol=1e-1, rtol=1e-1
117+
):
118+
if dynamic and not hasattr(torch, "compile"):
119+
return
113120

114-
def _test_from_torch(compile_type, expected_info, training=False, atol=1e-1, rtol=1e-1):
115121
torch_model = _get_torch_model("resnet50", training)
116122
if torch_model:
117123
if torch.cuda.is_available():
@@ -121,12 +127,13 @@ def _test_from_torch(compile_type, expected_info, training=False, atol=1e-1, rto
121127
compile_type,
122128
inputs=[["input_0", [1, 3, 224, 224], "float32"]],
123129
outputs=["output"],
130+
dynamic=dynamic,
124131
atol=atol,
125132
rtol=rtol,
126133
)
127-
manager = MSCManager(torch_model, config)
128-
manager.run_pipe()
129-
_check_manager(manager, expected_info)
134+
pipeline = TorchDynamic(torch_model, config) if dynamic else MSCManager(torch_model, config)
135+
pipeline.run_pipe()
136+
_check_pipeline(pipeline, expected_info, dynamic)
130137

131138

132139
def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2):
@@ -143,11 +150,12 @@ def _test_from_tf(compile_type, expected_info, atol=1e-2, rtol=1e-2):
143150
config["compile"]["profile"]["check"]["err_rate"] = -1
144151
manager = MSCManager(graphdef, config)
145152
manager.run_pipe()
146-
_check_manager(manager, expected_info)
153+
_check_pipeline(manager, expected_info)
147154

148155

149-
def test_tvm_manager():
150-
"""Test manager for tvm"""
156+
@pytest.mark.parametrize("dynamic", [False, True])
157+
def test_tvm_pipeline(dynamic):
158+
"""Test pipeline for tvm"""
151159

152160
model_info = {
153161
"inputs": [
@@ -168,40 +176,42 @@ def test_tvm_manager():
168176
"msc.linear_bias": 1,
169177
},
170178
}
171-
_test_from_torch(MSCFramework.TVM, model_info, training=False)
172-
173-
model_info = {
174-
"inputs": [
175-
{"name": "input", "shape": [1, 224, 224, 3], "dtype": "float32", "layout": "NHWC"}
176-
],
177-
"outputs": [
178-
{
179-
"name": "MobilenetV2/Predictions/Reshape_1:0",
180-
"shape": [1, 1001],
181-
"dtype": "float32",
182-
"layout": "NC",
183-
}
184-
],
185-
"nodes": {
186-
"total": 138,
187-
"input": 1,
188-
"msc.conv2d_bias": 36,
189-
"clip": 35,
190-
"nn.conv2d": 17,
191-
"nn.batch_norm": 17,
192-
"get_item": 17,
193-
"add": 10,
194-
"nn.avg_pool2d": 1,
195-
"squeeze": 1,
196-
"reshape": 2,
197-
"nn.softmax": 1,
198-
},
199-
}
200-
_test_from_tf(MSCFramework.TVM, model_info)
201-
202-
203-
def test_torch_manager():
204-
"""Test manager for torch"""
179+
_test_from_torch(MSCFramework.TVM, model_info, training=False, dynamic=dynamic)
180+
181+
if not dynamic:
182+
model_info = {
183+
"inputs": [
184+
{"name": "input", "shape": [1, 224, 224, 3], "dtype": "float32", "layout": "NHWC"}
185+
],
186+
"outputs": [
187+
{
188+
"name": "MobilenetV2/Predictions/Reshape_1:0",
189+
"shape": [1, 1001],
190+
"dtype": "float32",
191+
"layout": "NC",
192+
}
193+
],
194+
"nodes": {
195+
"total": 138,
196+
"input": 1,
197+
"msc.conv2d_bias": 36,
198+
"clip": 35,
199+
"nn.conv2d": 17,
200+
"nn.batch_norm": 17,
201+
"get_item": 17,
202+
"add": 10,
203+
"nn.avg_pool2d": 1,
204+
"squeeze": 1,
205+
"reshape": 2,
206+
"nn.softmax": 1,
207+
},
208+
}
209+
_test_from_tf(MSCFramework.TVM, model_info)
210+
211+
212+
@pytest.mark.parametrize("dynamic", [False, True])
213+
def test_torch_pipeline(dynamic):
214+
"""Test pipeline for torch"""
205215

206216
model_info = {
207217
"inputs": [
@@ -222,10 +232,10 @@ def test_torch_manager():
222232
"msc.linear_bias": 1,
223233
},
224234
}
225-
_test_from_torch(MSCFramework.TORCH, model_info, training=False)
235+
_test_from_torch(MSCFramework.TORCH, model_info, training=False, dynamic=dynamic)
226236

227237

228-
def test_tensorflow_manager():
238+
def test_tensorflow_pipeline():
229239
"""Test manager for tensorflow"""
230240

231241
model_info = {
@@ -259,8 +269,9 @@ def test_tensorflow_manager():
259269

260270

261271
@requires_tensorrt
262-
def test_tensorrt_manager():
263-
"""Test manager for tensorrt"""
272+
@pytest.mark.parametrize("dynamic", [False, True])
273+
def test_tensorrt_pipeline(dynamic):
274+
"""Test pipeline for tensorrt"""
264275

265276
model_info = {
266277
"inputs": [
@@ -269,7 +280,7 @@ def test_tensorrt_manager():
269280
"outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}],
270281
"nodes": {"total": 2, "input": 1, "msc_tensorrt": 1},
271282
}
272-
_test_from_torch(MSCFramework.TENSORRT, model_info, training=False)
283+
_test_from_torch(MSCFramework.TENSORRT, model_info, training=False, dynamic=dynamic)
273284

274285

275286
if __name__ == "__main__":

tests/python/contrib/test_msc/test_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def _test_with_manager(plugins, compile_type, expected_info):
313313
}
314314
manager = MSCManager(model, config, plugins=plugins)
315315
report = manager.run_pipe()
316-
model_info = manager.runner.model_info
316+
model_info = manager.get_runtime().model_info
317317
manager.destory()
318318
assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type)
319319
assert msc_utils.dict_equal(

tests/python/contrib/test_msc/test_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1):
100100
golden = [msc_utils.cast_array(golden)]
101101
workspace.destory()
102102
for gol_r, out_r in zip(golden, outputs):
103-
tvm.testing.assert_allclose(gol_r, out_r, atol=atol, rtol=rtol)
103+
tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol)
104104

105105

106106
def test_tvm_runner_cpu():
@@ -162,7 +162,7 @@ def test_tensorflow_runner():
162162
outputs = runner.run([data], ret_type="list")
163163
workspace.destory()
164164
for gol_r, out_r in zip(golden, outputs):
165-
tvm.testing.assert_allclose(gol_r, out_r, atol=1e-3, rtol=1e-3)
165+
tvm.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=1e-3, rtol=1e-3)
166166

167167

168168
if __name__ == "__main__":

tests/python/contrib/test_msc/test_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def get_tools(tool_type, use_distill=False, run_type=MSCFramework.MSC):
144144
}
145145
],
146146
}
147-
tools.append({"tool_type": ToolType.TRACKER, "tool_config": config, "apply_once": True})
147+
tools.append({"tool_type": ToolType.TRACKER, "tool_config": config})
148148
if use_distill:
149149
config = {
150150
"plan_file": "msc_distiller.json",
@@ -180,7 +180,7 @@ def _get_torch_model(name, training=False):
180180
def _check_manager(manager, expected_info):
181181
"""Check the manager results"""
182182

183-
model_info = manager.runner.model_info
183+
model_info = manager.get_runtime().model_info
184184
passed, err = True, ""
185185
if not manager.report["success"]:
186186
passed = False

0 commit comments

Comments
 (0)