diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index eeede181f6f9..7a1e98402996 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -609,13 +609,16 @@ def try_infer_value(val, on_success=None, on_failure=None, parameters=None): return val, False -def shape_of(x, dtype="int64"): +def shape_of(x, dtype="int64", start=None, end=None): """Get shape of a tensor.""" ttype = infer_type(x).checked_type if not _ty.is_dynamic(ttype): shape = list(ttype.shape) - return _expr.const(shape, dtype) + start = start or 0 # default to first + end = end or len(shape) # default to last + shape_sliced = shape[start:end] + return _expr.const(shape_sliced, dtype) return _op.shape_of(x, dtype) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 04fb17abbb19..36ee3aea0f89 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1394,6 +1394,12 @@ class Shape(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return shape_of(inputs[0], "int64") + @classmethod + def _impl_v15(cls, inputs, attr, params): + start = attr.get("start") + end = attr.get("end") + return shape_of(inputs[0], dtype="int64", start=start, end=end) + class CumSum(OnnxOpConverter): """Operator converter for CumSum.""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 91775d27b2de..0dea96dcf5e9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5082,12 +5082,6 @@ def verify_eyelike(indata): "test_round", "test_sequence_insert_at_back", "test_sequence_insert_at_front", - "test_shape_end_1", - "test_shape_end_negative_1", - "test_shape_start_1", - "test_shape_start_1_end_2", - "test_shape_start_1_end_negative_1", - "test_shape_start_negative_1", "test_simple_rnn_batchwise", "test_simple_rnn_defaults", "test_simple_rnn_with_initial_bias",