Skip to content

Commit 29241c8

Browse files
Jianjian.GuanJianjian.Guan
authored andcommitted
[Frontend] [ONNX] Support sequence_lens of GRU.
Support convert sequence_lens input of GRU.
1 parent c547bbb commit 29241c8

File tree

3 files changed

+103
-11
lines changed

3 files changed

+103
-11
lines changed

python/tvm/relay/frontend/common.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ def gru_cell(
737737
n_act=_op.tanh,
738738
backwards=False,
739739
linear_before_reset=True,
740+
sequence_lens=None,
740741
):
741742
"""
742743
Common implementation of GRU cell for all frontends of TVM
@@ -765,15 +766,52 @@ def gru_cell(
765766
activation function for new gate. it is tanh by default
766767
backwards : bool
767768
Flag for reverse pass of GRU
768-
769+
linear_before_reset : bool
770+
Flag for applying the linear transformation before multiplying by the output of the reset gate.
771+
sequence_lens : relay.op
772+
Tensor specifying lengths of the sequences in a batch.
773+
Shape = (batch_size)
769774
Returns
770775
-------
771776
result : List[relay.Expr], relay.Expr, relay.Expr
772777
The sequence of computed result, final hidden and cell state
773778
"""
774779

775780
outputs_list = []
776-
for x_t in input_seqs if not backwards else reversed(input_seqs):
781+
782+
seq_len = len(input_seqs)
783+
input_dtype = infer_type(input_seqs[0]).checked_type.dtype
784+
785+
if sequence_lens is not None:
786+
shape = infer_shape(sequence_lens)
787+
dtype = infer_type(sequence_lens).checked_type.dtype
788+
789+
arange = _op.arange(_op.const(0), _op.const(seq_len), dtype=dtype)
790+
arange = _op.expand_dims(arange, 1)
791+
sequence_lens = _op.broadcast_to(sequence_lens, [seq_len, shape[0]])
792+
793+
# cast to data dtype
794+
mask = _op.less(arange, sequence_lens)
795+
mask = _op.cast(mask, dtype=input_dtype)
796+
mask = _op.expand_dims(mask, 2)
797+
mask_seqs = unbind(mask)
798+
799+
res_mask = _op.greater_equal(arange, sequence_lens)
800+
res_mask = _op.cast(res_mask, dtype=input_dtype)
801+
res_mask = _op.expand_dims(res_mask, 2)
802+
res_mask_seqs = unbind(res_mask)
803+
804+
if backwards:
805+
# need a mask to keep intial_h_B correct
806+
initial_h = hidden_state
807+
initial_h_mask = _op.equal(arange, sequence_lens)
808+
initial_h_mask = _op.cast(initial_h_mask, dtype=input_dtype)
809+
initial_h_mask = _op.expand_dims(initial_h_mask, 2)
810+
initial_h_mask_seqs = unbind(initial_h_mask)
811+
812+
output = _op.zeros(infer_shape(hidden_state), input_dtype)
813+
for i in range(seq_len) if not backwards else reversed(range(seq_len)):
814+
x_t = input_seqs[i]
777815
xwt = _op.nn.dense(x_t, w_inp)
778816
if linear_before_reset:
779817
hwt = _op.nn.dense(hidden_state, w_hid)
@@ -806,9 +844,21 @@ def gru_cell(
806844

807845
hidden_state = (hidden_state - n_gate) * z_gate + n_gate
808846

847+
if sequence_lens is not None:
848+
hidden_state = hidden_state * mask_seqs[i]
849+
809850
outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]
810851

811-
return outputs_list, hidden_state
852+
if sequence_lens is not None:
853+
output = output * res_mask_seqs[i] + hidden_state
854+
else:
855+
output = hidden_state
856+
857+
# make sure initial_h_B correct
858+
if backwards and sequence_lens is not None:
859+
hidden_state = hidden_state + initial_h * initial_h_mask_seqs[i]
860+
861+
return outputs_list, output
812862

813863

