diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 5762c9635206..6c39a8d0a16a 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -142,7 +142,7 @@ class StringUtils { */ TVM_DLL static const std::tuple SplitOnce(const String& src_string, const String& sep, - bool from_left = true); + bool from_left = false); /*! * \brief Get the tokens between left and right. diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index 39a45035a5b2..6c47b8b39545 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -27,8 +27,11 @@ import tvm.testing from tvm.relax.frontend.torch import from_fx from tvm.relay.frontend import from_pytorch +from tvm import relay +from tvm.ir.module import IRModule from tvm.contrib.msc.core.frontend import translate from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen +from tvm.contrib.msc.core import utils as msc_utils def _valid_target(target): @@ -1057,5 +1060,38 @@ def forward(self, x, y): verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) +def test_name_string_with_colon(): + """test name string with colons, + e.g., TFLite default input name 'serving_default_input:0' + """ + + dtype = "float32" + x_var = relay.var("input_0:0", shape=(3, 5), dtype=dtype) + y_var = relay.var("input_1:0", shape=(3, 5), dtype=dtype) + z_add = relay.add(x_var, y_var) + func = relay.Function([x_var, y_var], z_add) + mod = IRModule() + mod["main"] = func + + try: + graph, _ = translate.from_relay(mod) + except Exception as err: + raise RuntimeError(f"Translation from relay to graph failed: {err}") + inspect = graph.inspect() + + expected = { + "inputs": [ + {"name": "input_0:0", "shape": [3, 5], "dtype": dtype, "layout": ""}, + {"name": "input_1:0", "shape": [3, 5], "dtype": dtype, "layout": ""}, + ], + "outputs": [{"name": "add", "shape": [3, 5], "dtype": dtype, "layout": ""}], + "nodes": {"total": 3, "input": 2, "add": 1}, + } + + assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with expected {}".format( + inspect, expected + ) + + if __name__ == "__main__": tvm.testing.main()