Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix gluon bi-rnn cell single step unroll #15081

Merged
merged 1 commit into from
May 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def _reverse_sequences(sequences, unroll_step, valid_length=None):
reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
if unroll_step > 1 or F is symbol:
reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
else:
reversed_sequences = [reversed_sequences[0]]

return reversed_sequences

Expand Down
53 changes: 27 additions & 26 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,32 +634,33 @@ def test_layer_fill_shape():


def test_bidirectional_unroll_valid_length():
# Test BidirectionalCell.
# In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).

class BiLSTM(gluon.nn.HybridBlock):
def __init__(self, rnn_size, time_step, **kwargs):
super(BiLSTM, self).__init__(**kwargs)
self.time_step = time_step
with self.name_scope():
self.bi_lstm = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
output_prefix='lstm_bi_')

def hybrid_forward(self, F, inputs, valid_len):
outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
layout='NTC', merge_outputs=True)
return outputs, states

rnn_size, time_step = 100, 3
net = BiLSTM(rnn_size, time_step)
net.initialize()
net.hybridize()
inputs_data = mx.nd.random.uniform(shape=(10, 3, 50))
valid_len = mx.nd.array([1]*10)
outputs, _ = net(inputs_data, valid_len)
assert outputs.shape == (10, 3, 200)
def _check_bidirectional_unroll_valid_length(length):
class BiLSTM(gluon.nn.HybridBlock):
def __init__(self, rnn_size, time_step, **kwargs):
super(BiLSTM, self).__init__(**kwargs)
self.time_step = time_step
with self.name_scope():
self.bi_lstm = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
output_prefix='lstm_bi_')

def hybrid_forward(self, F, inputs, valid_len):
outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
layout='NTC', merge_outputs=True)
return outputs, states

rnn_size = 100
net = BiLSTM(rnn_size, length)
net.initialize()
net.hybridize()
inputs_data = mx.nd.random.uniform(shape=(10, length, 50))
valid_len = mx.nd.array([length]*10)
outputs, _ = net(inputs_data, valid_len)
assert outputs.shape == (10, length, 200)

_check_bidirectional_unroll_valid_length(1)
_check_bidirectional_unroll_valid_length(3)


if __name__ == '__main__':
Expand Down