@@ -363,3 +363,232 @@ def _create_norm_layer(self, name):
363
363
gamma_initializer = self .norm_gamma_initializer ,
364
364
epsilon = self .norm_epsilon ,
365
365
name = name )
366
+
367
+
368
+ @tf .keras .utils .register_keras_serializable (package = 'Addons' )
369
+ class LayerNormSimpleRNNCell (keras .layers .SimpleRNNCell ):
370
+ """Cell class for LayerNormSimpleRNN.
371
+
372
+ References:
373
+ [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton.
374
+ "Layer Normalization." ArXiv:1607.06450 [Cs, Stat],
375
+ July 21, 2016. http://arxiv.org/abs/1607.06450
376
+
377
+ Arguments:
378
+ units: Positive integer, dimensionality of the output space.
379
+ activation: Activation function to use.
380
+ Default: hyperbolic tangent (`tanh`).
381
+ If you pass `None`, no activation is applied
382
+ (ie. "linear" activation: `a(x) = x`).
383
+ use_bias: Boolean, (default `True`), whether the layer uses a bias
384
+ vector.
385
+ layernorm_epsilon: Float, (default `1e-5`), Small float added to variance
386
+ to avoid dividing by zero.
387
+ kernel_initializer: Initializer for the `kernel` weights matrix,
388
+ used for the linear transformation of the inputs. Default:
389
+ `glorot_uniform`.
390
+ recurrent_initializer: Initializer for the `recurrent_kernel`
391
+ weights matrix, used for the linear transformation of the recurrent
392
+ state. Default: `orthogonal`.
393
+ bias_initializer: Initializer for the bias vector (`use_bias=True`).
394
+ Default: `zeros`.
395
+ gamma_initializer: Initializer for the gamma vector of the layer
396
+ normalization layer. Default: `ones`.
397
+ kernel_regularizer: Regularizer function applied to the `kernel` weights
398
+ matrix. Default: `None`.
399
+ recurrent_regularizer: Regularizer function applied to the
400
+ `recurrent_kernel` weights matrix. Default: `None`.
401
+ bias_regularizer: Regularizer function applied to the bias vector
402
+ (`use_bias=True`). Default: `None`.
403
+ gamma_regularizer: Regularizer function applied to the gamma vector
404
+ of the layer normalization layer. Default: `None`.
405
+ kernel_constraint: Constraint function applied to the `kernel` weights
406
+ matrix. Default: `None`.
407
+ recurrent_constraint: Constraint function applied to the
408
+ `recurrent_kernel` weights matrix. Default: `None`.
409
+ bias_constraint: Constraint function applied to the bias vector
410
+ (`use_bias=True`). Default: `None`.
411
+ gamma_constraint: Constraint function applied to the gamma vector
412
+ of the layer normalization layer. Default: `None`.
413
+ dropout: Float between 0 and 1. Fraction of the units to drop for the
414
+ linear transformation of the inputs. Default: 0.
415
+ recurrent_dropout: Float between 0 and 1. Fraction of the units to drop
416
+ for the linear transformation of the recurrent state. Default: 0.
417
+
418
+ Call arguments:
419
+ inputs: A 2D tensor, with shape of `[batch, feature]`.
420
+ states: A 2D tensor with shape of `[batch, units]`, which is the state
421
+ from the previous time step. For timestep 0, the initial state provided
422
+ by the user will be feed to cell.
423
+ training: Python boolean indicating whether the layer should behave in
424
+ training mode or in inference mode. Only relevant when `dropout` or
425
+ `recurrent_dropout` is used.
426
+
427
+ Examples:
428
+
429
+ ```python
430
+ import numpy as np
431
+ import tensorflow.keras as keras
432
+ import tensorflow_addons as tfa
433
+
434
+ inputs = np.random.random([32, 10, 8]).astype(np.float32)
435
+ rnn = keras.layers.RNN(tfa.rnn.LayerNormSimpleRNNCell(4))
436
+
437
+ output = rnn(inputs) # The output has shape `[32, 4]`.
438
+
439
+ rnn = keras.layers.RNN(
440
+ tfa.rnn.LayerNormSimpleRNNCell(4),
441
+ return_sequences=True,
442
+ return_state=True)
443
+
444
+ # whole_sequence_output has shape `[32, 10, 4]`.
445
+ # final_state has shape `[32, 4]`.
446
+ whole_sequence_output, final_state = rnn(inputs)
447
+ ```
448
+ """
449
+
450
+ def __init__ (self ,
451
+ units ,
452
+ activation = 'tanh' ,
453
+ use_bias = True ,
454
+ layernorm_epsilon = 1e-05 ,
455
+ kernel_initializer = 'glorot_uniform' ,
456
+ recurrent_initializer = 'orthogonal' ,
457
+ bias_initializer = 'zeros' ,
458
+ gamma_initializer = 'ones' ,
459
+ kernel_regularizer = None ,
460
+ recurrent_regularizer = None ,
461
+ bias_regularizer = None ,
462
+ gamma_regularizer = None ,
463
+ kernel_constraint = None ,
464
+ recurrent_constraint = None ,
465
+ bias_constraint = None ,
466
+ gamma_constraint = None ,
467
+ dropout = 0. ,
468
+ recurrent_dropout = 0. ,
469
+ ** kwargs ):
470
+ super (LayerNormSimpleRNNCell , self ).__init__ (
471
+ units ,
472
+ activation = activation ,
473
+ use_bias = use_bias ,
474
+ kernel_initializer = kernel_initializer ,
475
+ recurrent_initializer = recurrent_initializer ,
476
+ bias_initializer = bias_initializer ,
477
+ kernel_regularizer = kernel_regularizer ,
478
+ recurrent_regularizer = recurrent_regularizer ,
479
+ bias_regularizer = bias_regularizer ,
480
+ kernel_constraint = kernel_constraint ,
481
+ recurrent_constraint = recurrent_constraint ,
482
+ bias_constraint = bias_constraint ,
483
+ dropout = dropout ,
484
+ recurrent_dropout = recurrent_dropout ,
485
+ ** kwargs )
486
+ self .layernorm = keras .layers .LayerNormalization (
487
+ axis = - 1 ,
488
+ epsilon = layernorm_epsilon ,
489
+ center = False ,
490
+ scale = True ,
491
+ beta_initializer = None ,
492
+ gamma_initializer = gamma_initializer ,
493
+ beta_regularizer = None ,
494
+ gamma_regularizer = gamma_regularizer ,
495
+ beta_constraint = None ,
496
+ gamma_constraint = gamma_constraint ,
497
+ ** kwargs )
498
+
499
+ def build (self , input_shape ):
500
+ super (LayerNormSimpleRNNCell , self ).build (input_shape )
501
+ self .layernorm .build ((None , self .units ))
502
+
503
+ def call (self , inputs , states , training = None ):
504
+ """Formulas.
505
+
506
+ Notation:
507
+ y_t : Cell output at t (`output`)
508
+ y_{t-1} : Previous cell output at t-1 (`prev_output`)
509
+ x_t : The new input at t (`inputs`)
510
+ W_xh : Weight matrix for inputs x_t (`self.kernel`)
511
+ W_hh : Weights for prev. outputs y_{t-1} (`self.recurrent_kernel`)
512
+ b : Bias term for centering (`self.bias`)
513
+ d1 : Dropout function for x_t (`inputs * dp_mask`)
514
+ d2 : Dropout function for y_{t-1} (`prev_output * rec_dp_mask`)
515
+ ln : Scaling function from layer normalization (`self.layernorm`)
516
+ f : Activation function (`self.activation`)
517
+
518
+ Case 1:
519
+ Keras' SimpleRNN. Only with bias and activation
520
+ y_t = f(x_t * W_xh + y_{t-1} * W_hh + b)
521
+ or
522
+ net = x_t * W_xh + y_{t-1} * W_hh
523
+ y_t = f(net + b)
524
+
525
+ Case 2:
526
+ addons' LayerNormSimpleRNNCell. Like case 1 but with layer
527
+ normalization (only scaling).
528
+ y_t = f(ln(x_t * W_xh + y_{t-1} * W_hh) + b)
529
+ or
530
+ net = x_t * W_xh + y_{t-1} * W_hh
531
+ y_t = f(ln(net) + b)
532
+
533
+ Layer normalization with scaling and centering in one go (see Ba et
534
+ al (2016), page 3, formula 4, https://arxiv.org/abs/1607.06450)
535
+ is the same as layer normalization only with scaling, and
536
+ centering directly afterwards.
537
+
538
+ Case 3:
539
+ Keras' SimpleRNN. with dropout, bias, and activation
540
+ y_t = f(d1(x_t) * W_xh + d2(y_{t-1}) * W_hh + b)
541
+ or
542
+ net = d1(x_t) * W_xh + d2(y_{t-1}) * W_hh
543
+ y_t = f(net + b)
544
+
545
+ Case 4:
546
+ addons' LayerNormSimpleRNNCell. Like case 3 but with layer
547
+ normalization (only scaling).
548
+ y_t = f(ln(d1(x_t) * W_xh + d2(y_{t-1}) * W_hh) + b)
549
+ or
550
+ net = d1(x_t) * W_xh + d2(y_{t-1}) * W_hh
551
+ y_t = f(ln(net) + b)
552
+ """
553
+ prev_output = states [0 ]
554
+ dp_mask = self .get_dropout_mask_for_cell (inputs , training )
555
+ rec_dp_mask = self .get_recurrent_dropout_mask_for_cell (
556
+ prev_output , training )
557
+
558
+ if dp_mask is not None :
559
+ h = keras .backend .dot (inputs * dp_mask , self .kernel )
560
+ else :
561
+ h = keras .backend .dot (inputs , self .kernel )
562
+
563
+ # don't add bias to "h" here
564
+ # add bias after scaling with layer normalization to "output"
565
+
566
+ if rec_dp_mask is not None :
567
+ prev_output = prev_output * rec_dp_mask
568
+ output = h + keras .backend .dot (prev_output ,
569
+ self .recurrent_kernel ) # "net"
570
+
571
+ output = self .layernorm (output )
572
+
573
+ if self .bias is not None :
574
+ output = keras .backend .bias_add (output , self .bias )
575
+
576
+ if self .activation is not None :
577
+ output = self .activation (output )
578
+
579
+ return output , [output ]
580
+
581
+ # use SimpleRNNCell's get_initial_state method
582
+
583
+ def get_config (self ):
584
+ cell_config = super (LayerNormSimpleRNNCell , self ).get_config ()
585
+ del cell_config ['name' ]
586
+
587
+ ln_config = self .layernorm .get_config ()
588
+ ln_config = {
589
+ k :v for k , v in ln_config .items ()
590
+ if k in ["epsilon" , "gamma_initializer" ,
591
+ "gamma_regularizer" , "gamma_constraint" ]}
592
+
593
+ ln_config ['layernorm_epsilon' ] = ln_config .pop ("epsilon" )
594
+ return dict (list (cell_config .items ()) + list (ln_config .items ()))
0 commit comments