Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,13 +609,16 @@ def try_infer_value(val, on_success=None, on_failure=None, parameters=None):
return val, False


def shape_of(x, dtype="int64"):
def shape_of(x, dtype="int64", start=None, end=None):
"""Get shape of a tensor."""

ttype = infer_type(x).checked_type
if not _ty.is_dynamic(ttype):
shape = list(ttype.shape)
return _expr.const(shape, dtype)
start = start or 0 # default to first
end = end or len(shape) # default to last
shape_sliced = shape[start:end]
return _expr.const(shape_sliced, dtype)
return _op.shape_of(x, dtype)


Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,12 @@ class Shape(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
return shape_of(inputs[0], "int64")

@classmethod
def _impl_v15(cls, inputs, attr, params):
start = attr.get("start")
end = attr.get("end")
return shape_of(inputs[0], dtype="int64", start=start, end=end)


class CumSum(OnnxOpConverter):
"""Operator converter for CumSum."""
Expand Down
6 changes: 0 additions & 6 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5082,12 +5082,6 @@ def verify_eyelike(indata):
"test_round",
"test_sequence_insert_at_back",
"test_sequence_insert_at_front",
"test_shape_end_1",
"test_shape_end_negative_1",
"test_shape_start_1",
"test_shape_start_1_end_2",
"test_shape_start_1_end_negative_1",
"test_shape_start_negative_1",
"test_simple_rnn_batchwise",
"test_simple_rnn_defaults",
"test_simple_rnn_with_initial_bias",
Expand Down