From 928a7b050cd028594bf692d34193406d7a61b312 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Mon, 27 May 2019 11:58:21 -0700 Subject: [PATCH] fix gluon rnn cell single step unroll --- python/mxnet/gluon/rnn/rnn_cell.py | 5 ++- tests/python/unittest/test_gluon_rnn.py | 53 +++++++++++++------------ 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index 9154ccf6159a..71c7b3f84aa5 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -114,7 +114,10 @@ def _reverse_sequences(sequences, unroll_step, valid_length=None): reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0), sequence_length=valid_length, use_sequence_length=True) - reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True) + if unroll_step > 1 or F is symbol: + reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True) + else: + reversed_sequences = [reversed_sequences[0]] return reversed_sequences diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 9d7892010839..309756b122e7 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -634,32 +634,33 @@ def test_layer_fill_shape(): def test_bidirectional_unroll_valid_length(): - # Test BidirectionalCell. - # In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ). - - class BiLSTM(gluon.nn.HybridBlock): - def __init__(self, rnn_size, time_step, **kwargs): - super(BiLSTM, self).__init__(**kwargs) - self.time_step = time_step - with self.name_scope(): - self.bi_lstm = gluon.rnn.BidirectionalCell( - gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'), - gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'), - output_prefix='lstm_bi_') - - def hybrid_forward(self, F, inputs, valid_len): - outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len, - layout='NTC', merge_outputs=True) - return outputs, states - - rnn_size, time_step = 100, 3 - net = BiLSTM(rnn_size, time_step) - net.initialize() - net.hybridize() - inputs_data = mx.nd.random.uniform(shape=(10, 3, 50)) - valid_len = mx.nd.array([1]*10) - outputs, _ = net(inputs_data, valid_len) - assert outputs.shape == (10, 3, 200) + def _check_bidirectional_unroll_valid_length(length): + class BiLSTM(gluon.nn.HybridBlock): + def __init__(self, rnn_size, time_step, **kwargs): + super(BiLSTM, self).__init__(**kwargs) + self.time_step = time_step + with self.name_scope(): + self.bi_lstm = gluon.rnn.BidirectionalCell( + gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'), + gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'), + output_prefix='lstm_bi_') + + def hybrid_forward(self, F, inputs, valid_len): + outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len, + layout='NTC', merge_outputs=True) + return outputs, states + + rnn_size = 100 + net = BiLSTM(rnn_size, length) + net.initialize() + net.hybridize() + inputs_data = mx.nd.random.uniform(shape=(10, length, 50)) + valid_len = mx.nd.array([length]*10) + outputs, _ = net(inputs_data, valid_len) + assert outputs.shape == (10, length, 200) + + _check_bidirectional_unroll_valid_length(1) + _check_bidirectional_unroll_valid_length(3) if __name__ == '__main__':