@@ -4008,6 +4008,23 @@ def _impl_v1(cls, inputs, attr, params):
40084008 for var in else_free_vars :
40094009 graph_scope ._nodes .update ({var .name_hint : var })
40104010
4011+ # Sometimes pytorch to onnx will insert silly if statements that produce dynamic ranks.
4012+ # Often these dont contribute anything. If we see a dynamic rank output, try to unify
4013+ # them so we can continue without breaking.
4014+ if not isinstance (then_expr , _expr .Tuple ) and not isinstance (else_expr , _expr .Tuple ):
4015+ then_shape = infer_shape (then_expr )
4016+ else_shape = infer_shape (else_expr )
4017+ if len (then_shape ) != len (else_shape ):
4018+ warning_msg = (
4019+ "If statement produced outputs with different rank. "
4020+ "Attempting to unify ranks but this may produce incorrect results."
4021+ )
4022+ warnings .warn (warning_msg )
4023+ if len (then_shape ) < len (else_shape ):
4024+ then_expr = _op .broadcast_to_like (then_expr , else_expr )
4025+ else :
4026+ else_expr = _op .broadcast_to_like (else_expr , then_expr )
4027+
40114028 # Now we can construct the relay if statement and return.
40124029 ret = _expr .If (cond , then_expr , else_expr )
40134030 if len (then_branch .output ) > 1 :
@@ -5565,6 +5582,66 @@ def _impl_v11(cls, inputs, attr, params):
55655582 return _op .concatenate (inputs [0 ], axis = axis )
55665583
55675584
5585+ class SplitToSequence (OnnxOpConverter ):
5586+ """Operator converter for split to sequence op."""
5587+
5588+ @classmethod
5589+ def _impl_v11 (cls , inputs , attr , params ):
5590+ axis = attr .get ("axis" , 0 )
5591+ keepdims = attr .get ("keepdims" , 1 )
5592+
5593+ input_tensor = inputs [0 ]
5594+ input_shape = infer_shape (input_tensor )
5595+ split = inputs [1 ]
5596+
5597+ # If split is not provided, we split all values along axis.
5598+ if split is None :
5599+ output = _op .split (input_tensor , input_shape [axis ], axis = axis )
5600+ # If keepdims is 0, then we need to squeeze off the axis.
5601+ if not keepdims :
5602+ output = [_op .squeeze (tensor_slice , axis = [axis ]) for tensor_slice in output ]
5603+ return _expr .Tuple (list (output ))
5604+
5605+ # Otherwise, split based on provided split value.
5606+ else :
5607+ # For now we only support constant valued split.
5608+ assert isinstance (
5609+ split , _expr .Constant
5610+ ), "Only constant split supported for SplitToSequence"
5611+ split = split .data .numpy ()
5612+ if len (split .shape ) == 1 and split .shape [0 ] > 1 :
5613+ # If split is a 1D tensor, it must be converted to indices for relay compatibility.
5614+ split = np .cumsum (split )
5615+ # Remove final invalid index.
5616+ split = split [:- 1 ]
5617+ else :
5618+ # Otherwise get split as an integer.
5619+ split = int (split )
5620+
5621+ output = _op .split (input_tensor , split , axis = axis )
5622+
5623+ # If keepdims is set to 0 remove split axis. Note that this is
5624+ # an inconsistency with the onnx spec but is needed for pytorch compatibility.
5625+ if not keepdims :
5626+ output = [_op .squeeze (tensor_slice , axis = [axis ]) for tensor_slice in output ]
5627+ return _expr .Tuple (list (output ))
5628+
5629+
5630+ class SequenceAt (OnnxOpConverter ):
5631+ """Operator converter for sequence at op."""
5632+
5633+ @classmethod
5634+ def _impl_v11 (cls , inputs , attr , params ):
5635+ input_sequence = inputs [0 ]
5636+ position = inputs [1 ]
5637+ assert isinstance (
5638+ position , _expr .Constant
5639+ ), "Only constant position supported for SequenceAt"
5640+ # Convert position to integer.
5641+ position = int (position .data .numpy ())
5642+ return input_sequence [position ]
5643+
5644+
55685645# compatible operators that do NOT require any conversion.
55695646_identity_list = []
55705647
@@ -5793,6 +5870,8 @@ def _get_convert_map(opset):
57935870 "SequenceConstruct" : SequenceConstruct .get_converter (opset ),
57945871 "SequenceInsert" : SequenceInsert .get_converter (opset ),
57955872 "ConcatFromSequence" : ConcatFromSequence .get_converter (opset ),
5873+ "SplitToSequence" : SplitToSequence .get_converter (opset ),
5874+ "SequenceAt" : SequenceAt .get_converter (opset ),
57965875 }
57975876
57985877
0 commit comments