@@ -175,4 +175,32 @@ def forward(self, x):
175
175
x = x .permute (0 , 2 , 1 ) # Change shape to (batch_size, in_features, seq_len)
176
176
x = self .conv (x ) # Apply convolution
177
177
x = x .permute (0 , 2 , 1 ) # Change shape back to (batch_size, seq_len, in_features)
178
- return x
178
+ return x
179
+
180
+
181
+ # Temporal Positional Encoding (T-PE)
182
+ class TemporalPositionalEncoding (nn .Module ):
183
+ def __init__ (self , d_model , max_len = 896 ): # Assuming 896 timesteps
184
+ super (TemporalPositionalEncoding , self ).__init__ ()
185
+ pe = torch .zeros (max_len , d_model )
186
+ position = torch .arange (0 , max_len , dtype = torch .float ).unsqueeze (1 )
187
+ div_term = torch .exp (torch .arange (0 , d_model , 2 ).float () * (- math .log (10000.0 ) / d_model ))
188
+
189
+ pe [:, 0 ::2 ] = torch .sin (position * div_term )
190
+ pe [:, 1 ::2 ] = torch .cos (position * div_term )
191
+ self .register_buffer ('pe' , pe )
192
+
193
+ def forward (self , x ):
194
+ seq_len = x .size (1 )
195
+ return self .pe [:seq_len , :].unsqueeze (0 ).expand (x .size (0 ), - 1 , - 1 )
196
+
197
+
198
+ # Variable Positional Encoding for handling multivariate data
199
+ class VariablePositionalEncoding (nn .Module ):
200
+ def __init__ (self , d_model , num_variables ):
201
+ super (VariablePositionalEncoding , self ).__init__ ()
202
+ self .variable_embedding = nn .Embedding (num_variables , d_model )
203
+
204
+ def forward (self , x , variable_idx ):
205
+ variable_embed = self .variable_embedding (variable_idx )
206
+ return x + variable_embed .unsqueeze (0 )
0 commit comments