Skip to content

Commit c547bbb

Browse files
author
Josh Fromm
authored
[Relay][Frontend][Onnx] SequenceAt and SplitToSequence Operators (#13602)
* Add support for SequenceAt and SplitToSequence to onnx importer * Formatting * Change keepdims comparison * Only unify non-tuples in If
1 parent 12311dc commit c547bbb

File tree

3 files changed

+98
-14
lines changed

3 files changed

+98
-14
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4008,6 +4008,23 @@ def _impl_v1(cls, inputs, attr, params):
40084008
for var in else_free_vars:
40094009
graph_scope._nodes.update({var.name_hint: var})
40104010

4011+
# Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks.
4012+
# Often these dont contribute anything. If we see a dynamic rank output, try to unify
4013+
# them so we can continue without breaking.
4014+
if not isinstance(then_expr, _expr.Tuple) and not isinstance(else_expr, _expr.Tuple):
4015+
then_shape = infer_shape(then_expr)
4016+
else_shape = infer_shape(else_expr)
4017+
if len(then_shape) != len(else_shape):
4018+
warning_msg = (
4019+
"If statement produced outputs with different rank. "
4020+
"Attempting to unify ranks but this may produce incorrect results."
4021+
)
4022+
warnings.warn(warning_msg)
4023+
if len(then_shape) < len(else_shape):
4024+
then_expr = _op.broadcast_to_like(then_expr, else_expr)
4025+
else:
4026+
else_expr = _op.broadcast_to_like(else_expr, then_expr)
4027+
40114028
# Now we can construct the relay if statement and return.
40124029
ret = _expr.If(cond, then_expr, else_expr)
40134030
if len(then_branch.output) > 1:
@@ -5565,6 +5582,66 @@ def _impl_v11(cls, inputs, attr, params):
55655582
return _op.concatenate(inputs[0], axis=axis)
55665583

55675584

5585+
class SplitToSequence(OnnxOpConverter):
5586+
"""Operator converter for split to sequence op."""
5587+
5588+
@classmethod
5589+
def _impl_v11(cls, inputs, attr, params):
5590+
axis = attr.get("axis", 0)
5591+
keepdims = attr.get("keepdims", 1)
5592+
5593+
input_tensor = inputs[0]
5594+
input_shape = infer_shape(input_tensor)
5595+
split = inputs[1]
5596+
5597+
# If split is not provided, we split all values along axis.
5598+
if split is None:
5599+
output = _op.split(input_tensor, input_shape[axis], axis=axis)
5600+
# If keepdims is 0, then we need to squeeze off the axis.
5601+
if not keepdims:
5602+
output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output]
5603+
return _expr.Tuple(list(output))
5604+
5605+
# Otherwise, split based on provided split value.
5606+
else:
5607+
# For now we only support constant valued split.
5608+
assert isinstance(
5609+
split, _expr.Constant
5610+
), "Only constant split supported for SplitToSequence"
5611+
split = split.data.numpy()
5612+
if len(split.shape) == 1 and split.shape[0] > 1:
5613+
# If split is a 1D tensor, it must be converted to indices for relay compatibility.
5614+
split = np.cumsum(split)
5615+
# Remove final invalid index.
5616+
split = split[:-1]
5617+
else:
5618+
# Otherwise get split as an integer.
5619+
split = int(split)
5620+
5621+
output = _op.split(input_tensor, split, axis=axis)
5622+
5623+
# If keepdims is set to 0 remove split axis. Note that this is
5624+
# an inconsistency with the onnx spec but is needed for pytorch compatibility.
5625+
if not keepdims:
5626+
output = [_op.squeeze(tensor_slice, axis=[axis]) for tensor_slice in output]
5627+
return _expr.Tuple(list(output))
5628+
5629+
5630+
class SequenceAt(OnnxOpConverter):
5631+
"""Operator converter for sequence at op."""
5632+
5633+
@classmethod
5634+
def _impl_v11(cls, inputs, attr, params):
5635+
input_sequence = inputs[0]
5636+
position = inputs[1]
5637+
assert isinstance(
5638+
position, _expr.Constant
5639+
), "Only constant position supported for SequenceAt"
5640+
# Convert position to integer.
5641+
position = int(position.data.numpy())
5642+
return input_sequence[position]
5643+
5644+
55685645
# compatible operators that do NOT require any conversion.
55695646
_identity_list = []
55705647

