Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1327] Allow RNN Layers to be initialized to fp16 (#14219)
Browse files Browse the repository at this point in the history
* update rnn for fp16

* fix typo in test

* fix tests

* fix tests

* fix gpu tests

* Update test_gluon_rnn.py

* Update test_gluon_rnn.py

* trigger

* try removing checks for unix
  • Loading branch information
ThomasDelteil authored and eric-haibin-lin committed Mar 12, 2019
1 parent 66c74cc commit 6aa8c27
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 56 deletions.
59 changes: 35 additions & 24 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, hidden_size, num_layers, layout,
i2h_bias_initializer, h2h_bias_initializer,
mode, projection_size, h2r_weight_initializer,
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
**kwargs):
dtype, **kwargs):
super(_RNNLayer, self).__init__(**kwargs)
assert layout in ('TNC', 'NTC'), \
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
Expand All @@ -57,6 +57,7 @@ def __init__(self, hidden_size, num_layers, layout,
self._lstm_state_clip_min = lstm_state_clip_min
self._lstm_state_clip_max = lstm_state_clip_max
self._lstm_state_clip_nan = lstm_state_clip_nan
self._dtype = dtype

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

Expand All @@ -66,41 +67,41 @@ def __init__(self, hidden_size, num_layers, layout,
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
init=i2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, nh),
init=h2h_weight_initializer)
init=h2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
init=i2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
init=h2h_bias_initializer, dtype=dtype)
ni = nh * self._dir
else:
np = self._projection_size
for i in range(num_layers):
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
init=i2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, np),
init=h2h_weight_initializer)
init=h2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
init=i2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
init=h2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2r_weight'.format(j, i),
shape=(np, nh),
init=h2r_weight_initializer)
init=h2r_weight_initializer, dtype=dtype)
ni = np * self._dir

def _register_param(self, name, shape, init):
def _register_param(self, name, shape, init, dtype):
p = self.params.get(name, shape=shape, init=init,
allow_deferred_init=True)
allow_deferred_init=True, dtype=dtype)
setattr(self, name, p)
return p

Expand Down Expand Up @@ -179,6 +180,10 @@ def _unfuse(self):

return stack

def cast(self, dtype):
super(_RNNLayer, self).cast(dtype)
self._dtype = dtype

