diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py index 7ae8bfa71e26..3bd8e7810978 100644 --- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py @@ -323,8 +323,8 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, # pylint: enable= arguments-differ -def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0, - layout='TNC', valid_length=None): +def dynamic_unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0, + layout='TNC', valid_length=None): """Unrolls an RNN cell across time steps. Currently, 'TNC' is a preferred layout. unroll on the input of this layout @@ -376,9 +376,9 @@ def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0, >>> state_shape = (batch_size, input_size) >>> states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(2)] >>> valid_length = mx.nd.array([2, 3]) - >>> output, states = mx.gluon.contrib.rnn.rnn_cell.unroll(cell, rnn_data, states, - valid_length=valid_length, - layout='TNC') + >>> output, states = mx.gluon.contrib.rnn.rnn_cell.dynamic_unroll(cell, rnn_data, states, + valid_length=valid_length, + layout='TNC') >>> print(output) [[[ 0.00767238 0.00023103 0.03973929 -0.00925503 -0.05660512] [ 0.00881535 0.05428379 -0.02493718 -0.01834097 0.02189514]] diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index b6df8eeab927..1e0555900f17 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -324,8 +324,9 @@ def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None): def hybrid_forward(self, F, inputs, states, valid_length): if isinstance(valid_length, list) and len(valid_length) == 0: valid_length = None - return contrib.rnn.rnn_cell.unroll(self.cell, inputs, states, - valid_length=valid_length, layout=self.layout) + return contrib.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states, + valid_length=valid_length, + layout=self.layout) def check_unroll(cell_type, num_states, layout): batch_size = 20