@@ -5793,6 +5870,8 @@ def _get_convert_map(opset):
57935870
"SequenceConstruct": SequenceConstruct.get_converter(opset),
57945871
"SequenceInsert": SequenceInsert.get_converter(opset),
57955872
"ConcatFromSequence": ConcatFromSequence.get_converter(opset),
5873+
"SplitToSequence": SplitToSequence.get_converter(opset),
5874+
"SequenceAt": SequenceAt.get_converter(opset),
57965875
}
57975876

57985877

python/tvm/relay/op/_transform.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,6 @@ def _concatenate_shape_func(inputs, axis):
432432
for i in const_range(ndim):
433433
if i != axis:
434434
out[i] = inputs[0][i]
435-
for j in const_range(1, len(inputs)):
436-
assert out[i] == inputs[j][i], "Dims mismatch in the inputs of concatenate."
437435
else:
438436
out[i] = int64(0)
439437
for j in const_range(len(inputs)):

tests/python/frontend/onnx/test_forward.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7043,7 +7043,7 @@ def verify_linear_regressor(a_shape, c_shape, i_shape, targets=1, batch=1):
70437043
def test_sequence(target, dev):
70447044
"""test_sequence"""
70457045

7046-
def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_axis=None):
7046+
def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=0, new_axis=None):
70477047
tensor_shape = list(tensor_shape)
70487048
tensor_values = []
70497049
for i in range(num_tensors):
@@ -7062,20 +7062,30 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax
70627062
outputs=["sequence"],
70637063
)
70647064

7065-
insert_inputs = ["sequence", input_tensor_names[0]]
7066-
position_node = None
7067-
if position is not None:
7068-
insert_inputs.append("position")
7069-
position_node = make_constant_node("position", TensorProto.INT32, (), [position])
7065+
position_node = make_constant_node("position", TensorProto.INT32, (), [position])
70707066

70717067
# Test sequence insertion.
70727068
insert_node = helper.make_node(
7073-
"SequenceInsert", inputs=insert_inputs, outputs=["inserted_sequence"]
7069+
"SequenceInsert",
7070+
inputs=["sequence", input_tensor_names[0], "position"],
7071+
outputs=["inserted_sequence"],
70747072
)
70757073

70767074
# Test sequence concatenation.
70777075
concat_node = helper.make_node(
7078-
"ConcatFromSequence", inputs=["inserted_sequence"], outputs=["output"], axis=axis
7076+
"ConcatFromSequence",
7077+
inputs=["inserted_sequence"],
7078+
outputs=["concat_sequence"],
7079+
axis=axis,
7080+
)
7081+
7082+
# Test splitting a tensor into a sequence.
7083+
split_node = helper.make_node(
7084+
"SplitToSequence", inputs=["concat_sequence"], outputs=["split_sequence"], axis=axis
7085+
)
7086+
7087+
at_node = helper.make_node(
7088+
"SequenceAt", inputs=["split_sequence", "position"], outputs=["output"]
70797089
)
70807090

70817091
if new_axis is not None:
@@ -7097,10 +7107,7 @@ def verify_sequence_ops(tensor_shape, num_tensors, axis=0, position=None, new_ax
70977107
output_shape[axis] = (num_tensors + 1) * output_shape[axis]
70987108
graph_outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)]
70997109

7100-
graph_nodes = []
7101-
if position_node is not None:
7102-
graph_nodes.append(position_node)
7103-
graph_nodes += [construct_node, insert_node, concat_node]
7110+
graph_nodes = [position_node, construct_node, insert_node, concat_node, split_node, at_node]
71047111

71057112
graph = helper.make_graph(
71067113
graph_nodes,

0 commit comments

Comments
 (0)