Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add contrib unroll.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Jul 31, 2018
1 parent b2fd3b1 commit a5e4d06
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 1 deletion.
63 changes: 63 additions & 0 deletions python/mxnet/gluon/contrib/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell
from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length
from ... import tensor_types
from .... import symbol, ndarray
from ....base import _as_list

class VariationalDropoutCell(ModifierCell):
"""
Expand Down Expand Up @@ -315,3 +317,64 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,

return next_r, [next_r, next_c]
# pylint: enable= arguments-differ


def _format_sequence(inputs, layout, in_layout=None):
assert inputs is not None, \
"unroll(inputs=None) has been deprecated. " \
"Please create input variables outside unroll."

axis = layout.find('T')
batch_axis = layout.find('N')
batch_size = 0
in_axis = in_layout.find('T') if in_layout is not None else axis
assert isinstance(inputs, tensor_types)
if isinstance(inputs, symbol.Symbol):
F = symbol
else:
F = ndarray
batch_size = inputs.shape[batch_axis]

if axis != in_axis:
inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis)

return inputs, axis, F, batch_size


def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
layout='NTC', valid_length=None):
inputs, axis, F, batch_size = _format_sequence(inputs, layout)
states = begin_state

if drop_inputs:
inputs = F.Dropout(inputs, p=drop_inputs, axes=(axis,))

if valid_length is None:
def loop_body(inputs, states):
return cell(inputs, states)
else:
zeros = []
for i in range(len(states)):
zeros.append(F.zeros_like(states[i]))
states = _as_list(states)
states.append(F.zeros((1)))
def loop_body(inputs, states):
cell_states = states[:-1]
iter_no = states[-1]
out, new_states = cell(inputs, cell_states)
for i in range(len(new_states)):
new_states[i] = F.where(F.broadcast_greater(valid_length, iter_no),
new_states[i], zeros[i])
new_states.append(iter_no + 1)
return out, new_states

outputs, states = F.contrib.foreach(loop_body, inputs, states)
if drop_outputs:
outputs = F.Dropout(outputs, p=drop_outputs, axes=(axis,))
if valid_length is not None:
outputs = F.SequenceMask(outputs, sequence_length=valid_length,
use_sequence_length=True, axis=axis)
# the last state is the iteration number. We don't need it.
return outputs, states[:-1]
else:
return outputs, states
82 changes: 81 additions & 1 deletion tests/python/unittest/test_gluon_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

from __future__ import print_function
import mxnet as mx
import copy
from mxnet import gluon
from mxnet.gluon import contrib
from mxnet.gluon import nn
from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding
from mxnet.test_utils import almost_equal
from mxnet.test_utils import almost_equal, default_context, assert_almost_equal
from common import setup_module, with_seed, teardown
import numpy as np
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -228,6 +230,84 @@ def test_sampler():
assert list(interval_sampler) == [0, 3, 6, 9]


class TestRNNLayer(gluon.HybridBlock):
def __init__(self, cell_type, hidden_size, prefix=None, params=None):
super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
self.cell = cell_type(hidden_size, prefix='rnn_')

def hybrid_forward(self, F, inputs, states, valid_length):
if isinstance(valid_length, list) and len(valid_length) == 0:
valid_length = None
return contrib.rnn.rnn_cell.unroll(self.cell, inputs, states,
valid_length=valid_length, layout='TNC')

def check_unroll(cell_type, num_states):
batch_size = 1
input_size = 5
hidden_size = 3
seq_len = 1
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size))
valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size)))
state_shape = (batch_size, hidden_size)
states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)]

cell = cell_type(hidden_size, prefix='rnn_')
cell.initialize(ctx=default_context())
cell(rnn_data[0], states)
params1 = cell.collect_params()
orig_params1 = copy.deepcopy(params1)

trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length,
layout='TNC', merge_outputs=True)
res1.backward()
trainer.step(batch_size)

configs = [
#{},
{'static_alloc': True},
#{'static_alloc': True, 'static_shape': True}
]
# We can't pass None to a hybrid block, but it accepts an empty list.
# so we use an empty list to represent valid_length if it's None.
if valid_length is None:
valid_length = []
for config in configs:
layer = TestRNNLayer(cell_type, hidden_size)
layer.initialize(ctx=default_context())
layer.hybridize(**config)
res2, states2 = layer(rnn_data, states, valid_length)
params2 = layer.collect_params()
for key, val in orig_params1.items():
params2[key].set_data(copy.deepcopy(val.data()))

trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res2, states2 = layer(rnn_data, states, valid_length)
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
assert len(states1) == len(states2)
for i in range(len(states1)):
assert_almost_equal(states1[i].asnumpy(), states2[i].asnumpy(),
rtol=0.001, atol=0.0001)
res2.backward()
trainer.step(batch_size)

for key, val in params1.items():
weight1 = val.data()
weight2 = params2[key].data()
assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(),
rtol=0.001, atol=0.0001)


@with_seed()
def test_contrib_unroll():
cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2),
(gluon.rnn.GRUCell, 1)]
for cell_type, num_states in cell_types:
check_unroll(cell_type, num_states)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit a5e4d06

Please sign in to comment.