diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3c88f659f6f0..5813f6305ace 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1461,9 +1461,8 @@ class Split(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): splits = attr.get("split", None) - if splits is not None: + if splits is not None and len(splits) > 1: indices = [] - attr["indices_or_sections"] = [] index = 0 for i in splits[:-1]: index += i diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dd1c77330986..f8870edcb6d1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1966,6 +1966,9 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11): verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False) # Split a single value to a single value verify_split([1], [[1]], [1], pass_split=True) + # Test that the default case modifies nothing when split list has length one + verify_split([[1.0, 2.0]], [[1.0, 2.0]], [2], 1) + verify_split([[1.0, 2.0]], [[1.0, 2.0]], [1], 0) @tvm.testing.parametrize_targets