|
27 | 27 | import tvm.testing |
28 | 28 | from tvm.relax.frontend.torch import from_fx |
29 | 29 | from tvm.relay.frontend import from_pytorch |
| 30 | +from tvm import relay |
| 31 | +from tvm.ir.module import IRModule |
30 | 32 | from tvm.contrib.msc.core.frontend import translate |
31 | 33 | from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen |
| 34 | +from tvm.contrib.msc.core import utils as msc_utils |
32 | 35 |
|
33 | 36 |
|
34 | 37 | def _valid_target(target): |
@@ -1057,5 +1060,37 @@ def forward(self, x, y): |
1057 | 1060 | verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) |
1058 | 1061 |
|
1059 | 1062 |
|
| 1063 | +def test_name_string_with_colon(): |
| 1064 | + """test name string with colons, |
| 1065 | + e.g., TFLite default input name 'serving_default_input:0' |
| 1066 | + """ |
| 1067 | + |
| 1068 | + dtype = "float32" |
| 1069 | + x = relay.var("input_0:0", shape=(3, 5), dtype=dtype) |
| 1070 | + y = relay.var("input_1:0", shape=(3, 5), dtype=dtype) |
| 1071 | + z = relay.add(x, y) |
| 1072 | + func = relay.Function([x, y], z) |
| 1073 | + mod = IRModule() |
| 1074 | + mod["main"] = func |
| 1075 | + |
| 1076 | + try: |
| 1077 | + graph, _ = translate.from_relay(mod) |
| 1078 | + except Exception as e: |
| 1079 | + raise RuntimeError(f"Translation from relay to graph failed: {e}") |
| 1080 | + inspect = graph.inspect() |
| 1081 | + |
| 1082 | + expected = { |
| 1083 | + "inputs": [ |
| 1084 | + {'name': 'input_0:0', 'shape': [3, 5], 'dtype': dtype, 'layout': ''}, |
| 1085 | + {'name': 'input_1:0', 'shape': [3, 5], 'dtype': dtype, 'layout': ''}], |
| 1086 | + "outputs": [{'name': 'add', 'shape': [3, 5], 'dtype': dtype, 'layout': ''}], |
| 1087 | + "nodes": {'total': 3, 'input': 2, 'add': 1}, |
| 1088 | + } |
| 1089 | + |
| 1090 | + assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with expected {}".format( |
| 1091 | + inspect, expected |
| 1092 | + ) |
| 1093 | + |
| 1094 | + |
1060 | 1095 | if __name__ == "__main__": |
1061 | 1096 | tvm.testing.main() |
0 commit comments