-
Notifications
You must be signed in to change notification settings - Fork 18
/
transformer.py
298 lines (257 loc) · 12.5 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import numpy as np
import torch
from einops import rearrange
from torch import nn
from utils import Module
class MultiHeadAttention(nn.Module):
"""Multi Head Attention without dropout inspired by https://github.com/aladdinpersson/Machine-Learning-Collection
https://youtu.be/U0s0f995w14"""
def __init__(self, embed_dim, num_heads):
"""
Arguments:
embed_dim {int} -- Size of the embedding dimension
num_heads {int} -- Number of attention heads
"""
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_size = embed_dim // num_heads
assert (
self.head_size * num_heads == embed_dim
), "Embedding dimension needs to be divisible by the number of heads"
self.values = nn.Linear(embed_dim, embed_dim, bias=False)
self.keys = nn.Linear(embed_dim, embed_dim, bias=False)
self.queries = nn.Linear(embed_dim, embed_dim, bias=False)
self.fc_out = nn.Linear(embed_dim, embed_dim)
def forward(self, values, keys, queries, mask):
"""
The forward pass of the multi head attention layer.
Arguments:
values {torch.tensor} -- Value in shape of (N, L, D)
keys {torch.tensor} -- Keys in shape of (N, L, D)
queries {torch.tensor} -- Queries in shape of (N, L, D)
mask {torch.tensor} -- Attention mask in shape of (N, L)
Returns:
torch.tensor -- Output
torch.tensor -- Attention weights
"""
# Get number of training examples and sequence lengths
N = queries.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
# Split the embedding into self.num_heads different pieces
values = self.values(values) # (N, value_len, embed_dim)
keys = self.keys(keys) # (N, key_len, embed_dim)
queries = self.queries(queries) # (N, query_len, embed_dim)
values = values.reshape(N, value_len, self.num_heads, self.head_size) # (N, value_len, heads, head_dim)
keys = keys.reshape(N, key_len, self.num_heads, self.head_size) # (N, key_len, heads, head_dim)
queries = queries.reshape(N, query_len, self.num_heads, self.head_size) # (N, query_len, heads, heads_dim)
# Einsum does matrix mult. for query*keys for each training example
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# queries shape: (N, query_len, heads, heads_dim),
# keys shape: (N, key_len, heads, heads_dim)
# energy: (N, heads, query_len, key_len)
# Mask padded indices so their attention weights become 0
if mask is not None:
energy = energy.masked_fill(mask.unsqueeze(1).unsqueeze(1) == 0, float("-1e20")) # -inf causes NaN
# Normalize energy values and apply softmax wo retreive the attention scores
attention = torch.softmax(energy / (self.embed_dim ** (1 / 2)), dim=3)
# attention shape: (N, heads, query_len, key_len)
# Scale values by attention weights
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.num_heads * self.head_size
)
# attention shape: (N, heads, query_len, key_len)
# values shape: (N, value_len, heads, heads_dim)
# out after matrix multiply: (N, query_len, heads, head_dim), then
# we reshape and flatten the last two dimensions.
# Forward projection
out = self.fc_out(out)
# Linear layer doesn't modify the shape, final shape will be
# (N, query_len, embed_dim)
return out, attention
class TransformerBlock(Module):
def __init__(self, embed_dim, num_heads, config):
"""Transformer Block made of LayerNorms, Multi Head Attention and one fully connected feed forward projection.
Arguments:
embed_dim {int} -- Size of the embeddding dimension
num_heads {int} -- Number of attention headds
config {dict} -- General config
"""
super(TransformerBlock, self).__init__()
# Attention
self.attention = MultiHeadAttention(embed_dim, num_heads)
# Setup GTrXL if used
self.use_gtrxl = config["gtrxl"] if "gtrxl" in config else False
if self.use_gtrxl:
self.gate1 = GRUGate(embed_dim, config["gtrxl_bias"])
self.gate2 = GRUGate(embed_dim, config["gtrxl_bias"])
# LayerNorms
self.layer_norm = config["layer_norm"]
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
if self.layer_norm == "pre":
self.norm_kv = nn.LayerNorm(embed_dim)
# Feed forward projection
self.fc = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
def forward(self, value, key, query, mask):
"""
Arguments:
values {torch.tensor} -- Value in shape of (N, L, D)
keys {torch.tensor} -- Keys in shape of (N, L, D)
query {torch.tensor} -- Queries in shape of (N, L, D)
mask {torch.tensor} -- Attention mask in shape of (N, L)
Returns:
torch.tensor -- Output
torch.tensor -- Attention weights
"""
# Apply pre-layer norm across the attention input
if self.layer_norm == "pre":
query_ = self.norm1(query)
value = self.norm_kv(value)
key = value
else:
query_ = query
# Forward MultiHeadAttention
attention, attention_weights = self.attention(value, key, query_, mask)
# GRU Gate or skip connection
if self.use_gtrxl:
# Forward GRU gating
h = self.gate1(query, attention)
else:
# Skip connection
h = attention + query
# Apply post-layer norm across the attention output (i.e. projection input)
if self.layer_norm == "post":
h = self.norm1(h)
# Apply pre-layer norm across the projection input (i.e. attention output)
if self.layer_norm == "pre":
h_ = self.norm2(h)
else:
h_ = h
# Forward projection
forward = self.fc(h_)
# GRU Gate or skip connection
if self.use_gtrxl:
# Forward GRU gating
out = self.gate2(h, forward)
else:
# Skip connection
out = forward + h
# Apply post-layer norm across the projection output
if self.layer_norm == "post":
out = self.norm2(out)
return out, attention_weights
class SinusoidalPosition(nn.Module):
"""Relative positional encoding"""
def __init__(self, dim, min_timescale = 2., max_timescale = 1e4):
super().__init__()
freqs = torch.arange(0, dim, min_timescale)
inv_freqs = max_timescale ** (-freqs / dim)
self.register_buffer('inv_freqs', inv_freqs)
def forward(self, seq_len):
seq = torch.arange(seq_len - 1, -1, -1.)
sinusoidal_inp = rearrange(seq, 'n -> n ()') * rearrange(self.inv_freqs, 'd -> () d')
pos_emb = torch.cat((sinusoidal_inp.sin(), sinusoidal_inp.cos()), dim = -1)
return pos_emb
class Transformer(nn.Module):
"""Transformer encoder architecture without dropout. Positional encoding can be either "relative", "learned" or "" (none)."""
def __init__(self, config, input_dim, max_episode_steps) -> None:
"""Sets up the input embedding, positional encoding and the transformer blocks.
Arguments:
config {dict} -- Transformer config
input_dim {int} -- Dimension of the input
max_episode_steps {int} -- Maximum number of steps in an episode
"""
super().__init__()
self.config = config
self.num_blocks = config["num_blocks"]
self.embed_dim = config["embed_dim"]
self.num_heads = config["num_heads"]
self.max_episode_steps = max_episode_steps
self.activation = nn.ReLU()
# Input embedding layer
self.linear_embedding = nn.Linear(input_dim, self.embed_dim)
nn.init.orthogonal_(self.linear_embedding.weight, np.sqrt(2))
# Determine positional encoding
if config["positional_encoding"] == "relative":
self.pos_embedding = SinusoidalPosition(dim = self.embed_dim)
elif config["positional_encoding"] == "learned":
self.pos_embedding = nn.Parameter(torch.randn(self.max_episode_steps, self.embed_dim)) # (batch size, max episoded steps, num layers, layer size)
else:
pass # No positional encoding is used
# Instantiate transformer blocks
self.transformer_blocks = nn.ModuleList([
TransformerBlock(self.embed_dim, self.num_heads, config)
for _ in range(self.num_blocks)])
def forward(self, h, memories, mask, memory_indices):
"""
Arguments:
h {torch.tensor} -- Input (query)
memories {torch.tesnor} -- Whole episoded memories of shape (N, L, num blocks, D)
mask {torch.tensor} -- Attention mask (dtype: bool) of shape (N, L)
memory_indices {torch.tensor} -- Memory window indices (dtype: long) of shape (N, L)
Returns:
{torch.tensor} -- Output of the entire transformer encoder
{torch.tensor} -- Out memories (i.e. inputs to the transformer blocks)
"""
# Feed embedding layer and activate
h = self.activation(self.linear_embedding(h))
# Add positional encoding to every transformer block input
if self.config["positional_encoding"] == "relative":
pos_embedding = self.pos_embedding(self.max_episode_steps)[memory_indices]
memories = memories + pos_embedding.unsqueeze(2)
# memories[:,:,0] = memories[:,:,0] + pos_embedding # add positional encoding only to first layer?
elif self.config["positional_encoding"] == "learned":
memories = memories + self.pos_embedding[memory_indices].unsqueeze(2)
# memories[:,:,0] = memories[:,:,0] + self.pos_embedding[memory_indices] # add positional encoding only to first layer?
# Forward transformer blocks
out_memories = []
for i, block in enumerate(self.transformer_blocks):
out_memories.append(h.detach())
h, attention_weights = block(memories[:, :, i], memories[:, :, i], h.unsqueeze(1), mask) # args: value, key, query, mask
h = h.squeeze()
if len(h.shape) == 1:
h = h.unsqueeze(0)
return h, torch.stack(out_memories, dim=1)
class GRUGate(nn.Module):
"""
Overview:
GRU Gating Unit used in GTrXL.
Inspired by https://github.com/dhruvramani/Transformers-RL/blob/master/layers.py
"""
def __init__(self, input_dim: int, bg: float = 0.0):
"""
Arguments:
input_dim {int} -- Input dimension
bg {float} -- Initial gate bias value. By setting bg > 0 we can explicitly initialize the gating mechanism to
be close to the identity map. This can greatly improve the learning speed and stability since it
initializes the agent close to a Markovian policy (ignore attention at the beginning). (default: {0.0})
"""
super(GRUGate, self).__init__()
self.Wr = nn.Linear(input_dim, input_dim, bias=False)
self.Ur = nn.Linear(input_dim, input_dim, bias=False)
self.Wz = nn.Linear(input_dim, input_dim, bias=False)
self.Uz = nn.Linear(input_dim, input_dim, bias=False)
self.Wg = nn.Linear(input_dim, input_dim, bias=False)
self.Ug = nn.Linear(input_dim, input_dim, bias=False)
self.bg = nn.Parameter(torch.full([input_dim], bg)) # bias
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
nn.init.xavier_uniform_(self.Wr.weight)
nn.init.xavier_uniform_(self.Ur.weight)
nn.init.xavier_uniform_(self.Wz.weight)
nn.init.xavier_uniform_(self.Uz.weight)
nn.init.xavier_uniform_(self.Wg.weight)
nn.init.xavier_uniform_(self.Ug.weight)
def forward(self, x: torch.Tensor, y: torch.Tensor):
"""
Arguments:
x {torch.tensor} -- First input
y {torch.tensor} -- Second input
Returns:
{torch.tensor} -- Output
"""
r = self.sigmoid(self.Wr(y) + self.Ur(x))
z = self.sigmoid(self.Wz(y) + self.Uz(x) - self.bg)
h = self.tanh(self.Wg(y) + self.Ug(torch.mul(r, x)))
return torch.mul(1 - z, x) + torch.mul(z, h)