Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added ResidualCell (ModifierCell) for vertical connections in stacked…
Browse files Browse the repository at this point in the history
… RNNs (#6267)
  • Loading branch information
fhieber authored and piiswrong committed May 16, 2017
1 parent 565609c commit 36420de
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
15 changes: 15 additions & 0 deletions python/mxnet/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,21 @@ def __call__(self, inputs, states):
return output, states


class ResidualCell(ModifierCell):
"""
Adds residual connection as described in Wu et al, 2016
(https://arxiv.org/abs/1609.08144).
Output of the cell is output of the base cell plus input.
"""

def __init__(self, base_cell):
super(ResidualCell, self).__init__(base_cell)

def __call__(self, inputs, states):
output, states = self.base_cell(inputs, states)
output = symbol.elemwise_add(output, inputs, name="%s_plus_residual" % output.name)
return output, states


class BidirectionalCell(BaseRNNCell):
"""Bidirectional RNN cell
Expand Down
30 changes: 29 additions & 1 deletion tests/python/unittest/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,38 @@ def test_gru():
assert outs == [(10, 100), (10, 100), (10, 100)]


def test_residual():
cell = mx.rnn.ResidualCell(mx.rnn.GRUCell(50, prefix='rnn_'))
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)]
outputs, _ = cell.unroll(2, inputs)
outputs = mx.sym.Group(outputs)
assert sorted(cell.params._params.keys()) == \
['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
assert outputs.list_outputs() == \
['rnn_t0_out_plus_residual_output', 'rnn_t1_out_plus_residual_output']

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50))
assert outs == [(10, 50), (10, 50)]
print(args)
print(outputs.list_arguments())
outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50)),
rnn_t1_data=mx.nd.ones((10, 50)),
rnn_i2h_weight=mx.nd.zeros((150, 50)),
rnn_i2h_bias=mx.nd.zeros((150,)),
rnn_h2h_weight=mx.nd.zeros((150, 50)),
rnn_h2h_bias=mx.nd.zeros((150,)))
expected_outputs = np.ones((10, 50))
assert np.array_equal(outputs[0].asnumpy(), expected_outputs)
assert np.array_equal(outputs[1].asnumpy(), expected_outputs)


def test_stack():
cell = mx.rnn.SequentialRNNCell()
for i in range(5):
cell.add(mx.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
if i == 1:
cell.add(mx.rnn.ResidualCell(mx.rnn.LSTMCell(100, prefix='rnn_stack%d_' % i)))
else:
cell.add(mx.rnn.LSTMCell(100, prefix='rnn_stack%d_'%i))
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
outputs, _ = cell.unroll(3, inputs)
outputs = mx.sym.Group(outputs)
Expand Down

0 comments on commit 36420de

Please sign in to comment.