Skip to content

Commit 07e1eeb

Browse files
committed
Add TPS method
1 parent 01dfb6c commit 07e1eeb

File tree

1 file changed

+79
-1
lines changed

1 file changed

+79
-1
lines changed

src/attentions.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,82 @@ def forward(self, x: torch.tensor) -> torch.tensor:
193193
out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
194194
# out.shape == (batch_size, seq_len, d_model)
195195
out = self.dropout(out)
196-
return out
196+
return out
197+
198+
199+
class TPS_SelfAttention(nn.Module):
200+
def __init__(self, num_heads, model_dim, max_len, pow=2, LrEnb=0, LrMo=0, dropout=0.1):
201+
super(TPS_SelfAttention_Author, self).__init__()
202+
assert model_dim % num_heads == 0
203+
self.num_heads = num_heads
204+
self.model_dim = model_dim
205+
self.d_k = model_dim // num_heads
206+
self.pow = pow
207+
self.LrEnb = LrEnb
208+
self.LrMo = LrMo
209+
self.Max_Len = max_len
210+
211+
# Initialize multi-head attention module
212+
self.attention = nn.MultiheadAttention(embed_dim=model_dim, num_heads=num_heads, dropout=dropout)
213+
214+
# Precompute temporal distance
215+
t = torch.arange(0, max_len, dtype=torch.float)
216+
t1 = t.repeat(max_len, 1)
217+
t2 = t1.permute([1, 0])
218+
219+
if pow == 2:
220+
dis1 = torch.exp(-1 * torch.pow((t2 - t1), 2) / 2)
221+
self.dist = nn.Parameter(-1 * torch.pow((t2 - t1), 2) / 2, requires_grad=False)
222+
else:
223+
dis1 = torch.exp(-1 * torch.abs((t2 - t1)))
224+
self.dist = nn.Parameter(-1 * torch.abs((t2 - t1)), requires_grad=False)
225+
226+
if LrEnb:
227+
self.adj1 = nn.Parameter(dis1) # Learnable temporal weighting
228+
229+
# Dropout layer for regularization
230+
self.dropout = nn.Dropout(p=dropout)
231+
232+
def forward(self, q, k, v, mask=None):
233+
# q, k, v expected to have shape: [seq_len, batch_size, embedding_dim]
234+
235+
# Multi-head attention forward pass
236+
attn_output, attn_weights = self.attention(q, k, v, key_padding_mask=mask)
237+
238+
# Apply Gaussian-based temporal weighting
239+
seq_len = q.size(0) # Extract sequence length from q
240+
batch_size = q.size(1) # Extract batch size
241+
num_heads = self.num_heads
242+
243+
# Dynamically compute temporal distance matrix based on seq_len
244+
t = torch.arange(0, seq_len, dtype=torch.float, device=q.device)
245+
t1 = t.repeat(seq_len, 1)
246+
t2 = t1.permute([1, 0])
247+
248+
if self.pow == 2:
249+
dist_matrix = torch.exp(-1 * torch.pow((t2 - t1), 2) / 2)
250+
else:
251+
dist_matrix = torch.exp(-1 * torch.abs((t2 - t1)))
252+
253+
# Expand dist_matrix to match the shape of attn_weights [batch_size, seq_len, seq_len]
254+
expanded_dist = dist_matrix.unsqueeze(0).expand(batch_size, -1, -1)
255+
256+
# Apply Gaussian decay to attention weights
257+
weighted_attn = attn_weights * expanded_dist
258+
259+
# Normalize attention scores
260+
weighted_attn = weighted_attn / weighted_attn.sum(dim=-1, keepdim=True)
261+
262+
# Reshape `v` to [batch_size, seq_len, model_dim] for correct matrix multiplication
263+
v = v.permute(1, 0, 2) # Change `v` from [seq_len, batch_size, model_dim] to [batch_size, seq_len, model_dim]
264+
265+
# Matrix multiplication for attention output
266+
output = torch.bmm(weighted_attn, v) # Apply attention weights to the value matrix
267+
268+
# Reshape the output back to [seq_len, batch_size, model_dim]
269+
output = output.permute(1, 0, 2) # Change back to [seq_len, batch_size, model_dim]
270+
271+
# Apply dropout to the attention output
272+
output = self.dropout(output)
273+
274+
return output, weighted_attn

0 commit comments

Comments
 (0)