Skip to content

Commit cdb43ff

Browse files
authored
LayernormSimpleRNN moved to addons (#841)
* LayernormSimpleRNN moved to addons * code-format run * use super instead of calling the parent class * deactivate layernorm's bias term (beta) for centering, and apply the normal self.bias term after scaling with layernorm for centering. docstring with explanatory formulas added to cell's call method * use_layernorm=True set as default * import alligned with cell.py, examples in docstring corrected * import aligned with cell_test.py * code for LayernormSimpleRNN moved into cell.py and cell_test.py * pylint errors corrected * bazel's timeout increased from small to large for cell_test.py * test with training deactivated * non-ascii char replaced * dict syntax for python2 changed * Renamed to LayerNorm... * direct parent class call replaced with super * error due to import change corrected * uncomment line * unit test added * Name change in unit test file * Still the class name change * deleted dtype and trainable args for parent class * remove self for super parent class calls * compare arrays with assertAllEqual * use_layernorm removed * dict removed from return statement * LayerNormSimpleRNN removed, use kwargs, comments removed * forward **kwargs to other layers * a more pythonic dict loop
1 parent 52b8079 commit cdb43ff

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

Diff for: tensorflow_addons/rnn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616

1717
from tensorflow_addons.rnn.cell import LayerNormLSTMCell
1818
from tensorflow_addons.rnn.cell import NASCell
19+
from tensorflow_addons.rnn.cell import LayerNormSimpleRNNCell

Diff for: tensorflow_addons/rnn/cell.py

+229
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,232 @@ def _create_norm_layer(self, name):
363363
gamma_initializer=self.norm_gamma_initializer,
364364
epsilon=self.norm_epsilon,
365365
name=name)
366+
367+
368+
@tf.keras.utils.register_keras_serializable(package='Addons')
369+
class LayerNormSimpleRNNCell(keras.layers.SimpleRNNCell):
370+
"""Cell class for LayerNormSimpleRNN.
371+
372+
References:
373+
[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton.
374+
"Layer Normalization." ArXiv:1607.06450 [Cs, Stat],
375+
July 21, 2016. http://arxiv.org/abs/1607.06450
376+
377+
Arguments:
378+
units: Positive integer, dimensionality of the output space.
379+
activation: Activation function to use.
380+
Default: hyperbolic tangent (`tanh`).
381+
If you pass `None`, no activation is applied
382+
(ie. "linear" activation: `a(x) = x`).
383+
use_bias: Boolean, (default `True`), whether the layer uses a bias
384+
vector.
385+
layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
386+
to avoid dividing by zero.
387+
kernel_initializer: Initializer for the `kernel` weights matrix,
388+
used for the linear transformation of the inputs. Default:
389+
`glorot_uniform`.
390+
recurrent_initializer: Initializer for the `recurrent_kernel`
391+
weights matrix, used for the linear transformation of the recurrent
392+
state. Default: `orthogonal`.
393+
bias_initializer: Initializer for the bias vector (`use_bias=True`).
394+
Default: `zeros`.
395+
gamma_initializer: Initializer for the gamma vector of the layer
396+
normalization layer. Default: `ones`.
397+
kernel_regularizer: Regularizer function applied to the `kernel` weights
398+
matrix. Default: `None`.
399+
recurrent_regularizer: Regularizer function applied to the
400+
`recurrent_kernel` weights matrix. Default: `None`.
401+
bias_regularizer: Regularizer function applied to the bias vector
402+
(`use_bias=True`). Default: `None`.
403+
gamma_regularizer: Regularizer function applied to the gamma vector
404+
of the layer normalization layer. Default: `None`.
405+
kernel_constraint: Constraint function applied to the `kernel` weights
406+
matrix. Default: `None`.
407+
recurrent_constraint: Constraint function applied to the
408+
`recurrent_kernel` weights matrix. Default: `None`.
409+
bias_constraint: Constraint function applied to the bias vector
410+
(`use_bias=True`). Default: `None`.
411+
gamma_constraint: Constraint function applied to the gamma vector
412+
of the layer normalization layer. Default: `None`.
413+
dropout: Float between 0 and 1. Fraction of the units to drop for the
414+
linear transformation of the inputs. Default: 0.
415+
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
416+
for the linear transformation of the recurrent state. Default: 0.
417+
418+
Call arguments:
419+
inputs: A 2D tensor, with shape of `[batch, feature]`.
420+
states: A 2D tensor with shape of `[batch, units]`, which is the state
421+
from the previous time step. For timestep 0, the initial state provided
422+
by the user will be feed to cell.
423+
training: Python boolean indicating whether the layer should behave in
424+
training mode or in inference mode. Only relevant when `dropout` or
425+
`recurrent_dropout` is used.
426+
427+
Examples:
428+
429+
```python
430+
import numpy as np
431+
import tensorflow.keras as keras
432+
import tensorflow_addons as tfa
433+
434+
inputs = np.random.random([32, 10, 8]).astype(np.float32)
435+
rnn = keras.layers.RNN(tfa.rnn.LayerNormSimpleRNNCell(4))
436+
437+
output = rnn(inputs) # The output has shape `[32, 4]`.
438+
439+
rnn = keras.layers.RNN(
440+
tfa.rnn.LayerNormSimpleRNNCell(4),
441+
return_sequences=True,
442+
return_state=True)
443+
444+
# whole_sequence_output has shape `[32, 10, 4]`.
445+
# final_state has shape `[32, 4]`.
446+
whole_sequence_output, final_state = rnn(inputs)
447+
```
448+
"""
449+
450+
def __init__(self,
451+
units,
452+
activation='tanh',
453+
use_bias=True,
454+
layernorm_epsilon=1e-05,
455+
kernel_initializer='glorot_uniform',
456+
recurrent_initializer='orthogonal',
457+
bias_initializer='zeros',
458+
gamma_initializer='ones',
459+
kernel_regularizer=None,
460+
recurrent_regularizer=None,
461+
bias_regularizer=None,
462+
gamma_regularizer=None,
463+
kernel_constraint=None,
464+
recurrent_constraint=None,
465+
bias_constraint=None,
466+
gamma_constraint=None,
467+
dropout=0.,
468+
recurrent_dropout=0.,
469+
**kwargs):
470+
super(LayerNormSimpleRNNCell, self).__init__(
471+
units,
472+
activation=activation,
473+
use_bias=use_bias,
474+
kernel_initializer=kernel_initializer,
475+
recurrent_initializer=recurrent_initializer,
476+
bias_initializer=bias_initializer,
477+
kernel_regularizer=kernel_regularizer,
478+
recurrent_regularizer=recurrent_regularizer,
479+
bias_regularizer=bias_regularizer,
480+
kernel_constraint=kernel_constraint,
481+
recurrent_constraint=recurrent_constraint,
482+
bias_constraint=bias_constraint,
483+
dropout=dropout,
484+
recurrent_dropout=recurrent_dropout,
485+
**kwargs)
486+
self.layernorm = keras.layers.LayerNormalization(
487+
axis=-1,
488+
epsilon=layernorm_epsilon,
489+
center=False,
490+
scale=True,
491+
beta_initializer=None,
492+
gamma_initializer=gamma_initializer,
493+
beta_regularizer=None,
494+
gamma_regularizer=gamma_regularizer,
495+
beta_constraint=None,
496+
gamma_constraint=gamma_constraint,
497+
**kwargs)
498+
499+
def build(self, input_shape):
500+
super(LayerNormSimpleRNNCell, self).build(input_shape)
501+
self.layernorm.build((None, self.units))
502+
503+
def call(self, inputs, states, training=None):
504+
"""Formulas.
505+
506+
Notation:
507+
y_t : Cell output at t (`output`)
508+
y_{t-1} : Previous cell output at t-1 (`prev_output`)
509+
x_t : The new input at t (`inputs`)
510+
W_xh : Weight matrix for inputs x_t (`self.kernel`)
511+
W_hh : Weights for prev. outputs y_{t-1} (`self.recurrent_kernel`)
512+
b : Bias term for centering (`self.bias`)
513+
d1 : Dropout function for x_t (`inputs * dp_mask`)
514+
d2 : Dropout function for y_{t-1} (`prev_output * rec_dp_mask`)
515+
ln : Scaling function from layer normalization (`self.layernorm`)
516+
f : Activation function (`self.activation`)
517+
518+
Case 1:
519+
Keras' SimpleRNN. Only with bias and activation
520+
y_t = f(x_t * W_xh + y_{t-1} * W_hh + b)
521+
or
522+
net = x_t * W_xh + y_{t-1} * W_hh
523+
y_t = f(net + b)
524+
525+
Case 2:
526+
addons' LayerNormSimpleRNNCell. Like case 1 but with layer
527+
normalization (only scaling).
528+
y_t = f(ln(x_t * W_xh + y_{t-1} * W_hh) + b)
529+
or
530+
net = x_t * W_xh + y_{t-1} * W_hh
531+
y_t = f(ln(net) + b)
532+
533+
Layer normalization with scaling and centering in one go (see Ba et
534+
al (2016), page 3, formula 4, https://arxiv.org/abs/1607.06450)
535+
is the same as layer normalization only with scaling, and
536+
centering directly afterwards.
537+
538+
Case 3:
539+
Keras' SimpleRNN. with dropout, bias, and activation
540+
y_t = f(d1(x_t) * W_xh + d2(y_{t-1}) * W_hh + b)
541+
or
542+
net = d1(x_t) * W_xh + d2(y_{t-1}) * W_hh
543+
y_t = f(net + b)
544+
545+
Case 4:
546+
addons' LayerNormSimpleRNNCell. Like case 3 but with layer
547+
normalization (only scaling).
548+
y_t = f(ln(d1(x_t) * W_xh + d2(y_{t-1}) * W_hh) + b)
549+
or
550+
net = d1(x_t) * W_xh + d2(y_{t-1}) * W_hh
551+
y_t = f(ln(net) + b)
552+
"""
553+
prev_output = states[0]
554+
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
555+
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
556+
prev_output, training)
557+
558+
if dp_mask is not None:
559+
h = keras.backend.dot(inputs * dp_mask, self.kernel)
560+
else:
561+
h = keras.backend.dot(inputs, self.kernel)
562+
563+
# don't add bias to "h" here
564+
# add bias after scaling with layer normalization to "output"
565+
566+
if rec_dp_mask is not None:
567+
prev_output = prev_output * rec_dp_mask
568+
output = h + keras.backend.dot(prev_output,
569+
self.recurrent_kernel) # "net"
570+
571+
output = self.layernorm(output)
572+
573+
if self.bias is not None:
574+
output = keras.backend.bias_add(output, self.bias)
575+
576+
if self.activation is not None:
577+
output = self.activation(output)
578+
579+
return output, [output]
580+
581+
# use SimpleRNNCell's get_initial_state method
582+
583+
def get_config(self):
584+
cell_config = super(LayerNormSimpleRNNCell, self).get_config()
585+
del cell_config['name']
586+
587+
ln_config = self.layernorm.get_config()
588+
ln_config = {
589+
k:v for k, v in ln_config.items()
590+
if k in ["epsilon", "gamma_initializer",
591+
"gamma_regularizer", "gamma_constraint"]}
592+
593+
ln_config['layernorm_epsilon'] = ln_config.pop("epsilon")
594+
return dict(list(cell_config.items()) + list(ln_config.items()))

