-
Notifications
You must be signed in to change notification settings - Fork 8
/
nalu.py
132 lines (111 loc) · 5.45 KB
/
nalu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from keras.engine import Layer
from keras.engine import InputSpec
from keras import initializers
from keras import regularizers
from keras import constraints
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
class NALU(Layer):
def __init__(self, units,
use_gating=True,
kernel_W_initializer='glorot_uniform',
kernel_M_initializer='glorot_uniform',
gate_initializer='glorot_uniform',
kernel_W_regularizer=None,
kernel_M_regularizer=None,
gate_regularizer=None,
kernel_W_constraint=None,
kernel_M_constraint=None,
gate_constraint=None,
epsilon=1e-7,
**kwargs):
"""
Neural Arithmatic and Logical Unit.
# Arguments:
units: Output dimension.
use_gating: Bool, determines whether to use the gating
mechanism between W and m.
kernel_W_initializer: Initializer for `W` weights.
kernel_M_initializer: Initializer for `M` weights.
gate_initializer: Initializer for gate `G` weights.
kernel_W_regularizer: Regularizer for `W` weights.
kernel_M_regularizer: Regularizer for `M` weights.
gate_regularizer: Regularizer for gate `G` weights.
kernel_W_constraint: Constraints on `W` weights.
kernel_M_constraint: Constraints on `M` weights.
gate_constraint: Constraints on gate `G` weights.
epsilon: Small factor to prevent log 0.
# Reference:
- [Neural Arithmetic Logic Units](https://arxiv.org/abs/1808.00508)
"""
super(NALU, self).__init__()
self.units = units
self.use_gating = use_gating
self.epsilon = epsilon
self.kernel_W_initializer = initializers.get(kernel_W_initializer)
self.kernel_M_initializer = initializers.get(kernel_M_initializer)
self.gate_initializer = initializers.get(gate_initializer)
self.kernel_W_regularizer = regularizers.get(kernel_W_regularizer)
self.kernel_M_regularizer = regularizers.get(kernel_M_regularizer)
self.gate_regularizer = regularizers.get(gate_regularizer)
self.kernel_W_constraint = constraints.get(kernel_W_constraint)
self.kernel_M_constraint = constraints.get(kernel_M_constraint)
self.gate_constraint = constraints.get(gate_constraint)
self.supports_masking = True
def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]
self.W_hat = self.add_weight(shape=(input_dim, self.units),
name='W_hat',
initializer=self.kernel_W_initializer,
regularizer=self.kernel_W_regularizer,
constraint=self.kernel_W_constraint)
self.M_hat = self.add_weight(shape=(input_dim, self.units),
name='M_hat',
initializer=self.kernel_M_initializer,
regularizer=self.kernel_M_regularizer,
constraint=self.kernel_M_constraint)
if self.use_gating:
self.G = self.add_weight(shape=(input_dim, self.units),
name='G',
initializer=self.gate_initializer,
regularizer=self.gate_regularizer,
constraint=self.gate_constraint)
else:
self.G = None
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True
def call(self, inputs, **kwargs):
W = K.tanh(self.W_hat) * K.sigmoid(self.M_hat)
m = K.exp(K.dot(K.log(K.abs(inputs) + self.epsilon), W))
a = K.dot(inputs, W)
if self.use_gating:
g = K.sigmoid(K.dot(inputs, self.G))
outputs = g * a + (1. - g) * m
else:
outputs = a + m
return outputs
def compute_output_shape(self, input_shape):
assert input_shape and len(input_shape) >= 2
assert input_shape[-1]
output_shape = list(input_shape)
output_shape[-1] = self.units
return tuple(output_shape)
def get_config(self):
config = {
'units': self.units,
'use_gating': self.use_gating,
'kernel_W_initializer': initializers.serialize(self.kernel_W_initializer),
'kernel_M_initializer': initializers.serialize(self.kernel_M_initializer),
'gate_initializer': initializers.serialize(self.gate_initializer),
'kernel_W_regularizer': regularizers.serialize(self.kernel_W_regularizer),
'kernel_M_regularizer': regularizers.serialize(self.kernel_M_regularizer),
'gate_regularizer': regularizers.serialize(self.gate_regularizer),
'kernel_W_constraint': constraints.serialize(self.kernel_W_constraint),
'kernel_M_constraint': constraints.serialize(self.kernel_M_constraint),
'gate_constraint': constraints.serialize(self.gate_constraint),
'epsilon': self.epsilon
}
base_config = super(NALU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
get_custom_objects().update({'NALU': NALU})