|
33 | 33 | from tvm.contrib import graph_executor, utils |
34 | 34 | from tvm.runtime.vm import VirtualMachine |
35 | 35 |
|
| 36 | +from tvm.relay import Any, GlobalVar |
36 | 37 | from tvm.relay.transform import FirstOrderGradient, InferType |
37 | 38 | from tvm.relay.transform.transform import ToMixedPrecision |
38 | 39 |
|
@@ -88,7 +89,7 @@ def set_func_attr(func, compile_name, symbol_name): |
88 | 89 | return func |
89 | 90 |
|
90 | 91 |
|
91 | | -def run_and_verify_func(config, target="cuda", run_module=True, data_type="float16"): |
| 92 | +def run_and_verify_func(config, target="cuda", run_module=True, data_type="float32"): |
92 | 93 | """Test a Relay func by compiling, running, and comparing TVM and TRT outputs. |
93 | 94 |
|
94 | 95 | Parameters |
@@ -277,6 +278,9 @@ def test_tensorrt_not_compatible(run_module): |
277 | 278 | results = func(x_data) |
278 | 279 |
|
279 | 280 |
|
| 281 | +@pytest.mark.xfail( |
| 282 | + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") |
| 283 | +) |
280 | 284 | def test_tensorrt_serialize_graph_executor(run_module): |
281 | 285 | import mxnet as mx |
282 | 286 | from mxnet.gluon.model_zoo.vision import get_model |
@@ -331,6 +335,9 @@ def load_graph(): |
331 | 335 | assert_result_dict_holds(result_dict) |
332 | 336 |
|
333 | 337 |
|
| 338 | +@pytest.mark.xfail( |
| 339 | + reason=("Currently failing test. See tracking issue https://github.com/apache/tvm/issues/8901") |
| 340 | +) |
334 | 341 | def test_tensorrt_serialize_vm(run_module): |
335 | 342 | import mxnet as mx |
336 | 343 | from mxnet.gluon.model_zoo.vision import get_model |
@@ -473,12 +480,7 @@ def get_graph(x_shape=(1, 8, 8, 32), k_shape=(3, 3, 32, 16)): |
473 | 480 | x = relay.var("x", shape=(x_shape), dtype="float32") |
474 | 481 | kernel = relay.var("kernel", shape=(k_shape), dtype="float32") |
475 | 482 | out = relay.nn.conv2d( |
476 | | - x, |
477 | | - kernel, |
478 | | - channels=16, |
479 | | - kernel_size=(3, 3), |
480 | | - data_layout="NHWC", |
481 | | - kernel_layout="HWIO", |
| 483 | + x, kernel, channels=16, kernel_size=(3, 3), data_layout="NHWC", kernel_layout="HWIO" |
482 | 484 | ) |
483 | 485 | f = relay.Function([x, kernel], out) |
484 | 486 | return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] |
@@ -602,13 +604,7 @@ def get_graph( |
602 | 604 | count_include_pad=count_include_pad, |
603 | 605 | ) |
604 | 606 | else: |
605 | | - out = op( |
606 | | - x, |
607 | | - pool_size=pool_size, |
608 | | - strides=strides, |
609 | | - padding=padding, |
610 | | - ceil_mode=ceil_mode, |
611 | | - ) |
| 607 | + out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode) |
612 | 608 | f = relay.Function([x], out) |
613 | 609 | return f, {"x": x_shape}, [] |
614 | 610 |
|
@@ -726,11 +722,7 @@ def get_graph(x_shape, indices_or_sections, axis): |
726 | 722 |
|
727 | 723 | def test_conv2d_transpose(run_module): |
728 | 724 | def get_graph( |
729 | | - x_shape=(1, 32, 8, 8), |
730 | | - k_shape=(32, 16, 3, 3), |
731 | | - groups=1, |
732 | | - padding=(0, 0), |
733 | | - strides=(1, 1), |
| 725 | + x_shape=(1, 32, 8, 8), k_shape=(32, 16, 3, 3), groups=1, padding=(0, 0), strides=(1, 1) |
734 | 726 | ): |
735 | 727 | x = relay.var("x", shape=(x_shape), dtype="float32") |
736 | 728 | kernel = relay.var("kernel", shape=(k_shape), dtype="float32") |
@@ -1009,24 +1001,10 @@ def get_graph(x_shape, param_shape, axis=1, epsilon=1e-5): |
1009 | 1001 | gamma = relay.var("gamma", shape=(param_shape), dtype="float32") |
1010 | 1002 | beta = relay.var("beta", shape=(param_shape), dtype="float32") |
1011 | 1003 | out = relay.nn.layer_norm( |
1012 | | - x, |
1013 | | - gamma=gamma, |
1014 | | - beta=beta, |
1015 | | - axis=axis, |
1016 | | - epsilon=epsilon, |
1017 | | - center=True, |
1018 | | - scale=True, |
| 1004 | + x, gamma=gamma, beta=beta, axis=axis, epsilon=epsilon, center=True, scale=True |
1019 | 1005 | ) |
1020 | 1006 | f = relay.Function([x, gamma, beta], out) |
1021 | | - return ( |
1022 | | - f, |
1023 | | - { |
1024 | | - "x": x_shape, |
1025 | | - "beta": param_shape, |
1026 | | - "gamma": param_shape, |
1027 | | - }, |
1028 | | - ["beta", "gamma"], |
1029 | | - ) |
| 1007 | + return (f, {"x": x_shape, "beta": param_shape, "gamma": param_shape}, ["beta", "gamma"]) |
1030 | 1008 |
|
1031 | 1009 | run_and_verify_func(get_graph((1, 32, 8, 8), (32,)), run_module=run_module) |
1032 | 1010 | run_and_verify_func( |
@@ -1170,20 +1148,9 @@ def test_strided_slice(run_module): |
1170 | 1148 | def get_graph(x_shape, begin, end, strides=None, slice_mode="size"): |
1171 | 1149 | x = relay.var("x", shape=(x_shape), dtype="float32") |
1172 | 1150 | if strides: |
1173 | | - out = relay.strided_slice( |
1174 | | - x, |
1175 | | - begin, |
1176 | | - end, |
1177 | | - strides, |
1178 | | - slice_mode=slice_mode, |
1179 | | - ) |
| 1151 | + out = relay.strided_slice(x, begin, end, strides, slice_mode=slice_mode) |
1180 | 1152 | else: |
1181 | | - out = relay.strided_slice( |
1182 | | - x, |
1183 | | - begin, |
1184 | | - end, |
1185 | | - slice_mode=slice_mode, |
1186 | | - ) |
| 1153 | + out = relay.strided_slice(x, begin, end, slice_mode=slice_mode) |
1187 | 1154 | f = relay.Function([x], out) |
1188 | 1155 | return f, {"x": x_shape}, [] |
1189 | 1156 |
|
@@ -1292,13 +1259,7 @@ def get_graph( |
1292 | 1259 | count_include_pad=count_include_pad, |
1293 | 1260 | ) |
1294 | 1261 | else: |
1295 | | - out = op( |
1296 | | - x, |
1297 | | - pool_size=pool_size, |
1298 | | - strides=strides, |
1299 | | - padding=padding, |
1300 | | - ceil_mode=ceil_mode, |
1301 | | - ) |
| 1262 | + out = op(x, pool_size=pool_size, strides=strides, padding=padding, ceil_mode=ceil_mode) |
1302 | 1263 | f = relay.Function([x], out) |
1303 | 1264 | return f, {"x": x_shape}, [] |
1304 | 1265 |
|
|
0 commit comments