def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
"""Initial state for this cell.
Expand Down Expand Up @@ -317,6 +322,8 @@ class RNN(_RNNLayer):
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
dtype : str, default 'float32'
Type to initialize the parameters and default states to
prefix : str or None
Prefix of this `Block`.
params : ParameterDict or None
Expand Down Expand Up @@ -357,17 +364,17 @@ def __init__(self, hidden_size, num_layers=1, activation='relu',
layout='TNC', dropout=0, bidirectional=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
input_size=0, **kwargs):
input_size=0, dtype='float32', **kwargs):
super(RNN, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'rnn_'+activation, None, None, None, None, False,
**kwargs)
dtype, **kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
'__layout__': 'LNC', 'dtype': self._dtype}]


class LSTM(_RNNLayer):
Expand Down Expand Up @@ -432,6 +439,8 @@ class LSTM(_RNNLayer):
state_clip_nan : boolean, default False
Whether to stop NaN from propagating in state by clipping it to min/max.
If the clipping range is not specified, this option is ignored.
dtype : str, default 'float32'
Type to initialize the parameters and default states to
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
Expand Down Expand Up @@ -477,26 +486,26 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
projection_size=None, h2r_weight_initializer=None,
state_clip_min=None, state_clip_max=None, state_clip_nan=False,
**kwargs):
dtype='float32', **kwargs):
super(LSTM, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'lstm', projection_size, h2r_weight_initializer,
state_clip_min, state_clip_max, state_clip_nan,
**kwargs)
dtype, **kwargs)

def state_info(self, batch_size=0):
if self._projection_size is None:
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'},
'__layout__': 'LNC', 'dtype': self._dtype},
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
'__layout__': 'LNC', 'dtype': self._dtype}]
else:
return [{'shape': (self._num_layers * self._dir, batch_size, self._projection_size),
'__layout__': 'LNC'},
'__layout__': 'LNC', 'dtype': self._dtype},
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
'__layout__': 'LNC', 'dtype': self._dtype}]


class GRU(_RNNLayer):
Expand Down Expand Up @@ -544,6 +553,8 @@ class GRU(_RNNLayer):
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
dtype : str, default 'float32'
Type to initialize the parameters and default states to
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
Expand Down Expand Up @@ -586,14 +597,14 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout=0, bidirectional=False, input_size=0,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
**kwargs):
dtype='float32', **kwargs):
super(GRU, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'gru', None, None, None, None, False,
**kwargs)
dtype, **kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
'__layout__': 'LNC', 'dtype': self._dtype}]
98 changes: 66 additions & 32 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,15 @@ def hybrid_forward(self, F, seq):
assert_almost_equal(output1.asnumpy(), output2.asnumpy())


def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.collect_params().initialize()
def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.cpu()):
layer.collect_params().initialize(ctx=ctx)
inputs = inputs.as_in_context(ctx)
inputs.attach_grad()
if states is not None:
if isinstance(states, (list, tuple)):
states = [s.as_in_context(ctx) for s in states]
else:
states = states.as_in_context(ctx)
with mx.autograd.record():
if states is None:
out = layer(inputs)
Expand Down Expand Up @@ -467,47 +473,76 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)


@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_rnn_layers():
check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))])
check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))

check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)),
[mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))], run_only=True)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)

def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()):

check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)],ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, ), mx.nd.ones((8, 3, 20), dtype=dtype),ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype),ctx=ctx)


check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype),
[mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)

net = gluon.nn.Sequential()
net.add(gluon.rnn.LSTM(10, bidirectional=True))
net.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
net.add(gluon.nn.BatchNorm(axis=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(3, activation='relu'))
net.collect_params().initialize()
net.collect_params().initialize(ctx=ctx)
net.cast(dtype)
with mx.autograd.record():
net(mx.nd.ones((2, 3, 10))).backward()
out = net(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx))
out.backward()
out = out.asnumpy()

net2 = gluon.nn.HybridSequential()
net2.add(gluon.rnn.LSTM(10, bidirectional=True))
net2.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
net2.add(gluon.nn.BatchNorm(axis=2))
net2.add(gluon.nn.Flatten())
net2.add(gluon.nn.Dense(3, activation='relu'))
net2.hybridize()
net2.collect_params().initialize()
net2.collect_params().initialize(ctx=ctx)
net2.cast(dtype)
with mx.autograd.record():
out = net2(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx))
out.backward()
out = out.asnumpy()

net3 = gluon.nn.HybridSequential()
net3.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype))
net3.add(gluon.nn.BatchNorm(axis=2))
net3.add(gluon.nn.Flatten())
net3.add(gluon.nn.Dense(3, activation='relu'))
net3.hybridize()
net3.collect_params().initialize(ctx=ctx)
net3.cast(dtype2)
with mx.autograd.record():
net2(mx.nd.ones((2, 3, 10))).backward()
out = net3(mx.nd.ones((2, 3, 10), dtype=dtype2, ctx=ctx))
out.backward()
out = out.asnumpy()

def test_rnn_layers_fp32():
run_rnn_layers('float32', 'float32')

@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@unittest.skipIf(mx.context.num_gpus() == 0, "RNN FP16 only implemented for GPU for now")
def test_rnn_layers_fp16():
run_rnn_layers('float16', 'float32', mx.gpu())


def test_rnn_unroll_variant_length():
Expand Down Expand Up @@ -590,8 +625,6 @@ def test_cell_fill_shape():
check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]


@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_layer_fill_shape():
layer = gluon.rnn.LSTM(10)
layer.hybridize()
Expand All @@ -603,6 +636,7 @@ def test_layer_fill_shape():
def test_bidirectional_unroll_valid_length():
# Test BidirectionalCell.
# In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).

class BiLSTM(gluon.nn.HybridBlock):
def __init__(self, rnn_size, time_step, **kwargs):
super(BiLSTM, self).__init__(**kwargs)
Expand Down

0 comments on commit 6aa8c27

Please sign in to comment.