Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1327] Allow RNN Layers to be initialized to fp16 #14219

Merged
merged 9 commits into from
Mar 12, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, hidden_size, num_layers, layout,
i2h_bias_initializer, h2h_bias_initializer,
mode, projection_size, h2r_weight_initializer,
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
**kwargs):
dtype, **kwargs):
super(_RNNLayer, self).__init__(**kwargs)
assert layout in ('TNC', 'NTC'), \
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
Expand All @@ -57,6 +57,7 @@ def __init__(self, hidden_size, num_layers, layout,
self._lstm_state_clip_min = lstm_state_clip_min
self._lstm_state_clip_max = lstm_state_clip_max
self._lstm_state_clip_nan = lstm_state_clip_nan
self._dtype = dtype

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

Expand All @@ -66,41 +67,41 @@ def __init__(self, hidden_size, num_layers, layout,
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
init=i2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, nh),
init=h2h_weight_initializer)
init=h2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
init=i2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
init=h2h_bias_initializer, dtype=dtype)
ni = nh * self._dir
else:
np = self._projection_size
for i in range(num_layers):
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
init=i2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, np),
init=h2h_weight_initializer)
init=h2h_weight_initializer, dtype=dtype)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
init=i2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
init=h2h_bias_initializer, dtype=dtype)
self._register_param('{}{}_h2r_weight'.format(j, i),
shape=(np, nh),
init=h2r_weight_initializer)
init=h2r_weight_initializer, dtype=dtype)
ni = np * self._dir

def _register_param(self, name, shape, init):
def _register_param(self, name, shape, init, dtype):
p = self.params.get(name, shape=shape, init=init,
allow_deferred_init=True)
allow_deferred_init=True, dtype=dtype)
setattr(self, name, p)
return p

Expand Down Expand Up @@ -179,6 +180,10 @@ def _unfuse(self):

return stack

def cast(self, dtype):
super(_RNNLayer, self).cast(dtype)
self._dtype = dtype

def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
"""Initial state for this cell.

Expand Down Expand Up @@ -317,6 +322,8 @@ class RNN(_RNNLayer):
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
dtype : str, default 'float32'
Type to initialize the parameters and default states to
prefix : str or None
Prefix of this `Block`.
params : ParameterDict or None
Expand Down Expand Up @@ -357,17 +364,17 @@ def __init__(self, hidden_size, num_layers=1, activation='relu',
layout='TNC', dropout=0, bidirectional=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
input_size=0, **kwargs):
input_size=0, dtype='float32', **kwargs):
super(RNN, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'rnn_'+activation, None, None, None, None, False,
**kwargs)
dtype, **kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
'__layout__': 'LNC', 'dtype': self._dtype}]


class LSTM(_RNNLayer):
Expand Down Expand Up @@ -432,6 +439,8 @@ class LSTM(_RNNLayer):
state_clip_nan : boolean, default False
Whether to stop NaN from propagating in state by clipping it to min/max.
If the clipping range is not specified, this option is ignored.
dtype : str, default 'float32'
Type to initialize the parameters and default states to
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
Expand Down Expand Up @@ -477,26 +486,26 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
projection_size=None, h2r_weight_initializer=None,
state_clip_min=None, state_clip_max=None, state_clip_nan=False,
**kwargs):
dtype='float32', **kwargs):
super(LSTM, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'lstm', projection_size, h2r_weight_initializer,
state_clip_min, state_clip_max, state_clip_nan,
**kwargs)
dtype, **kwargs)

def state_info(self, batch_size=0):
if self._projection_size is None:
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'},
'__layout__': 'LNC', 'dtype': self._dtype},
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
'__layout__': 'LNC', 'dtype': self._dtype}]
else:
return [{'shape': (self._num_layers * self._dir, batch_size, self._projection_size),
'__layout__': 'LNC'},
'__layout__': 'LNC', 'dtype': self._dtype},
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
'__layout__': 'LNC', 'dtype': self._dtype}]


class GRU(_RNNLayer):
Expand Down Expand Up @@ -544,6 +553,8 @@ class GRU(_RNNLayer):
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
dtype : str, default 'float32'
Type to initialize the parameters and default states to
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
Expand Down Expand Up @@ -586,14 +597,14 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout=0, bidirectional=False, input_size=0,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
**kwargs):
dtype='float32', **kwargs):
super(GRU, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'gru', None, None, None, None, False,
**kwargs)
dtype, **kwargs)

def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
'__layout__': 'LNC', 'dtype': self._dtype}]
96 changes: 66 additions & 30 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,15 @@ def hybrid_forward(self, F, seq):
assert_almost_equal(output1.asnumpy(), output2.asnumpy())


def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.collect_params().initialize()
def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.cpu()):
layer.collect_params().initialize(ctx=ctx)
inputs.attach_grad()
inputs = inputs.as_in_context(ctx)
if states is not None:
if isinstance(states, (list, tuple)):
states = [s.as_in_context(ctx) for s in states]
else:
states = states.as_in_context(ctx)
with mx.autograd.record():
if states is None:
out = layer(inputs)
Expand Down Expand Up @@ -467,47 +473,77 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)


@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_rnn_layers():
check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))])
check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))

check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)),
[mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))], run_only=True)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)

def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()):

check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)],ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, ), mx.nd.ones((8, 3, 20)),ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype),ctx=ctx)


check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype),
[mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)

net = gluon.nn.Sequential()
net.add(gluon.rnn.LSTM(10, bidirectional=True))
net.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
net.add(gluon.nn.BatchNorm(axis=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(3, activation='relu'))
net.collect_params().initialize()
net.collect_params().initialize(ctx=ctx)
net.cast(dtype)
with mx.autograd.record():
net(mx.nd.ones((2, 3, 10))).backward()
out = net(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx))
out.backward()
out = out.asnumpy()

net2 = gluon.nn.HybridSequential()
net2.add(gluon.rnn.LSTM(10, bidirectional=True))
net2.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
net2.add(gluon.nn.BatchNorm(axis=2))
net2.add(gluon.nn.Flatten())
net2.add(gluon.nn.Dense(3, activation='relu'))
net2.hybridize()
net2.collect_params().initialize()
net2.collect_params().initialize(ctx=ctx)
net2.cast(dtype)
with mx.autograd.record():
out = net2(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx))
out.backward()
out = out.asnumpy()

net3 = gluon.nn.HybridSequential()
net3.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype))
net3.add(gluon.nn.BatchNorm(axis=2))
net3.add(gluon.nn.Flatten())
net3.add(gluon.nn.Dense(3, activation='relu'))
net3.hybridize()
net3.collect_params().initialize(ctx=ctx)
net3.cast(dtype2)
with mx.autograd.record():
net2(mx.nd.ones((2, 3, 10))).backward()
out = net3(mx.nd.ones((2, 3, 10), dtype=dtype2, ctx=ctx))
out.backward()
out = out.asnumpy()

@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_rnn_layers_fp32():
run_rnn_layers('float32', 'float32')

@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@unittest.skipIf(mx.context.num_gpus() == 0, "RNN FP16 only implemented for GPU for now")
def test_rnn_layers_fp16():
run_rnn_layers('float16', 'float32', mx.gpu())


def test_rnn_unroll_variant_length():
Expand Down