diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py index 0cbc9eaac375..3bd8e7810978 100644 --- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py @@ -22,6 +22,7 @@ from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length from ... import tensor_types +from ....base import _as_list class VariationalDropoutCell(ModifierCell): """ @@ -320,3 +321,117 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, return next_r, [next_r, next_c] # pylint: enable= arguments-differ + + +def dynamic_unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0, + layout='TNC', valid_length=None): + """Unrolls an RNN cell across time steps. + + Currently, 'TNC' is a preferred layout. unroll on the input of this layout + runs much faster. + + Parameters + ---------- + cell : an object whose base class is RNNCell. + The RNN cell to run on the input sequence. + inputs : Symbol + It should have shape (batch_size, length, ...) if `layout` is 'NTC', + or (length, batch_size, ...) if `layout` is 'TNC'. + begin_state : nested list of Symbol + The initial states of the RNN sequence. + drop_inputs : float, default 0. + The dropout rate for inputs. Won't apply dropout if it equals 0. + drop_outputs : float, default 0. + The dropout rate for outputs. Won't apply dropout if it equals 0. + layout : str, optional + `layout` of input symbol. Only used if inputs + is a single Symbol. + valid_length : Symbol, NDArray or None + `valid_length` specifies the length of the sequences in the batch without padding. + This option is especially useful for building sequence-to-sequence models where + the input and output sequences would potentially be padded. + If `valid_length` is None, all sequences are assumed to have the same length. + If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,). + The ith element will be the length of the ith sequence in the batch. + The last valid state will be return and the padded outputs will be masked with 0. + Note that `valid_length` must be smaller or equal to `length`. + + Returns + ------- + outputs : Symbol + the output of the RNN from this unrolling. + + states : list of Symbol + The new state of this RNN after this unrolling. + The type of this symbol is same as the output of `begin_state`. + + Examples + -------- + >>> seq_len = 3 + >>> batch_size = 2 + >>> input_size = 5 + >>> cell = mx.gluon.rnn.LSTMCell(input_size, prefix='rnn_') + >>> cell.initialize(ctx=mx.cpu()) + >>> rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size)) + >>> state_shape = (batch_size, input_size) + >>> states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(2)] + >>> valid_length = mx.nd.array([2, 3]) + >>> output, states = mx.gluon.contrib.rnn.rnn_cell.dynamic_unroll(cell, rnn_data, states, + valid_length=valid_length, + layout='TNC') + >>> print(output) + [[[ 0.00767238 0.00023103 0.03973929 -0.00925503 -0.05660512] + [ 0.00881535 0.05428379 -0.02493718 -0.01834097 0.02189514]] + [[-0.00676967 0.01447039 0.01287002 -0.00574152 -0.05734247] + [ 0.01568508 0.02650866 -0.04270559 -0.04328435 0.00904011]] + [[ 0. 0. 0. 0. 0. ] + [ 0.01055336 0.02734251 -0.03153727 -0.03742751 -0.01378113]]] + + """ + + # Merge is always True, so we don't need length. + inputs, axis, F, _ = _format_sequence(0, inputs, layout, True) + if axis != 0: + axes = list(range(len(layout))) + tmp = axes[0] + axes[0] = axes[axis] + axes[axis] = tmp + inputs = F.transpose(inputs, axes=axes) + states = begin_state + + if drop_inputs: + inputs = F.Dropout(inputs, p=drop_inputs, axes=(axis,)) + + if valid_length is None: + def loop_body(inputs, states): + return cell(inputs, states) + else: + zeros = [] + for s in states: + zeros.append(F.zeros_like(s)) + states = list(_as_list(states)) + states.append(F.zeros((1))) + def loop_body(inputs, states): + cell_states = states[:-1] + iter_no = states[-1] + out, new_states = cell(inputs, cell_states) + for i, state in enumerate(cell_states): + new_states[i] = F.where(F.broadcast_greater(valid_length, iter_no), + new_states[i], state) + new_states.append(iter_no + 1) + return out, new_states + + outputs, states = F.contrib.foreach(loop_body, inputs, states) + if drop_outputs: + outputs = F.Dropout(outputs, p=drop_outputs, axes=(axis,)) + if valid_length is not None: + if axis != 0: + outputs = F.transpose(outputs, axes) + outputs = F.SequenceMask(outputs, sequence_length=valid_length, + use_sequence_length=True, axis=axis) + # the last state is the iteration number. We don't need it. + return outputs, states[:-1] + else: + if axis != 0: + outputs = F.transpose(outputs, axes) + return outputs, states diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index 6901e8bd12fe..1e0555900f17 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -17,12 +17,14 @@ from __future__ import print_function import mxnet as mx +import copy +from mxnet import gluon from mxnet.gluon import contrib from mxnet.gluon import nn from mxnet.gluon.contrib.nn import ( Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D, PixelShuffle2D, PixelShuffle3D) -from mxnet.test_utils import almost_equal +from mxnet.test_utils import almost_equal, default_context, assert_almost_equal from common import setup_module, with_seed, teardown import numpy as np from numpy.testing import assert_allclose @@ -313,6 +315,97 @@ def test_sampler(): assert list(interval_sampler) == [0, 3, 6, 9] +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.cell = cell_type(hidden_size, prefix='rnn_') + self.layout = layout + + def hybrid_forward(self, F, inputs, states, valid_length): + if isinstance(valid_length, list) and len(valid_length) == 0: + valid_length = None + return contrib.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states, + valid_length=valid_length, + layout=self.layout) + +def check_unroll(cell_type, num_states, layout): + batch_size = 20 + input_size = 50 + hidden_size = 30 + seq_len = 10 + if layout == 'TNC': + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size)) + elif layout == 'NTC': + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(batch_size, seq_len, input_size)) + else: + print("Wrong layout") + return + valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size))) + state_shape = (batch_size, hidden_size) + states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] + + cell = cell_type(hidden_size, prefix='rnn_') + cell.initialize(ctx=default_context()) + if layout == 'TNC': + cell(rnn_data[0], states) + else: + cell(rnn_data[:,0,:], states) + params1 = cell.collect_params() + orig_params1 = copy.deepcopy(params1) + + trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length, + layout=layout, merge_outputs=True) + res1.backward() + trainer.step(batch_size) + + configs = [ + lambda layer: None, + lambda layer: layer.hybridize(), + lambda layer: layer.hybridize({'inline_limit': 0}), + lambda layer: layer.hybridize({'static_alloc': True}), + lambda layer: layer.hybridize({'static_alloc': True, 'static_shape': True}) ] + # We can't pass None to a hybrid block, but it accepts an empty list. + # so we use an empty list to represent valid_length if it's None. + if valid_length is None: + valid_length = [] + for config in configs: + layer = TestRNNLayer(cell_type, hidden_size, layout) + layer.initialize(ctx=default_context()) + config(layer) + res2, states2 = layer(rnn_data, states, valid_length) + params2 = layer.collect_params() + for key, val in orig_params1.items(): + params2[key].set_data(copy.deepcopy(val.data())) + + trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res2, states2 = layer(rnn_data, states, valid_length) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + assert len(states1) == len(states2) + for i in range(len(states1)): + assert_almost_equal(states1[i].asnumpy(), states2[i].asnumpy(), + rtol=0.001, atol=0.0001) + res2.backward() + trainer.step(batch_size) + + for key, val in params1.items(): + weight1 = val.data() + weight2 = params2[key].data() + assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), + rtol=0.001, atol=0.0001) + + +@with_seed() +def test_contrib_unroll(): + cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2), + (gluon.rnn.GRUCell, 1)] + for cell_type, num_states in cell_types: + check_unroll(cell_type, num_states, 'TNC') + check_unroll(cell_type, num_states, 'NTC') + + if __name__ == '__main__': import nose nose.runmodule()