From dff03b41148193e2e5fe5c95e202466b65222e3a Mon Sep 17 00:00:00 2001 From: An Wang Date: Wed, 10 Nov 2021 11:23:07 -0800 Subject: [PATCH 1/2] split fix --- python/tvm/relay/frontend/onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3c88f659f6f0..5813f6305ace 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1461,9 +1461,8 @@ class Split(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): splits = attr.get("split", None) - if splits is not None: + if splits is not None and len(splits) > 1: indices = [] - attr["indices_or_sections"] = [] index = 0 for i in splits[:-1]: index += i From 589a048ad918ae9d88fa4ba260124c61bec91f4e Mon Sep 17 00:00:00 2001 From: An Wang Date: Wed, 10 Nov 2021 14:22:07 -0800 Subject: [PATCH 2/2] add default split test case --- tests/python/frontend/onnx/test_forward.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dd1c77330986..f8870edcb6d1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1966,6 +1966,9 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11): verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False) # Split a single value to a single value verify_split([1], [[1]], [1], pass_split=True) + # Test that the default case modifies nothing when split list has length one + verify_split([[1.0, 2.0]], [[1.0, 2.0]], [2], 1) + verify_split([[1.0, 2.0]], [[1.0, 2.0]], [1], 0) @tvm.testing.parametrize_targets