diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index c43dc8527fd4..6dfec43a8b5f 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -37,7 +37,7 @@ def __init__(self, hidden_size, num_layers, layout, i2h_bias_initializer, h2h_bias_initializer, mode, projection_size, h2r_weight_initializer, lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan, - **kwargs): + dtype, **kwargs): super(_RNNLayer, self).__init__(**kwargs) assert layout in ('TNC', 'NTC'), \ "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout @@ -57,6 +57,7 @@ def __init__(self, hidden_size, num_layers, layout, self._lstm_state_clip_min = lstm_state_clip_min self._lstm_state_clip_max = lstm_state_clip_max self._lstm_state_clip_nan = lstm_state_clip_nan + self._dtype = dtype self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] @@ -66,16 +67,16 @@ def __init__(self, hidden_size, num_layers, layout, for j in ['l', 'r'][:self._dir]: self._register_param('{}{}_i2h_weight'.format(j, i), shape=(ng*nh, ni), - init=i2h_weight_initializer) + init=i2h_weight_initializer, dtype=dtype) self._register_param('{}{}_h2h_weight'.format(j, i), shape=(ng*nh, nh), - init=h2h_weight_initializer) + init=h2h_weight_initializer, dtype=dtype) self._register_param('{}{}_i2h_bias'.format(j, i), shape=(ng*nh,), - init=i2h_bias_initializer) + init=i2h_bias_initializer, dtype=dtype) self._register_param('{}{}_h2h_bias'.format(j, i), shape=(ng*nh,), - init=h2h_bias_initializer) + init=h2h_bias_initializer, dtype=dtype) ni = nh * self._dir else: np = self._projection_size @@ -83,24 +84,24 @@ def __init__(self, hidden_size, num_layers, layout, for j in ['l', 'r'][:self._dir]: self._register_param('{}{}_i2h_weight'.format(j, i), shape=(ng*nh, ni), - init=i2h_weight_initializer) + init=i2h_weight_initializer, dtype=dtype) self._register_param('{}{}_h2h_weight'.format(j, i), shape=(ng*nh, np), - init=h2h_weight_initializer) + init=h2h_weight_initializer, dtype=dtype) self._register_param('{}{}_i2h_bias'.format(j, i), shape=(ng*nh,), - init=i2h_bias_initializer) + init=i2h_bias_initializer, dtype=dtype) self._register_param('{}{}_h2h_bias'.format(j, i), shape=(ng*nh,), - init=h2h_bias_initializer) + init=h2h_bias_initializer, dtype=dtype) self._register_param('{}{}_h2r_weight'.format(j, i), shape=(np, nh), - init=h2r_weight_initializer) + init=h2r_weight_initializer, dtype=dtype) ni = np * self._dir - def _register_param(self, name, shape, init): + def _register_param(self, name, shape, init, dtype): p = self.params.get(name, shape=shape, init=init, - allow_deferred_init=True) + allow_deferred_init=True, dtype=dtype) setattr(self, name, p) return p @@ -179,6 +180,10 @@ def _unfuse(self): return stack + def cast(self, dtype): + super(_RNNLayer, self).cast(dtype) + self._dtype = dtype + def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): """Initial state for this cell. @@ -317,6 +322,8 @@ class RNN(_RNNLayer): input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. + dtype : str, default 'float32' + Type to initialize the parameters and default states to prefix : str or None Prefix of this `Block`. params : ParameterDict or None @@ -357,17 +364,17 @@ def __init__(self, hidden_size, num_layers=1, activation='relu', layout='TNC', dropout=0, bidirectional=False, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - input_size=0, **kwargs): + input_size=0, dtype='float32', **kwargs): super(RNN, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, 'rnn_'+activation, None, None, None, None, False, - **kwargs) + dtype, **kwargs) def state_info(self, batch_size=0): return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size), - '__layout__': 'LNC'}] + '__layout__': 'LNC', 'dtype': self._dtype}] class LSTM(_RNNLayer): @@ -432,6 +439,8 @@ class LSTM(_RNNLayer): state_clip_nan : boolean, default False Whether to stop NaN from propagating in state by clipping it to min/max. If the clipping range is not specified, this option is ignored. + dtype : str, default 'float32' + Type to initialize the parameters and default states to input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. @@ -477,26 +486,26 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC', i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', projection_size=None, h2r_weight_initializer=None, state_clip_min=None, state_clip_max=None, state_clip_nan=False, - **kwargs): + dtype='float32', **kwargs): super(LSTM, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, 'lstm', projection_size, h2r_weight_initializer, state_clip_min, state_clip_max, state_clip_nan, - **kwargs) + dtype, **kwargs) def state_info(self, batch_size=0): if self._projection_size is None: return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size), - '__layout__': 'LNC'}, + '__layout__': 'LNC', 'dtype': self._dtype}, {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size), - '__layout__': 'LNC'}] + '__layout__': 'LNC', 'dtype': self._dtype}] else: return [{'shape': (self._num_layers * self._dir, batch_size, self._projection_size), - '__layout__': 'LNC'}, + '__layout__': 'LNC', 'dtype': self._dtype}, {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size), - '__layout__': 'LNC'}] + '__layout__': 'LNC', 'dtype': self._dtype}] class GRU(_RNNLayer): @@ -544,6 +553,8 @@ class GRU(_RNNLayer): Initializer for the bias vector. h2h_bias_initializer : str or Initializer Initializer for the bias vector. + dtype : str, default 'float32' + Type to initialize the parameters and default states to input_size: int, default 0 The number of expected features in the input x. If not specified, it will be inferred from input. @@ -586,14 +597,14 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC', dropout=0, bidirectional=False, input_size=0, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - **kwargs): + dtype='float32', **kwargs): super(GRU, self).__init__(hidden_size, num_layers, layout, dropout, bidirectional, input_size, i2h_weight_initializer, h2h_weight_initializer, i2h_bias_initializer, h2h_bias_initializer, 'gru', None, None, None, None, False, - **kwargs) + dtype, **kwargs) def state_info(self, batch_size=0): return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size), - '__layout__': 'LNC'}] + '__layout__': 'LNC', 'dtype': self._dtype}] diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index edc43d21b36b..b410362c8fd1 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -427,9 +427,15 @@ def hybrid_forward(self, F, seq): assert_almost_equal(output1.asnumpy(), output2.asnumpy()) -def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): - layer.collect_params().initialize() +def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.cpu()): + layer.collect_params().initialize(ctx=ctx) + inputs = inputs.as_in_context(ctx) inputs.attach_grad() + if states is not None: + if isinstance(states, (list, tuple)): + states = [s.as_in_context(ctx) for s in states] + else: + states = states.as_in_context(ctx) with mx.autograd.record(): if states is None: out = layer(inputs) @@ -467,47 +473,76 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5) -@assert_raises_cudnn_not_satisfied(min_version='5.1.10') -def test_rnn_layers(): - check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20))) - check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10))) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20))) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))]) - check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20))) - check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10))) - - check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)), - run_only=True) - check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5), - mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)), - run_only=True) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5), - mx.nd.ones((8, 3, 20)), - [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))], run_only=True) - check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)), - run_only=True) - check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5), - mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True) + +def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()): + + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx) + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)],ctx=ctx) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, ), mx.nd.ones((8, 3, 20), dtype=dtype),ctx=ctx) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype),ctx=ctx) + + + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.nd.ones((8, 3, 20), dtype=dtype), + run_only=True, ctx=ctx) + check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5, dtype=dtype), + mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), + run_only=True, ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, dtype=dtype), + mx.nd.ones((8, 3, 20), dtype=dtype), + [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True, ctx=ctx) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), + run_only=True, ctx=ctx) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5, dtype=dtype), + mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx) net = gluon.nn.Sequential() - net.add(gluon.rnn.LSTM(10, bidirectional=True)) + net.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2)) net.add(gluon.nn.BatchNorm(axis=2)) net.add(gluon.nn.Flatten()) net.add(gluon.nn.Dense(3, activation='relu')) - net.collect_params().initialize() + net.collect_params().initialize(ctx=ctx) + net.cast(dtype) with mx.autograd.record(): - net(mx.nd.ones((2, 3, 10))).backward() + out = net(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx)) + out.backward() + out = out.asnumpy() net2 = gluon.nn.HybridSequential() - net2.add(gluon.rnn.LSTM(10, bidirectional=True)) + net2.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2)) net2.add(gluon.nn.BatchNorm(axis=2)) net2.add(gluon.nn.Flatten()) net2.add(gluon.nn.Dense(3, activation='relu')) net2.hybridize() - net2.collect_params().initialize() + net2.collect_params().initialize(ctx=ctx) + net2.cast(dtype) + with mx.autograd.record(): + out = net2(mx.nd.ones((2, 3, 10), dtype=dtype, ctx=ctx)) + out.backward() + out = out.asnumpy() + + net3 = gluon.nn.HybridSequential() + net3.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype)) + net3.add(gluon.nn.BatchNorm(axis=2)) + net3.add(gluon.nn.Flatten()) + net3.add(gluon.nn.Dense(3, activation='relu')) + net3.hybridize() + net3.collect_params().initialize(ctx=ctx) + net3.cast(dtype2) with mx.autograd.record(): - net2(mx.nd.ones((2, 3, 10))).backward() + out = net3(mx.nd.ones((2, 3, 10), dtype=dtype2, ctx=ctx)) + out.backward() + out = out.asnumpy() + +def test_rnn_layers_fp32(): + run_rnn_layers('float32', 'float32') + +@assert_raises_cudnn_not_satisfied(min_version='5.1.10') +@unittest.skipIf(mx.context.num_gpus() == 0, "RNN FP16 only implemented for GPU for now") +def test_rnn_layers_fp16(): + run_rnn_layers('float16', 'float32', mx.gpu()) def test_rnn_unroll_variant_length(): @@ -590,8 +625,6 @@ def test_cell_fill_shape(): check_rnn_forward(cell, mx.nd.ones((2, 3, 7))) assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1] - -@assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_layer_fill_shape(): layer = gluon.rnn.LSTM(10) layer.hybridize() @@ -603,6 +636,7 @@ def test_layer_fill_shape(): def test_bidirectional_unroll_valid_length(): # Test BidirectionalCell. # In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ). + class BiLSTM(gluon.nn.HybridBlock): def __init__(self, rnn_size, time_step, **kwargs): super(BiLSTM, self).__init__(**kwargs)