814864
def lstm_cell(

python/tvm/relay/frontend/onnx.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3126,8 +3126,7 @@ def _inputs_helper(cls, inputs, layout):
31263126
Wp = inputs[1]
31273127
Rp = inputs[2]
31283128
Bp = inputs[3]
3129-
# Sequence length currently unused as it can be inferred from shapes.
3130-
# sequence_lens = inputs['sequence_lens']
3129+
sequence_lens = inputs[4]
31313130
Hp_0 = inputs[5]
31323131

31333132
num_directions = infer_shape(Wp)[0]
@@ -3158,11 +3157,11 @@ def _inputs_helper(cls, inputs, layout):
31583157
Bs = None
31593158
if Bp is not None:
31603159
Bs = _op.split(Bp, num_directions)
3161-
return X_steps, H_ts, Ws, Rs, Bs, num_directions
3160+
return X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens
31623161

31633162
@classmethod
31643163
def _impl_common(cls, inputs, attr, layout):
3165-
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
3164+
X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
31663165
acts = cls._get_activations(attr, 1, num_directions, "RNN")
31673166

31683167
weights_dicts = []
@@ -3261,7 +3260,7 @@ def _default_activations(cls, num_directions):
32613260

32623261
@classmethod
32633262
def _impl_common(cls, inputs, attr, layout):
3264-
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
3263+
X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
32653264
acts = cls._get_activations(attr, 3, num_directions, "LSTM")
32663265

32673266
# cell state
@@ -3346,6 +3345,7 @@ def bidir_gru_cell(
33463345
input_seqs,
33473346
weight_dicts,
33483347
acts,
3348+
sequence_lens=None,
33493349
):
33503350
"""
33513351
Bidirectional GRU cell
@@ -3356,6 +3356,7 @@ def bidir_gru_cell(
33563356
**weight_dicts[0],
33573357
rz_act=acts[0],
33583358
n_act=acts[1],
3359+
sequence_lens=sequence_lens,
33593360
)
33603361

33613362
reverse_outputs, rev_H_t = gru_cell(
@@ -3364,6 +3365,7 @@ def bidir_gru_cell(
33643365
rz_act=acts[2],
33653366
n_act=acts[3],
33663367
backwards=True,
3368+
sequence_lens=sequence_lens,
33673369
)
33683370

33693371
final_outputs = []
@@ -3383,7 +3385,9 @@ def _default_activations(cls, num_directions):
33833385

33843386
@classmethod
33853387
def _impl_common(cls, inputs, attr, layout):
3386-
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
3388+
X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens = cls._inputs_helper(
3389+
inputs, layout
3390+
)
33873391
acts = cls._get_activations(attr, 2, num_directions, "GRU")
33883392
linear_before_reset = attr.get("linear_before_reset", 0)
33893393

@@ -3412,6 +3416,7 @@ def _impl_common(cls, inputs, attr, layout):
34123416
input_seqs=X_steps,
34133417
weight_dicts=weights_dicts,
34143418
acts=acts,
3419+
sequence_lens=sequence_lens,
34153420
)
34163421
else:
34173422
# outputs shape = [seqs_num, (batch_size, hidden_size)]
@@ -3420,6 +3425,7 @@ def _impl_common(cls, inputs, attr, layout):
34203425
**weights_dicts[0],
34213426
rz_act=acts[0],
34223427
n_act=acts[1],
3428+
sequence_lens=sequence_lens,
34233429
)
34243430

34253431
# output shape = (seqs_num, num_directions, batch_size, hidden_size)

tests/python/frontend/onnx/test_forward.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3897,6 +3897,7 @@ def verify_rnn(
38973897
atol=1e-5,
38983898
target=None,
38993899
dev=None,
3900+
use_sequence_lens=False,
39003901
):
39013902
"""verify_rnn"""
39023903
if rnn_type == "RNN":
@@ -3954,10 +3955,16 @@ def register(np_arr, name, shape=None):
39543955
)
39553956
register(b_np, "B")
39563957

3958+
if use_sequence_lens:
3959+
sequence_np = np.random.uniform(0, seq_length, size=(batch_size)).astype("int32")
3960+
register(sequence_np, "sequence_lens")
3961+
39573962
if use_initial_state:
39583963
assert use_bias is True, "Initial states must have bias specified."
3959-
sequence_np = np.repeat(seq_length, batch_size).astype("int32")
3960-
register(sequence_np, "sequence_lens")
3964+
3965+
if not use_sequence_lens:
3966+
sequence_np = np.repeat(seq_length, batch_size).astype("int32")
3967+
register(sequence_np, "sequence_lens")
39613968

39623969
if layout == 1:
39633970
initial_h_np = np.random.uniform(size=(batch_size, directions, hidden_size)).astype(
@@ -4211,6 +4218,35 @@ def verify_rnn_helper(target, dev, rnn_type):
42114218
# dev=dev,
42124219
# )
42134220

4221+
# Testing with initial state
4222+
if rnn_type == "GRU":
4223+
verify_rnn(
4224+
seq_length=2,
4225+
batch_size=1,
4226+
input_size=16,
4227+
hidden_size=32,
4228+
use_bias=True,
4229+
use_initial_state=True,
4230+
rnn_type=rnn_type,
4231+
directions=directions,
4232+
target=target,
4233+
dev=dev,
4234+
use_sequence_lens=True,
4235+
)
4236+
verify_rnn(
4237+
seq_length=8,
4238+
batch_size=8,
4239+
input_size=16,
4240+
hidden_size=32,
4241+
use_bias=True,
4242+
use_initial_state=True,
4243+
rnn_type=rnn_type,
4244+
directions=directions,
4245+
target=target,
4246+
dev=dev,
4247+
use_sequence_lens=True,
4248+
)
4249+
42144250
# Testing with peepholes
42154251
if rnn_type == "LSTM":
42164252
verify_rnn(

0 commit comments

Comments
 (0)