From 71d3a4b6f0be16f44ae580a8ec3abae0cb0aa70e Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Fri, 8 Apr 2022 13:49:52 -0700 Subject: [PATCH 1/2] support shape op slice indices --- python/tvm/relay/frontend/common.py | 7 +++++-- python/tvm/relay/frontend/onnx.py | 7 +++++++ tests/python/frontend/onnx/test_forward.py | 6 ------ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index eeede181f6f9..c3e1789220f6 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -609,13 +609,16 @@ def try_infer_value(val, on_success=None, on_failure=None, parameters=None): return val, False -def shape_of(x, dtype="int64"): +def shape_of(x, dtype="int64", start=None, end=None): """Get shape of a tensor.""" ttype = infer_type(x).checked_type if not _ty.is_dynamic(ttype): shape = list(ttype.shape) - return _expr.const(shape, dtype) + start = start or 0 # default to first + end = end or len(shape) # default to last + shape_sliced = shape[start:end] + return _expr.const(shape_sliced, dtype) return _op.shape_of(x, dtype) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 04fb17abbb19..c19e5adc01b4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1394,6 +1394,13 @@ class Shape(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return shape_of(inputs[0], "int64") + @classmethod + def _impl_v15(cls, inputs, attr, params): + start = attr.get("start") + end = attr.get("end") + return shape_of(inputs[0], dtype="int64", start=start, end=end) + + class CumSum(OnnxOpConverter): """Operator converter for CumSum.""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 91775d27b2de..0dea96dcf5e9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5082,12 +5082,6 @@ def verify_eyelike(indata): "test_round", "test_sequence_insert_at_back", "test_sequence_insert_at_front", - "test_shape_end_1", - "test_shape_end_negative_1", - "test_shape_start_1", - "test_shape_start_1_end_2", - "test_shape_start_1_end_negative_1", - "test_shape_start_negative_1", "test_simple_rnn_batchwise", "test_simple_rnn_defaults", "test_simple_rnn_with_initial_bias", From 1e8b38e505e191ae3a57a5149543629d0cb4783f Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Fri, 8 Apr 2022 14:09:20 -0700 Subject: [PATCH 2/2] lint --- python/tvm/relay/frontend/common.py | 4 ++-- python/tvm/relay/frontend/onnx.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index c3e1789220f6..7a1e98402996 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -615,8 +615,8 @@ def shape_of(x, dtype="int64", start=None, end=None): ttype = infer_type(x).checked_type if not _ty.is_dynamic(ttype): shape = list(ttype.shape) - start = start or 0 # default to first - end = end or len(shape) # default to last + start = start or 0 # default to first + end = end or len(shape) # default to last shape_sliced = shape[start:end] return _expr.const(shape_sliced, dtype) return _op.shape_of(x, dtype) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c19e5adc01b4..36ee3aea0f89 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1399,7 +1399,6 @@ def _impl_v15(cls, inputs, attr, params): start = attr.get("start") end = attr.get("end") return shape_of(inputs[0], dtype="int64", start=start, end=end) - class CumSum(OnnxOpConverter):