Skip to content

Commit 99bbfc8

Browse files
committed
Add T-PE and V-PE methods
1 parent 3e94b49 commit 99bbfc8

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

src/positional_encodings.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,4 +175,32 @@ def forward(self, x):
175175
x = x.permute(0, 2, 1) # Change shape to (batch_size, in_features, seq_len)
176176
x = self.conv(x) # Apply convolution
177177
x = x.permute(0, 2, 1) # Change shape back to (batch_size, seq_len, in_features)
178-
return x
178+
return x
179+
180+
181+
# Temporal Positional Encoding (T-PE)
182+
class TemporalPositionalEncoding(nn.Module):
183+
def __init__(self, d_model, max_len=896): # Assuming 896 timesteps
184+
super(TemporalPositionalEncoding, self).__init__()
185+
pe = torch.zeros(max_len, d_model)
186+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
187+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
188+
189+
pe[:, 0::2] = torch.sin(position * div_term)
190+
pe[:, 1::2] = torch.cos(position * div_term)
191+
self.register_buffer('pe', pe)
192+
193+
def forward(self, x):
194+
seq_len = x.size(1)
195+
return self.pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
196+
197+
198+
# Variable Positional Encoding for handling multivariate data
199+
class VariablePositionalEncoding(nn.Module):
200+
def __init__(self, d_model, num_variables):
201+
super(VariablePositionalEncoding, self).__init__()
202+
self.variable_embedding = nn.Embedding(num_variables, d_model)
203+
204+
def forward(self, x, variable_idx):
205+
variable_embed = self.variable_embedding(variable_idx)
206+
return x + variable_embed.unsqueeze(0)

0 commit comments

Comments
 (0)