From 2f66d38f4843f17f1c2bbfedb7952c3073fe4ed6 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 9 Apr 2024 19:32:21 +0800 Subject: [PATCH 1/2] Apply SAITS embedding strategy to new added models (#343) * feat: make Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer take the missing mask as a part of input; * refactor: separate n_steps of K and Q, this is necessary for Crossformer; --- pypots/imputation/autoformer/modules/core.py | 22 +++-- pypots/imputation/crossformer/modules/core.py | 20 +++-- .../crossformer/modules/submodules.py | 9 +- pypots/imputation/dlinear/model.py | 9 +- pypots/imputation/dlinear/modules/core.py | 87 +++++++++++++------ pypots/imputation/etsformer/modules/core.py | 14 ++- pypots/imputation/fedformer/modules/core.py | 18 ++-- pypots/imputation/informer/__init__.py | 1 - pypots/imputation/informer/modules/core.py | 24 +++-- pypots/imputation/patchtst/modules/core.py | 30 ++++--- pypots/nn/modules/transformer/attention.py | 5 +- tests/imputation/dlinear.py | 23 +++++ 12 files changed, 190 insertions(+), 72 deletions(-) diff --git a/pypots/imputation/autoformer/modules/core.py b/pypots/imputation/autoformer/modules/core.py index c3747fde..14cdb53c 100644 --- a/pypots/imputation/autoformer/modules/core.py +++ b/pypots/imputation/autoformer/modules/core.py @@ -5,6 +5,7 @@ # Created by Wenjie Du # License: BSD-3-Clause +import torch import torch.nn as nn from .submodules import ( @@ -38,7 +39,7 @@ def __init__( self.seq_len = n_steps self.n_layers = n_layers self.enc_embedding = DataEmbedding( - n_features, + n_features * 2, d_model, dropout=dropout, with_pos=False, @@ -63,28 +64,35 @@ def __init__( ) # for the imputation task, the output dim is the same as input dim - self.projection = nn.Linear(d_model, n_features) + self.output_projection = nn.Linear(d_model, n_features) def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - # embedding - enc_out = self.enc_embedding(X) # [B,T,C] + # WDU: the original Autoformer paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the + # embedding layers to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + + # the same as SAITS, concatenate the time series data and the missing mask for embedding + input_X = torch.cat([X, masks], dim=2) + enc_out = self.enc_embedding(input_X) # Autoformer encoder processing enc_out, attns = self.encoder(enc_out) # project back the original data space - dec_out = self.projection(enc_out) + output = self.output_projection(enc_out) - imputed_data = masks * X + (1 - masks) * dec_out + imputed_data = masks * X + (1 - masks) * output results = { "imputed_data": imputed_data, } if training: # `loss` is always the item for backward propagating to update the model - loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) results["loss"] = loss return results diff --git a/pypots/imputation/crossformer/modules/core.py b/pypots/imputation/crossformer/modules/core.py index 0cc9b07a..8eb04df6 100644 --- a/pypots/imputation/crossformer/modules/core.py +++ b/pypots/imputation/crossformer/modules/core.py @@ -33,6 +33,7 @@ def __init__( super().__init__() self.n_features = n_features + self.d_model = d_model # The padding operation to handle invisible sgemnet length pad_in_len = ceil(1.0 * n_steps / seg_len) * seg_len @@ -49,7 +50,7 @@ def __init__( 0, ) self.enc_pos_embedding = nn.Parameter( - torch.randn(1, n_features, in_seg_num, d_model) + torch.randn(1, d_model, in_seg_num, d_model) ) self.pre_norm = nn.LayerNorm(d_model) @@ -71,31 +72,40 @@ def __init__( ) self.head = FlattenHead(head_nf, n_steps, dropout) + self.embedding = nn.Linear(n_features * 2, d_model) + self.output_projection = nn.Linear(d_model, n_features) def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] + # WDU: the original Crossformer paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the + # embedding layers to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. # embedding - x_enc = self.enc_value_embedding(X.permute(0, 2, 1)) + input_X = self.embedding(torch.cat([X, masks], dim=2)) + x_enc = self.enc_value_embedding(input_X.permute(0, 2, 1)) # Crossformer processing x_enc = rearrange( - x_enc, "(b d) seg_num d_model -> b d seg_num d_model", d=self.n_features + x_enc, "(b d) seg_num d_model -> b d seg_num d_model", d=self.d_model ) x_enc += self.enc_pos_embedding x_enc = self.pre_norm(x_enc) enc_out, attns = self.encoder(x_enc) # project back the original data space dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1) + output = self.output_projection(dec_out) - imputed_data = masks * X + (1 - masks) * dec_out + imputed_data = masks * X + (1 - masks) * output results = { "imputed_data": imputed_data, } if training: # `loss` is always the item for backward propagating to update the model - loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) results["loss"] = loss return results diff --git a/pypots/imputation/crossformer/modules/submodules.py b/pypots/imputation/crossformer/modules/submodules.py index 2a67a227..6a6f1c7b 100644 --- a/pypots/imputation/crossformer/modules/submodules.py +++ b/pypots/imputation/crossformer/modules/submodules.py @@ -144,11 +144,12 @@ def __init__( d_ff, depth, dropout, - seg_num=10, - factor=10, + seg_num, + factor, ): super().__init__() + d_k = d_model // n_heads if win_size > 1: self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) else: @@ -158,7 +159,9 @@ def __init__( for i in range(depth): self.encode_layers.append( - TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, d_ff, dropout) + TwoStageAttentionLayer( + seg_num, factor, d_model, n_heads, d_k, d_k, d_ff, dropout + ) ) def forward(self, x, attn_mask=None, tau=None, delta=None): diff --git a/pypots/imputation/dlinear/model.py b/pypots/imputation/dlinear/model.py index d5c5e84a..f6c89976 100644 --- a/pypots/imputation/dlinear/model.py +++ b/pypots/imputation/dlinear/model.py @@ -47,7 +47,11 @@ class DLinear(BaseNNImputer): The window size of moving average. individual : - Whether to share model across different features. + Whether to make a linear layer for each variate/channel/feature individually. + + d_model: + The dimension of the space in which the time-series data will be embedded and modeled. + It is necessary only for DLinear in the non-individual mode. batch_size : The batch size for training and evaluating the model. @@ -96,6 +100,7 @@ def __init__( n_features: int, moving_avg_window_size: int, individual: bool = False, + d_model: Optional[int] = None, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -120,6 +125,7 @@ def __init__( # model hype-parameters self.moving_avg_window_size = moving_avg_window_size self.individual = individual + self.d_model = d_model # set up the model self.model = _DLinear( @@ -127,6 +133,7 @@ def __init__( n_features, moving_avg_window_size, individual, + d_model, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/dlinear/modules/core.py b/pypots/imputation/dlinear/modules/core.py index e8e5ec35..18f33cec 100644 --- a/pypots/imputation/dlinear/modules/core.py +++ b/pypots/imputation/dlinear/modules/core.py @@ -5,6 +5,8 @@ # Created by Wenjie Du # License: BSD-3-Clause +from typing import Optional + import torch import torch.nn as nn @@ -19,6 +21,7 @@ def __init__( n_features: int, moving_avg_window_size: int, individual: bool = False, + d_model: Optional[int] = None, ): super().__init__() @@ -28,39 +31,48 @@ def __init__( self.individual = individual if individual: - self.Linear_Seasonal = nn.ModuleList() - self.Linear_Trend = nn.ModuleList() - - for i in range(self.n_features): - self.Linear_Seasonal.append(nn.Linear(self.n_steps, self.n_steps)) - self.Linear_Trend.append(nn.Linear(self.n_steps, self.n_steps)) - - self.Linear_Seasonal[i].weight = nn.Parameter( - (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + # create linear layers for each feature individually + self.linear_seasonal = nn.ModuleList() + self.linear_trend = nn.ModuleList() + for i in range(n_features): + self.linear_seasonal.append(nn.Linear(n_steps, n_steps)) + self.linear_trend.append(nn.Linear(n_steps, n_steps)) + self.linear_seasonal[i].weight = nn.Parameter( + (1 / n_steps) * torch.ones([n_steps, n_steps]) ) - self.Linear_Trend[i].weight = nn.Parameter( - (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + self.linear_trend[i].weight = nn.Parameter( + (1 / n_steps) * torch.ones([n_steps, n_steps]) ) else: - self.Linear_Seasonal = nn.Linear(self.n_steps, self.n_steps) - self.Linear_Trend = nn.Linear(self.n_steps, self.n_steps) - - self.Linear_Seasonal.weight = nn.Parameter( - (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + if d_model is None: + raise ValueError( + "The argument d_model is necessary for DLinear in the non-individual mode." + ) + self.linear_seasonal = nn.Linear(n_steps, n_steps) + self.linear_trend = nn.Linear(n_steps, n_steps) + self.linear_seasonal.weight = nn.Parameter( + (1 / n_steps) * torch.ones([n_steps, n_steps]) ) - self.Linear_Trend.weight = nn.Parameter( - (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + self.linear_trend.weight = nn.Parameter( + (1 / n_steps) * torch.ones([n_steps, n_steps]) ) + self.linear_seasonal_embedding = nn.Linear(n_features * 2, d_model) + self.linear_trend_embedding = nn.Linear(n_features * 2, d_model) + self.linear_seasonal_output = nn.Linear(d_model, n_features) + self.linear_trend_output = nn.Linear(d_model, n_features) + def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - # DLinear encoder processing + # input preprocessing and embedding for DLinear seasonal_init, trend_init = self.series_decomp(X) - seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute( - 0, 2, 1 - ) + + # DLinear processing if self.individual: + seasonal_init, trend_init = seasonal_init.permute( + 0, 2, 1 + ), trend_init.permute(0, 2, 1) seasonal_output = torch.zeros( [seasonal_init.size(0), seasonal_init.size(1), self.n_steps], dtype=seasonal_init.dtype, @@ -70,15 +82,36 @@ def forward(self, inputs: dict, training: bool = True) -> dict: dtype=trend_init.dtype, ).to(trend_init.device) for i in range(self.n_features): - seasonal_output[:, i, :] = self.Linear_Seasonal[i]( + seasonal_output[:, i, :] = self.linear_seasonal[i]( seasonal_init[:, i, :] ) - trend_output[:, i, :] = self.Linear_Trend[i](trend_init[:, i, :]) + trend_output[:, i, :] = self.linear_trend[i](trend_init[:, i, :]) + + seasonal_output = seasonal_output.permute(0, 2, 1) + trend_output = trend_output.permute(0, 2, 1) else: - seasonal_output = self.Linear_Seasonal(seasonal_init) - trend_output = self.Linear_Trend(trend_init) + # WDU: the original DLinear paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the + # embedding layers to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project the seasonal and trend from the hidden space to the original space. + # But this is only for the non-individual mode. + seasonal_init = torch.cat([seasonal_init, masks], dim=2) + trend_init = torch.cat([trend_init, masks], dim=2) + seasonal_init = self.linear_seasonal_embedding(seasonal_init) + trend_init = self.linear_trend_embedding(trend_init) + seasonal_init, trend_init = seasonal_init.permute( + 0, 2, 1 + ), trend_init.permute(0, 2, 1) + + seasonal_output = self.linear_seasonal(seasonal_init) + trend_output = self.linear_trend(trend_init) + seasonal_output = seasonal_output.permute(0, 2, 1) + trend_output = trend_output.permute(0, 2, 1) + seasonal_output = self.linear_seasonal_output(seasonal_output) + trend_output = self.linear_trend_output(trend_output) + output = seasonal_output + trend_output - output = output.permute(0, 2, 1) imputed_data = masks * X + (1 - masks) * output results = { diff --git a/pypots/imputation/etsformer/modules/core.py b/pypots/imputation/etsformer/modules/core.py index 1906174c..57faa6de 100644 --- a/pypots/imputation/etsformer/modules/core.py +++ b/pypots/imputation/etsformer/modules/core.py @@ -5,6 +5,7 @@ # Created by Wenjie Du # License: BSD-3-Clause +import torch import torch.nn as nn from .submodules import ( @@ -36,7 +37,7 @@ def __init__( self.n_steps = n_steps self.enc_embedding = DataEmbedding( - n_features, + n_features * 2, d_model, dropout=dropout, ) @@ -76,8 +77,15 @@ def __init__( def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - # embedding - res = self.enc_embedding(X) + # WDU: the original ETSformer paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the + # embedding layers to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + + # the same as SAITS, concatenate the time series data and the missing mask for embedding + input_X = torch.cat([X, masks], dim=2) + res = self.enc_embedding(input_X) # ETSformer encoder processing level, growths, seasons = self.encoder(res, X, attn_mask=None) diff --git a/pypots/imputation/fedformer/modules/core.py b/pypots/imputation/fedformer/modules/core.py index 895cf8d4..00f5241a 100644 --- a/pypots/imputation/fedformer/modules/core.py +++ b/pypots/imputation/fedformer/modules/core.py @@ -5,6 +5,7 @@ # Created by Wenjie Du # License: BSD-3-Clause +import torch import torch.nn as nn from .submodules import MultiWaveletTransform, FourierBlock @@ -37,7 +38,7 @@ def __init__( super().__init__() self.enc_embedding = DataEmbedding( - n_features, + n_features * 2, d_model, dropout=dropout, ) @@ -75,17 +76,24 @@ def __init__( ], norm_layer=SeasonalLayerNorm(d_model), ) - self.projection = nn.Linear(d_model, n_features) + self.output_projection = nn.Linear(d_model, n_features) def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - # embedding - enc_out = self.enc_embedding(X) + # WDU: the original FEDformer paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the + # embedding layers to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + + # the same as SAITS, concatenate the time series data and the missing mask for embedding + input_X = torch.cat([X, masks], dim=2) + enc_out = self.enc_embedding(input_X) # FEDformer encoder processing enc_out, attns = self.encoder(enc_out) - output = self.projection(enc_out) + output = self.output_projection(enc_out) imputed_data = masks * X + (1 - masks) * output results = { diff --git a/pypots/imputation/informer/__init__.py b/pypots/imputation/informer/__init__.py index 557abbaf..298d2345 100644 --- a/pypots/imputation/informer/__init__.py +++ b/pypots/imputation/informer/__init__.py @@ -7,7 +7,6 @@ In Proceedings of the AAAI conference on artificial intelligence, volume 35, pages 11106–11115, 2021. `_ - """ # Created by Wenjie Du diff --git a/pypots/imputation/informer/modules/core.py b/pypots/imputation/informer/modules/core.py index 455a7b1a..e6240c63 100644 --- a/pypots/imputation/informer/modules/core.py +++ b/pypots/imputation/informer/modules/core.py @@ -5,11 +5,12 @@ # Created by Wenjie Du # License: BSD-3-Clause +import torch import torch.nn as nn from .submodules import ProbAttention, ConvLayer, InformerEncoderLayer, InformerEncoder -from ....nn.modules.transformer.embedding import DataEmbedding from ....nn.modules.transformer import MultiHeadAttention +from ....nn.modules.transformer.embedding import DataEmbedding from ....utils.metrics import calc_mse @@ -33,7 +34,7 @@ def __init__( self.seq_len = n_steps self.n_layers = n_layers self.enc_embedding = DataEmbedding( - n_features, + n_features * 2, d_model, dropout=dropout, ) @@ -59,28 +60,35 @@ def __init__( ) # for the imputation task, the output dim is the same as input dim - self.projection = nn.Linear(d_model, n_features) + self.output_projection = nn.Linear(d_model, n_features) def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - # embedding - enc_out = self.enc_embedding(X) + # WDU: the original Informer paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the + # embedding layers to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + + # the same as SAITS, concatenate the time series data and the missing mask for embedding + input_X = torch.cat([X, masks], dim=2) + enc_out = self.enc_embedding(input_X) # Informer encoder processing enc_out, attns = self.encoder(enc_out) # project back the original data space - dec_out = self.projection(enc_out) + output = self.output_projection(enc_out) - imputed_data = masks * X + (1 - masks) * dec_out + imputed_data = masks * X + (1 - masks) * output results = { "imputed_data": imputed_data, } if training: # `loss` is always the item for backward propagating to update the model - loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) results["loss"] = loss return results diff --git a/pypots/imputation/patchtst/modules/core.py b/pypots/imputation/patchtst/modules/core.py index 9013a802..c1fc97c7 100644 --- a/pypots/imputation/patchtst/modules/core.py +++ b/pypots/imputation/patchtst/modules/core.py @@ -5,6 +5,7 @@ # Created by Wenjie Du # License: BSD-3-Clause +import torch import torch.nn as nn from .submodules import PatchEmbedding, FlattenHead @@ -38,7 +39,9 @@ def __init__( self.n_steps = n_steps self.n_features = n_features self.n_layers = n_layers + self.d_model = d_model + self.embedding = nn.Linear(n_features * 2, d_model) self.patch_embedding = PatchEmbedding( d_model, patch_len, stride, padding, dropout ) @@ -57,38 +60,45 @@ def __init__( ] ) self.head = FlattenHead(head_nf, n_steps, dropout) + self.output_projection = nn.Linear(d_model, n_features) def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] + # WDU: the original PatchTST paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I add the + # embedding layers to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + # do patching and embedding - x_enc = X.permute(0, 2, 1) - # u: [bs * n_features x patch_num x d_model] - enc_out = self.patch_embedding(x_enc) + input_X = self.embedding(torch.cat([X, masks], dim=2)) + enc_out = self.patch_embedding(input_X.permute(0, 2, 1)) # PatchTST encoder processing - # z: [bs * n_features x patch_num x d_model] + # z: [bs * d_model x patch_num x d_model] for i in range(self.n_layers): enc_out, _ = self.encoder[i](enc_out) - # z: [bs x n_features x patch_num x d_model] + # z: [bs x d_model x patch_num x d_model] enc_out = enc_out.reshape( - -1, self.n_features, enc_out.shape[-2], enc_out.shape[-1] + -1, self.d_model, enc_out.shape[-2], enc_out.shape[-1] ) - # z: [bs x n_features x d_model x patch_num] + # z: [bs x d_model x d_model x patch_num] enc_out = enc_out.permute(0, 1, 3, 2) # project back the original data space - dec_out = self.head(enc_out) # z: [bs x n_features x target_window] + dec_out = self.head(enc_out) # z: [bs x d_model x target_window] dec_out = dec_out.permute(0, 2, 1) + output = self.output_projection(dec_out) - imputed_data = masks * X + (1 - masks) * dec_out + imputed_data = masks * X + (1 - masks) * output results = { "imputed_data": imputed_data, } if training: # `loss` is always the item for backward propagating to update the model - loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) results["loss"] = loss return results diff --git a/pypots/nn/modules/transformer/attention.py b/pypots/nn/modules/transformer/attention.py index 1c23efd8..448abf7c 100644 --- a/pypots/nn/modules/transformer/attention.py +++ b/pypots/nn/modules/transformer/attention.py @@ -195,11 +195,12 @@ def forward( # keep useful variables batch_size, n_steps = q.size(0), q.size(1) + k_n_steps = k.size(1) # now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v] q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k) - k = self.w_ks(k).view(batch_size, n_steps, self.n_heads, self.d_k) - v = self.w_vs(v).view(batch_size, n_steps, self.n_heads, self.d_v) + k = self.w_ks(k).view(batch_size, k_n_steps, self.n_heads, self.d_k) + v = self.w_vs(v).view(batch_size, k_n_steps, self.n_heads, self.d_v) # transpose for self-attention calculation -> [batch_size, n_steps, d_k or d_v, n_heads] q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) diff --git a/tests/imputation/dlinear.py b/tests/imputation/dlinear.py index e2680b23..c1351305 100644 --- a/tests/imputation/dlinear.py +++ b/tests/imputation/dlinear.py @@ -47,15 +47,30 @@ class TestDLinear(unittest.TestCase): DATA["n_features"], moving_avg_window_size=3, individual=False, + d_model=128, epochs=EPOCHS, saving_path=saving_path, optimizer=optimizer, device=DEVICE, ) + individual_optimizer = Adam(lr=0.001, weight_decay=1e-5) + individual_dlinear = DLinear( + DATA["n_steps"], + DATA["n_features"], + moving_avg_window_size=3, + individual=True, + d_model=None, # d_model is useless for DLinear in the individual mode + epochs=EPOCHS, + saving_path=saving_path, + optimizer=individual_optimizer, + device=DEVICE, + ) + @pytest.mark.xdist_group(name="imputation-dlinear") def test_0_fit(self): self.dlinear.fit(TRAIN_SET, VAL_SET) + self.individual_dlinear.fit(TRAIN_SET, VAL_SET) @pytest.mark.xdist_group(name="imputation-dlinear") def test_1_impute(self): @@ -71,6 +86,14 @@ def test_1_impute(self): ) logger.info(f"DLinear test_MSE: {test_MSE}") + imputation_results = self.individual_dlinear.predict(TEST_SET) + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Individual DLinear test_MSE: {test_MSE}") + @pytest.mark.xdist_group(name="imputation-dlinear") def test_2_parameters(self): assert hasattr(self.dlinear, "model") and self.dlinear.model is not None From 0df20d06ea732a46c8fee7183cf0a496d527bafe Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 9 Apr 2024 20:47:42 +0800 Subject: [PATCH 2/2] docs: update README; --- README.md | 69 ++++++++++++++++++++++++---------------------- pypots/__init__.py | 2 +- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 86e52e6f..9b294fa2 100644 --- a/README.md +++ b/README.md @@ -192,39 +192,42 @@ The paper references are all listed at the bottom of this readme file. Please re 🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support. This functionality is implemented with the [Microsoft NNI](https://github.com/microsoft/nni) framework. -| ***`Imputation`*** | 🚥 | 🚥 | 🚥 | -|:----------------------:|:-----------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:| -| **Type** | **Abbr.** | **Full name of the algorithm/model** | **Year** | -| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 | -| Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 | -| Neural Net | Crossformer | Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting [^16] | 2023 | -| Neural Net | TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis [^14] | 2023 | -| Neural Net | PatchTST | A Time Series is Worth 64 Words: Long-Term Forecasting with Transformers [^18] | 2023 | -| Neural Net | DLinear | Are Transformers Effective for Time Series Forecasting? [^17] | 2023 | -| Neural Net | ETSformer | Exponential Smoothing Transformers for Time-series Forecasting [^19] | 2023 | -| Neural Net | FEDformer | Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting [^20] | 2022 | -| Neural Net | Informer | Beyond Efficient Transformer for Long Sequence Time-Series Forecasting [^21] | 2021 | -| Neural Net | Autoformer | Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting [^15] | 2021 | -| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 | -| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 | -| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 | -| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | -| Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 | -| Naive | LOCF/NOCB | Last Observation Carried Forward / Next Observation Carried Backward | - | -| Naive | Median | Median Value Imputation | - | -| Naive | Mean | Mean Value Imputation | - | -| ***`Classification`*** | 🚥 | 🚥 | 🚥 | -| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | -| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | -| Neural Net | GRU-D | Recurrent Neural Networks for Multivariate Time Series with Missing Values [^4] | 2018 | -| Neural Net | Raindrop | Graph-Guided Network for Irregularly Sampled Multivariate Time Series [^5] | 2022 | -| ***`Clustering`*** | 🚥 | 🚥 | 🚥 | -| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | -| Neural Net | CRLI | Clustering Representation Learning on Incomplete time-series data [^6] | 2021 | -| Neural Net | VaDER | Variational Deep Embedding with Recurrence [^7] | 2019 | -| ***`Forecasting`*** | 🚥 | 🚥 | 🚥 | -| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | -| Probabilistic | BTTF | Bayesian Temporal Tensor Factorization [^8] | 2021 | +🔥 Note that Transformer, Crossformer, PatchTST, DLinear, ETSformer, FEDformer, Informer, Autoformer are not proposed as imputation methods in their original papers, +and they cannot accept POTS as input. **To make them applicable on POTS data, we apply the embedding strategy the same as we did in [SAITS paper](https://arxiv.org/pdf/2202.08516).** + +| ***`Imputation`*** | 🚥 | 🚥 | 🚥 | +|:----------------------:|:-----------:|:-----------------------------------------------------------------------------------------------:|:--------:| +| **Type** | **Abbr.** | **Full name of the algorithm/model** | **Year** | +| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 | +| Neural Net | Transformer | Attention is All you Need [^2] | 2017 | +| Neural Net | Crossformer | Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting [^16] | 2023 | +| Neural Net | TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis [^14] | 2023 | +| Neural Net | PatchTST | A Time Series is Worth 64 Words: Long-Term Forecasting with Transformers [^18] | 2023 | +| Neural Net | DLinear | Are Transformers Effective for Time Series Forecasting? [^17] | 2023 | +| Neural Net | ETSformer | Exponential Smoothing Transformers for Time-series Forecasting [^19] | 2023 | +| Neural Net | FEDformer | Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting [^20] | 2022 | +| Neural Net | Informer | Beyond Efficient Transformer for Long Sequence Time-Series Forecasting [^21] | 2021 | +| Neural Net | Autoformer | Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting [^15] | 2021 | +| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 | +| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 | +| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 | +| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | +| Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 | +| Naive | LOCF/NOCB | Last Observation Carried Forward / Next Observation Carried Backward | - | +| Naive | Median | Median Value Imputation | - | +| Naive | Mean | Mean Value Imputation | - | +| ***`Classification`*** | 🚥 | 🚥 | 🚥 | +| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | +| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | +| Neural Net | GRU-D | Recurrent Neural Networks for Multivariate Time Series with Missing Values [^4] | 2018 | +| Neural Net | Raindrop | Graph-Guided Network for Irregularly Sampled Multivariate Time Series [^5] | 2022 | +| ***`Clustering`*** | 🚥 | 🚥 | 🚥 | +| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | +| Neural Net | CRLI | Clustering Representation Learning on Incomplete time-series data [^6] | 2021 | +| Neural Net | VaDER | Variational Deep Embedding with Recurrence [^7] | 2019 | +| ***`Forecasting`*** | 🚥 | 🚥 | 🚥 | +| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | +| Probabilistic | BTTF | Bayesian Temporal Tensor Factorization [^8] | 2021 | ## ❖ Citing PyPOTS diff --git a/pypots/__init__.py b/pypots/__init__.py index 8075ec06..566339bd 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.3.2" +__version__ = "0.4" from . import imputation, classification, clustering, forecasting, optim, data, utils