@@ -193,4 +193,82 @@ def forward(self, x: torch.tensor) -> torch.tensor:
193
193
out = out .transpose (1 , 2 ).reshape (batch_size , seq_len , - 1 )
194
194
# out.shape == (batch_size, seq_len, d_model)
195
195
out = self .dropout (out )
196
- return out
196
+ return out
197
+
198
+
199
+ class TPS_SelfAttention (nn .Module ):
200
+ def __init__ (self , num_heads , model_dim , max_len , pow = 2 , LrEnb = 0 , LrMo = 0 , dropout = 0.1 ):
201
+ super (TPS_SelfAttention_Author , self ).__init__ ()
202
+ assert model_dim % num_heads == 0
203
+ self .num_heads = num_heads
204
+ self .model_dim = model_dim
205
+ self .d_k = model_dim // num_heads
206
+ self .pow = pow
207
+ self .LrEnb = LrEnb
208
+ self .LrMo = LrMo
209
+ self .Max_Len = max_len
210
+
211
+ # Initialize multi-head attention module
212
+ self .attention = nn .MultiheadAttention (embed_dim = model_dim , num_heads = num_heads , dropout = dropout )
213
+
214
+ # Precompute temporal distance
215
+ t = torch .arange (0 , max_len , dtype = torch .float )
216
+ t1 = t .repeat (max_len , 1 )
217
+ t2 = t1 .permute ([1 , 0 ])
218
+
219
+ if pow == 2 :
220
+ dis1 = torch .exp (- 1 * torch .pow ((t2 - t1 ), 2 ) / 2 )
221
+ self .dist = nn .Parameter (- 1 * torch .pow ((t2 - t1 ), 2 ) / 2 , requires_grad = False )
222
+ else :
223
+ dis1 = torch .exp (- 1 * torch .abs ((t2 - t1 )))
224
+ self .dist = nn .Parameter (- 1 * torch .abs ((t2 - t1 )), requires_grad = False )
225
+
226
+ if LrEnb :
227
+ self .adj1 = nn .Parameter (dis1 ) # Learnable temporal weighting
228
+
229
+ # Dropout layer for regularization
230
+ self .dropout = nn .Dropout (p = dropout )
231
+
232
+ def forward (self , q , k , v , mask = None ):
233
+ # q, k, v expected to have shape: [seq_len, batch_size, embedding_dim]
234
+
235
+ # Multi-head attention forward pass
236
+ attn_output , attn_weights = self .attention (q , k , v , key_padding_mask = mask )
237
+
238
+ # Apply Gaussian-based temporal weighting
239
+ seq_len = q .size (0 ) # Extract sequence length from q
240
+ batch_size = q .size (1 ) # Extract batch size
241
+ num_heads = self .num_heads
242
+
243
+ # Dynamically compute temporal distance matrix based on seq_len
244
+ t = torch .arange (0 , seq_len , dtype = torch .float , device = q .device )
245
+ t1 = t .repeat (seq_len , 1 )
246
+ t2 = t1 .permute ([1 , 0 ])
247
+
248
+ if self .pow == 2 :
249
+ dist_matrix = torch .exp (- 1 * torch .pow ((t2 - t1 ), 2 ) / 2 )
250
+ else :
251
+ dist_matrix = torch .exp (- 1 * torch .abs ((t2 - t1 )))
252
+
253
+ # Expand dist_matrix to match the shape of attn_weights [batch_size, seq_len, seq_len]
254
+ expanded_dist = dist_matrix .unsqueeze (0 ).expand (batch_size , - 1 , - 1 )
255
+
256
+ # Apply Gaussian decay to attention weights
257
+ weighted_attn = attn_weights * expanded_dist
258
+
259
+ # Normalize attention scores
260
+ weighted_attn = weighted_attn / weighted_attn .sum (dim = - 1 , keepdim = True )
261
+
262
+ # Reshape `v` to [batch_size, seq_len, model_dim] for correct matrix multiplication
263
+ v = v .permute (1 , 0 , 2 ) # Change `v` from [seq_len, batch_size, model_dim] to [batch_size, seq_len, model_dim]
264
+
265
+ # Matrix multiplication for attention output
266
+ output = torch .bmm (weighted_attn , v ) # Apply attention weights to the value matrix
267
+
268
+ # Reshape the output back to [seq_len, batch_size, model_dim]
269
+ output = output .permute (1 , 0 , 2 ) # Change back to [seq_len, batch_size, model_dim]
270
+
271
+ # Apply dropout to the attention output
272
+ output = self .dropout (output )
273
+
274
+ return output , weighted_attn
0 commit comments