Skip to content

Commit

Permalink
Merge pull request #425 from LinglongQian/main
Browse files Browse the repository at this point in the history
ETSformer hyperparameters mismatch during NNI tuning
  • Loading branch information
WenjieDu authored May 28, 2024
2 parents 8e25601 + ab63adb commit 7e38d08
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
16 changes: 8 additions & 8 deletions pypots/imputation/etsformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ class ETSformer(BaseNNImputer):

def __init__(
self,
n_steps,
n_features,
n_e_layers,
n_d_layers,
d_model,
n_heads,
d_ffn,
top_k,
n_steps: int,
n_features: int,
n_e_layers: int,
n_d_layers: int,
d_model: int,
n_heads: int,
d_ffn: int,
top_k: int,
dropout: float = 0,
ORT_weight: float = 1,
MIT_weight: float = 1,
Expand Down
14 changes: 7 additions & 7 deletions pypots/imputation/fedformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ class FEDformer(BaseNNImputer):

def __init__(
self,
n_steps,
n_features,
n_layers,
d_model,
n_heads,
d_ffn,
moving_avg_window_size,
n_steps: int,
n_features: int,
n_layers: int,
d_model: int,
n_heads: int,
d_ffn: int,
moving_avg_window_size: int,
dropout: float = 0,
version="Fourier",
modes=32,
Expand Down
2 changes: 1 addition & 1 deletion pypots/nn/modules/etsformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(self, x):
f = fft.rfftfreq(t)[self.low_freq :]

x_freq, index_tuple = self.topk_freq(x_freq)
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2))
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device)
f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device)

return self.extrapolate(x_freq, f, t)
Expand Down

0 comments on commit 7e38d08

Please sign in to comment.