Diff for: tensorflow_addons/rnn/cell_test.py

+61
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from tensorflow_addons.utils import test_utils
2222
from tensorflow_addons.rnn import cell as rnn_cell
23+
from tensorflow_addons.rnn import LayerNormSimpleRNNCell
2324

2425

2526
@test_utils.run_all_in_graph_and_eager_modes
@@ -292,5 +293,65 @@ def test_config(self):
292293
self.assertEqual(config, restored_config)
293294

294295

296+
@test_utils.run_all_in_graph_and_eager_modes
297+
class LayerNormSimpleRNNTest(tf.test.TestCase):
298+
def test_constraints_layernorm_rnn(self):
299+
embedding_dim = 4
300+
k_constraint = keras.constraints.max_norm(0.01)
301+
r_constraint = keras.constraints.max_norm(0.01)
302+
b_constraint = keras.constraints.max_norm(0.01)
303+
g_constraint = keras.constraints.max_norm(0.01)
304+
layer = keras.layers.RNN(
305+
LayerNormSimpleRNNCell(
306+
units=5,
307+
kernel_constraint=k_constraint,
308+
recurrent_constraint=r_constraint,
309+
bias_constraint=b_constraint,
310+
gamma_constraint=g_constraint),
311+
input_shape=(None, embedding_dim),
312+
return_sequences=False)
313+
layer.build((None, None, embedding_dim))
314+
self.assertEqual(layer.cell.kernel.constraint, k_constraint)
315+
self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
316+
self.assertEqual(layer.cell.bias.constraint, b_constraint)
317+
self.assertEqual(layer.cell.layernorm.gamma.constraint, g_constraint)
318+
319+
def test_with_masking_layer_layernorm_rnn(self):
320+
inputs = np.random.random((2, 3, 4))
321+
targets = np.abs(np.random.random((2, 3, 5)))
322+
targets /= targets.sum(axis=-1, keepdims=True)
323+
model = keras.models.Sequential()
324+
model.add(keras.layers.Masking(input_shape=(3, 4)))
325+
model.add(
326+
keras.layers.RNN(
327+
LayerNormSimpleRNNCell(units=5),
328+
return_sequences=True,
329+
unroll=False))
330+
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
331+
model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)
332+
333+
def test_regularizers_layernorm_rnn(self):
334+
embedding_dim = 4
335+
layer = keras.layers.RNN(
336+
LayerNormSimpleRNNCell(
337+
units=5,
338+
kernel_regularizer=keras.regularizers.l1(0.01),
339+
recurrent_regularizer=keras.regularizers.l1(0.01),
340+
bias_regularizer='l2',
341+
gamma_regularizer='l2'),
342+
input_shape=(None, embedding_dim),
343+
return_sequences=False)
344+
layer.build((None, None, 2))
345+
self.assertEqual(len(layer.losses), 4)
346+
347+
def test_configs_layernorm(self):
348+
config = {'layernorm_epsilon': 1e-6}
349+
cell1 = LayerNormSimpleRNNCell(units=8, **config)
350+
config1 = cell1.get_config()
351+
cell2 = LayerNormSimpleRNNCell(**config1)
352+
config2 = cell2.get_config()
353+
assert config1 == config2
354+
355+
295356
if __name__ == "__main__":
296357
tf.test.main()

0 commit comments

Comments
 (0)