-
Notifications
You must be signed in to change notification settings - Fork 0
/
gru.py
88 lines (81 loc) · 3.7 KB
/
gru.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import tensorflow as tf
# Custom GRUCell
class GRUCell(tf.keras.layers.Layer):
def __init__(self, input_dim, units):
super(GRUCell, self).__init__()
self.input_dim = input_dim
self.units = units
self.itteration = 0
# TF needs this.
self.state_size = units
def build(self, input_shape):
# Update gate
self.w_z = self.add_weight(
name="w_update",
shape=(self.input_dim, self.units),
initializer='random_normal',
regularizer="l2"
)
self.u_z = self.add_weight(
name="u_update",
shape=(self.units, self.units),
initializer='random_normal',
regularizer="l2"
)
self.b_z = self.add_weight(
name="b_update",
shape=(self.units,),
initializer='zeros',
regularizer=None
)
# Reset gate
self.w_r = self.add_weight(
name="w_reset",
shape=(self.input_dim, self.units),
initializer='random_normal',
regularizer="l2"
)
self.u_r = self.add_weight(
name="u_reset",
shape=(self.units, self.units),
initializer='random_normal',
regularizer=None
)
self.b_r = self.add_weight(
name="b_reset",
shape=(self.units,),
initializer='zeros',
regularizer=None
)
# Memory content
self.w_h = self.add_weight(
name="w_memory",
shape=(self.input_dim, self.units),
initializer='random_normal',
regularizer="l2"
)
self.u_h = self.add_weight(
name="u_memory",
shape=(self.units, self.units),
initializer='random_normal',
regularizer="l2"
)
self.b_h = self.add_weight(
name="b_memory",
shape=(self.units,),
initializer='zeros',
regularizer=None
)
def call(self, inputs, hidden_states):
# Mask the hidden state to reset it at timesteps with finished environments
input, mask = tf.split(inputs, 2, axis=1)
mask = tf.matmul(mask, tf.ones((mask.shape[-1], self.units))) / self.input_dim
h_masked = hidden_states[0] * mask
# Compute update and reset gates
z_t = tf.nn.sigmoid(tf.matmul(input, self.w_z) + tf.matmul(h_masked, self.u_z) + self.b_z)
r_t = tf.nn.sigmoid(tf.matmul(input, self.w_r) + tf.matmul(h_masked, self.u_r) + self.b_r)
# Compute current hidden state (memory content)
h_t = tf.nn.tanh(tf.matmul(input, self.w_h) + tf.matmul((r_t * h_masked), self.u_h) + self.b_h)
h_t = (z_t * h_t) + ((1 - z_t) * h_masked)
h_t_forward = tf.concat((h_t, mask), axis=1)
return h_t_forward, [h_t]