Skip to content

Commit 504fc07

Browse files
margaretqianMargaret Qian
authored andcommitted
[ONNX] Update onnx shape op with slice index support (apache#10947)
* support shape op slice indices * lint Co-authored-by: Margaret Qian <[email protected]>
1 parent f69fa01 commit 504fc07

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

python/tvm/relay/frontend/common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,13 +609,16 @@ def try_infer_value(val, on_success=None, on_failure=None, parameters=None):
609609
return val, False
610610

611611

612-
def shape_of(x, dtype="int64"):
612+
def shape_of(x, dtype="int64", start=None, end=None):
613613
"""Get shape of a tensor."""
614614

615615
ttype = infer_type(x).checked_type
616616
if not _ty.is_dynamic(ttype):
617617
shape = list(ttype.shape)
618-
return _expr.const(shape, dtype)
618+
start = start or 0 # default to first
619+
end = end or len(shape) # default to last
620+
shape_sliced = shape[start:end]
621+
return _expr.const(shape_sliced, dtype)
619622
return _op.shape_of(x, dtype)
620623

621624

python/tvm/relay/frontend/onnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,12 @@ class Shape(OnnxOpConverter):
14351435
def _impl_v1(cls, inputs, attr, params):
14361436
return shape_of(inputs[0], "int64")
14371437

1438+
@classmethod
1439+
def _impl_v15(cls, inputs, attr, params):
1440+
start = attr.get("start")
1441+
end = attr.get("end")
1442+
return shape_of(inputs[0], dtype="int64", start=start, end=end)
1443+
14381444

14391445
class CumSum(OnnxOpConverter):
14401446
"""Operator converter for CumSum."""

tests/python/frontend/onnx/test_forward.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5082,12 +5082,6 @@ def verify_eyelike(indata):
50825082
"test_round",
50835083
"test_sequence_insert_at_back",
50845084
"test_sequence_insert_at_front",
5085-
"test_shape_end_1",
5086-
"test_shape_end_negative_1",
5087-
"test_shape_start_1",
5088-
"test_shape_start_1_end_2",
5089-
"test_shape_start_1_end_negative_1",
5090-
"test_shape_start_negative_1",
50915085
"test_simple_rnn_batchwise",
50925086
"test_simple_rnn_defaults",
50935087
"test_simple_rnn_with_initial_bias",

0 commit comments

Comments
 (0)