Skip to content

Commit c022eab

Browse files
psunnroot
authored andcommitted
[MSC] add test for variable names with colons
1 parent af5d309 commit c022eab

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/python/contrib/test_msc/test_translate_relay.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@
2727
import tvm.testing
2828
from tvm.relax.frontend.torch import from_fx
2929
from tvm.relay.frontend import from_pytorch
30+
from tvm import relay
31+
from tvm.ir.module import IRModule
3032
from tvm.contrib.msc.core.frontend import translate
3133
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
34+
from tvm.contrib.msc.core import utils as msc_utils
3235

3336

3437
def _valid_target(target):
@@ -1057,5 +1060,37 @@ def forward(self, x, y):
10571060
verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")])
10581061

10591062

1063+
def test_name_string_with_colon():
1064+
"""test name string with colons,
1065+
e.g., TFLite default input name 'serving_default_input:0'
1066+
"""
1067+
1068+
dtype = "float32"
1069+
x = relay.var("input_0:0", shape=(3, 5), dtype=dtype)
1070+
y = relay.var("input_1:0", shape=(3, 5), dtype=dtype)
1071+
z = relay.add(x, y)
1072+
func = relay.Function([x, y], z)
1073+
mod = IRModule()
1074+
mod["main"] = func
1075+
1076+
try:
1077+
graph, _ = translate.from_relay(mod)
1078+
except Exception as e:
1079+
raise RuntimeError(f"Translation from relay to graph failed: {e}")
1080+
inspect = graph.inspect()
1081+
1082+
expected = {
1083+
"inputs": [
1084+
{'name': 'input_0:0', 'shape': [3, 5], 'dtype': dtype, 'layout': ''},
1085+
{'name': 'input_1:0', 'shape': [3, 5], 'dtype': dtype, 'layout': ''}],
1086+
"outputs": [{'name': 'add', 'shape': [3, 5], 'dtype': dtype, 'layout': ''}],
1087+
"nodes": {'total': 3, 'input': 2, 'add': 1},
1088+
}
1089+
1090+
assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with expected {}".format(
1091+
inspect, expected
1092+
)
1093+
1094+
10601095
if __name__ == "__main__":
10611096
tvm.testing.main()

0 commit comments

Comments
 (0)