From 366a5842acde5176534079a0932dcb12b974e94c Mon Sep 17 00:00:00 2001 From: lss-1138 <57395990+lss-1138@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:55:14 +0800 Subject: [PATCH 1/3] Add SegRNN model --- README.md | 4 + pypots/imputation/__init__.py | 2 + pypots/imputation/segrnn/__init__.py | 24 +++ pypots/imputation/segrnn/core.py | 59 ++++++ pypots/imputation/segrnn/data.py | 21 ++ pypots/imputation/segrnn/model.py | 296 +++++++++++++++++++++++++++ pypots/nn/modules/segrnn/__init__.py | 23 +++ pypots/nn/modules/segrnn/backbone.py | 79 +++++++ 8 files changed, 508 insertions(+) create mode 100644 pypots/imputation/segrnn/__init__.py create mode 100644 pypots/imputation/segrnn/core.py create mode 100644 pypots/imputation/segrnn/data.py create mode 100644 pypots/imputation/segrnn/model.py create mode 100644 pypots/nn/modules/segrnn/__init__.py create mode 100644 pypots/nn/modules/segrnn/backbone.py diff --git a/README.md b/README.md index 47ee47fa..e024cc27 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ The paper references and links are all listed at the bottom of this file. | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | +| Neural Net | SegRNN[^42] | ✅ | | | | | `2023 - arXiv` | | Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | | Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | | Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | @@ -509,3 +510,6 @@ Time-Series.AI [^41]: Xu, Z., Zeng, A., & Xu, Q. (2024). [FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb). *ICLR 2024*. +[^42]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). +[Segrnn: Segment recurrent neural network for long-term time series forecasting](https://github.com/lss-1138/SegRNN) +*arXiv 2023*. diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 6600dcfd..9ce1d867 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -38,6 +38,7 @@ from .imputeformer import ImputeFormer from .timemixer import TimeMixer from .moderntcn import ModernTCN +from .segrnn import SegRNN # naive imputation methods from .locf import LOCF @@ -87,4 +88,5 @@ "Lerp", "TEFN", "CSAI", + "SegRNN", ] diff --git a/pypots/imputation/segrnn/__init__.py b/pypots/imputation/segrnn/__init__.py new file mode 100644 index 00000000..243345a8 --- /dev/null +++ b/pypots/imputation/segrnn/__init__.py @@ -0,0 +1,24 @@ +""" +The package including the modules of SegRNN. + +Refer to the paper +`Lin, Shengsheng and Lin, Weiwei and Wu, Wentai and Zhao, Feiyu and Mo, Ruichao and Zhang, Haotong. +Segrnn: Segment recurrent neural network for long-term time series forecasting. +arXiv preprint arXiv:2308.11200. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/lss-1138/SegRNN + +""" + +# Created by Shengsheng Lin + + + +from .model import SegRNN + +__all__ = [ + "SegRNN", +] diff --git a/pypots/imputation/segrnn/core.py b/pypots/imputation/segrnn/core.py new file mode 100644 index 00000000..c1d5f3a2 --- /dev/null +++ b/pypots/imputation/segrnn/core.py @@ -0,0 +1,59 @@ +""" +The core wrapper assembles the submodules of SegRNN imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Shengsheng Lin + +from typing import Optional + +from typing import Callable +import torch.nn as nn + +from ...nn.modules.segrnn import BackboneSegRNN +from ...nn.modules.saits import SaitsLoss + +class _SegRNN(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + seg_len: int = 24, + d_model: int = 512, + dropout: float = 0.5, + ORT_weight: float = 1, + MIT_weight: float = 1, + ): + super().__init__() + + self.n_steps = n_steps + self.n_features = n_features + self.seg_len = seg_len + self.d_model = d_model + self.dropout = dropout + + self.backbone = BackboneSegRNN(n_steps, n_features, seg_len, d_model, dropout) + + # apply SAITS loss function to Transformer on the imputation task + self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + reconstruction = self.backbone(X) + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + results["loss"] = loss + + return results diff --git a/pypots/imputation/segrnn/data.py b/pypots/imputation/segrnn/data.py new file mode 100644 index 00000000..a9eb728e --- /dev/null +++ b/pypots/imputation/segrnn/data.py @@ -0,0 +1,21 @@ +""" +Dataset class for the imputation model SegRNN. +""" + +# Created by Shengsheng lin + +from typing import Union + +from pypots.imputation.saits.data import DatasetForSAITS + + +class DatasetForSegRNN(DatasetForSAITS): + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/segrnn/model.py b/pypots/imputation/segrnn/model.py new file mode 100644 index 00000000..6e687084 --- /dev/null +++ b/pypots/imputation/segrnn/model.py @@ -0,0 +1,296 @@ +""" +The implementation of SegRNN for the partially-observed time-series imputation task. + +""" + +# Created by Shengsheng Lin + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _SegRNN +from .data import DatasetForSegRNN +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class SegRNN(BaseNNImputer): + """The PyTorch implementation of the SegRNN model. + SegRNN is originally proposed by Shengsheng Lin et al. in :cite:`lin2023segrnn`. + See detail in https://arxiv.org/abs/2308.11200 or https://github.com/lss-1138/SegRNN. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + seg_len : + The segment length for input of RNN. + + d_model: + The dimension of RNN cell. + + dropout : + The dropout rate of the output layer of SegRNN. + + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + """ + + def __init__( + self, + n_steps: int, + n_features: int, + seg_len: int = 24, + d_model: int = 512, + dropout: float = 0.5, + ORT_weight: float = 1, + MIT_weight: float = 1, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.seg_len = seg_len + self.d_model = d_model + self.dropout = dropout + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + + # set up the model + self.model = _SegRNN( + self.n_steps, + self.n_features, + self.seg_len, + self.d_model, + self.dropout, + self.ORT_weight, + self.MIT_weight, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForSegRNN(train_set, return_X_ori=False, return_y=False, file_type=file_type) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForSegRNN(val_set, return_X_ori=True, return_y=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] diff --git a/pypots/nn/modules/segrnn/__init__.py b/pypots/nn/modules/segrnn/__init__.py new file mode 100644 index 00000000..a367bf8d --- /dev/null +++ b/pypots/nn/modules/segrnn/__init__.py @@ -0,0 +1,23 @@ +""" +The package including the modules of SegRNN. + +Refer to the paper +`Lin, Shengsheng and Lin, Weiwei and Wu, Wentai and Zhao, Feiyu and Mo, Ruichao and Zhang, Haotong. +Segrnn: Segment recurrent neural network for long-term time series forecasting. +arXiv preprint arXiv:2308.11200. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/lss-1138/SegRNN + +""" + +# Created by Shengsheng Lin + + +from .backbone import BackboneSegRNN + +__all__ = [ + "BackboneSegRNN", +] diff --git a/pypots/nn/modules/segrnn/backbone.py b/pypots/nn/modules/segrnn/backbone.py new file mode 100644 index 00000000..e0588403 --- /dev/null +++ b/pypots/nn/modules/segrnn/backbone.py @@ -0,0 +1,79 @@ +""" + +""" + +# Created by Shengsheng Lin + +from typing import Optional + +import torch +import torch.nn as nn + + +class BackboneSegRNN(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + seg_len: int = 24, + d_model: int = 512, + dropout: float = 0.5 + ): + super().__init__() + + self.n_steps = n_steps + self.n_features = n_features + self.seg_len = seg_len + self.d_model = d_model + self.dropout = dropout + + + if n_steps % seg_len: + raise ValueError("The argument seg_len is necessary for SegRNN need to be divisible by the sequence length n_steps.") + + self.seg_num = self.n_steps // self.seg_len + self.valueEmbedding = nn.Sequential( + nn.Linear(self.seg_len, self.d_model), + nn.ReLU() + ) + self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, + batch_first=True, bidirectional=False) + self.pos_emb = nn.Parameter(torch.randn(self.seg_num, self.d_model // 2)) + self.channel_emb = nn.Parameter(torch.randn(self.n_features, self.d_model // 2)) + self.predict = nn.Sequential( + nn.Dropout(self.dropout), + nn.Linear(self.d_model, self.seg_len) + ) + + def forward(self, x): + # b:batch_size c:channel_size s:seq_len s:seq_len + # d:d_model w:seg_len n m:seg_num + batch_size = x.size(0) + + # normalization and permute b,s,c -> b,c,s + seq_last = x[:, -1:, :].detach() + x = (x - seq_last).permute(0, 2, 1) # b,c,s + + # segment and embedding b,c,s -> bc,n,w -> bc,n,d + x = self.valueEmbedding(x.reshape(-1, self.seg_num, self.seg_len)) + + # encoding + _, hn = self.rnn(x) # bc,n,d 1,bc,d + + # m,d//2 -> 1,m,d//2 -> c,m,d//2 + # c,d//2 -> c,1,d//2 -> c,m,d//2 + # c,m,d -> cm,1,d -> bcm, 1, d + pos_emb = torch.cat([ + self.pos_emb.unsqueeze(0).repeat(self.n_features, 1, 1), + self.channel_emb.unsqueeze(1).repeat(1, self.seg_num, 1) + ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1) + + _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d + + # 1,bcm,d -> 1,bcm,w -> b,c,s + y = self.predict(hy).view(-1, self.n_features, self.n_steps) + + # permute and denorm + y = y.permute(0, 2, 1) + seq_last + + return y From 0195cbbdb77e5949b2e0ed53cf30c8bd7f89235d Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 25 Oct 2024 20:20:25 +0800 Subject: [PATCH 2/3] Update docs (#538) * docs: add FITS into the algo table; * docs: update the reference of TEFN; * docs: update pytorch intersphinx mapping link; * docs: update docs for new added models CSAI and SegRNN; --- README.md | 10 ++++++--- README_zh.md | 8 +++++++ docs/index.rst | 4 ++++ docs/pypots.imputation.rst | 18 ++++++++++++++++ docs/references.bib | 6 ++++++ pypots/imputation/csai/model.py | 38 +++++++++++++++++++++++++-------- 6 files changed, 72 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index e024cc27..fcfda433 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,6 @@ The paper references and links are all listed at the bottom of this file. | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | -| Neural Net | SegRNN[^42] | ✅ | | | | | `2023 - arXiv` | | Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | | Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | | Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | @@ -137,6 +136,8 @@ The paper references and links are all listed at the bottom of this file. | Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | | Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | | Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | +| Neural Net | CSAI[^42] | ✅ | | | | | `2023 - arXiv` | +| Neural Net | SegRNN🧑‍🔧[^43] | ✅ | | | | | `2023 - arXiv` | | Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | | Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | | Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | @@ -510,6 +511,9 @@ Time-Series.AI [^41]: Xu, Z., Zeng, A., & Xu, Q. (2024). [FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb). *ICLR 2024*. -[^42]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). -[Segrnn: Segment recurrent neural network for long-term time series forecasting](https://github.com/lss-1138/SegRNN) +[^42]: Qian, L., Ibrahim, Z., Ellis, H. L., Zhang, A., Zhang, Y., Wang, T., & Dobson, R. (2023). +[Knowledge Enhanced Conditional Imputation for Healthcare Time-series](https://arxiv.org/abs/2312.16713). +*arXiv 2023*. +[^43]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). +[SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting](https://arxiv.org/abs/2308.11200). *arXiv 2023*. diff --git a/README_zh.md b/README_zh.md index 55978e01..d85af8ba 100644 --- a/README_zh.md +++ b/README_zh.md @@ -121,6 +121,8 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异 | Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | | Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | | Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | +| Neural Net | CSAI[^42] | ✅ | | | | | `2023 - arXiv` | +| Neural Net | SegRNN🧑‍🔧[^43] | ✅ | | | | | `2023 - arXiv` | | Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | | Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | | Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | @@ -482,3 +484,9 @@ Time-Series.AI [^41]: Xu, Z., Zeng, A., & Xu, Q. (2024). [FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb). *ICLR 2024*. +[^42]: Qian, L., Ibrahim, Z., Ellis, H. L., Zhang, A., Zhang, Y., Wang, T., & Dobson, R. (2023). +[Knowledge Enhanced Conditional Imputation for Healthcare Time-series](https://arxiv.org/abs/2312.16713). +*arXiv 2023*. +[^43]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). +[SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting](https://arxiv.org/abs/2308.11200). +*arXiv 2023*. diff --git a/docs/index.rst b/docs/index.rst index 4aaeb762..fd3b9219 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -165,6 +165,10 @@ The paper references are all listed at the bottom of this readme file. +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | TiDE🧑‍🔧 :cite:`das2023tide` | ✅ | | | | | ``2023 - TMLR`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ +| Neural Net | CSAI :cite:`qian2023csai` | ✅ | | | | | ``2023 - arXiv`` | ++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ +| Neural Net | SegRNN🧑‍🔧 :cite:`lin2023segrnn` | ✅ | | | | | ``2023 - arXiv`` | ++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | SCINet🧑‍🔧 :cite:`liu2022scinet` | ✅ | | | | | ``2022 - NeurIPS`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | Nonstationary Tr🧑‍🔧 :cite:`liu2022nonstationary` | ✅ | | | | | ``2022 - NeurIPS`` | diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index a7b47b07..73e55646 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -28,6 +28,24 @@ pypots.imputation.tefn :show-inheritance: :inherited-members: +pypots.imputation.csai +------------------------------------ + +.. automodule:: pypots.imputation.csai + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + +pypots.imputation.segrnn +------------------------------------ + +.. automodule:: pypots.imputation.segrnn + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.fits ------------------------------------ diff --git a/docs/references.bib b/docs/references.bib index ce0014ea..10a9632c 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -779,3 +779,9 @@ @inproceedings{jin2024timellm url={https://openreview.net/forum?id=Unb5CVPtae} } +@article{qian2023csai, +title={Knowledge Enhanced Conditional Imputation for Healthcare Time-series}, +author={Qian, Linglong and Ibrahim, Zina and Ellis, Hugh Logan and Zhang, Ao and Zhang, Yuezhou and Wang, Tao and Dobson, Richard}, +journal={arXiv preprint arXiv:2312.16713}, +year={2023} +} \ No newline at end of file diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index b61286c5..010561a2 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -21,7 +21,8 @@ class CSAI(BaseNNImputer): - """ + """The PyTorch implementation of the CSAI model :cite:`qian2023csai`. + Parameters ---------- n_steps : @@ -58,29 +59,48 @@ class CSAI(BaseNNImputer): The number of epochs for training the model. patience : - The patience for the early-stopping mechanism. Given a positive integer, training will stop when no improvement is observed after the specified number of epochs. If set to None, early-stopping is disabled. + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. optimizer : - The optimizer used for model training. Defaults to the Adam optimizer if not specified. + The optimizer for model training. + If not given, will use a default Adam optimizer. num_workers : - The number of subprocesses used for data loading. Setting this to `0` means that data loading is performed in the main process without using subprocesses. + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on, which can be a string, a :class:`torch.device` object, or a list of devices. If not provided, the model will attempt to use available CUDA devices first, then default to CPUs. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : - The path for saving model checkpoints and tensorboard files during training. If not provided, models will not be saved automatically. + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. model_saving_strategy : - The strategy for saving model checkpoints. Can be one of [None, "best", "better", "all"]. "best" saves the best model after training, "better" saves any model that improves during training, and "all" saves models after each epoch. If set to None, no models will be saved. + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. verbose : - Whether to print training logs during the training process. + Whether to print out the training logs during the training process. Notes ----- - CSAI (Consistent Sequential Imputation) is a bidirectional model designed for time-series imputation. It employs a forward and backward GRU network to handle missing data, using consistency and reconstruction losses to improve accuracy. The model supports various training configurations, such as interval computations, early-stopping, and multiple devices for training. Results can be saved based on the specified saving strategy, and tensorboard files are generated for tracking the model's performance over time. + CSAI (Consistent Sequential Imputation) is a bidirectional model designed for time-series imputation. + It employs a forward and backward GRU network to handle missing data, using consistency and reconstruction losses + to improve accuracy. The model supports various training configurations, such as interval computations, + early-stopping, and multiple devices for training. Results can be saved based on the specified saving strategy, + and tensorboard files are generated for tracking the model's performance over time. """ From bf045b4889637586b4fbf084dc7ee0470768f51f Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 25 Oct 2024 21:15:47 +0800 Subject: [PATCH 3/3] refactor: clean linting issues; --- pypots/classification/csai/__init__.py | 2 +- pypots/classification/csai/core.py | 40 ++--- pypots/classification/csai/data.py | 31 ++-- pypots/classification/csai/model.py | 131 +++++++------- pypots/imputation/csai/__init__.py | 2 +- pypots/imputation/csai/core.py | 38 +++-- pypots/imputation/csai/data.py | 227 ++++++++++++------------- pypots/imputation/csai/model.py | 167 +++++++++--------- pypots/imputation/segrnn/__init__.py | 1 - pypots/imputation/segrnn/core.py | 1 + pypots/nn/modules/csai/__init__.py | 10 +- pypots/nn/modules/csai/backbone.py | 59 +++---- pypots/nn/modules/csai/layers.py | 29 ++-- pypots/nn/modules/segrnn/backbone.py | 53 +++--- 14 files changed, 383 insertions(+), 408 deletions(-) diff --git a/pypots/classification/csai/__init__.py b/pypots/classification/csai/__init__.py index 5ea14ae3..d29b9575 100644 --- a/pypots/classification/csai/__init__.py +++ b/pypots/classification/csai/__init__.py @@ -17,4 +17,4 @@ __all__ = [ "CSAI", -] \ No newline at end of file +] diff --git a/pypots/classification/csai/core.py b/pypots/classification/csai/core.py index 30f052f6..97a1fecc 100644 --- a/pypots/classification/csai/core.py +++ b/pypots/classification/csai/core.py @@ -17,10 +17,10 @@ # self.bcelogits = nn.BCEWithLogitsLoss() # def forward(self, y_score, y_out, targets, smooth=1): - + # #comment out if your model contains a sigmoid or equivalent activation layer -# # inputs = F.sigmoid(inputs) - +# # inputs = F.sigmoid(inputs) + # #flatten label and prediction tensors # BCE = self.bcelogits(y_out, targets) @@ -30,23 +30,23 @@ # dice_loss = 1 - (2.*intersection + smooth)/(y_score.sum() + targets.sum() + smooth) # Dice_BCE = BCE + dice_loss - + # return BCE, Dice_BCE class _BCSAI(nn.Module): def __init__( - self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - imputation_weight: float, - consistency_weight: float, - classification_weight: float, - n_classes: int, - step_channels: int, - dropout: float = 0.5, - intervals=None, + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + imputation_weight: float, + consistency_weight: float, + classification_weight: float, + n_classes: int, + step_channels: int, + dropout: float = 0.5, + intervals=None, ): super().__init__() self.n_steps = n_steps @@ -107,12 +107,12 @@ def forward(self, inputs: dict, training: bool = True) -> dict: b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["labels"]) # f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float()) # b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float()) - classification_loss = (f_classification_loss + b_classification_loss) + classification_loss = f_classification_loss + b_classification_loss loss = ( - self.consistency_weight * consistency_loss + - self.imputation_weight * reconstruction_loss + - self.classification_weight * classification_loss + self.consistency_weight * consistency_loss + + self.imputation_weight * reconstruction_loss + + self.classification_weight * classification_loss ) results["loss"] = loss @@ -120,4 +120,4 @@ def forward(self, inputs: dict, training: bool = True) -> dict: results["f_reconstruction"] = f_reconstruction results["b_reconstruction"] = b_reconstruction - return results \ No newline at end of file + return results diff --git a/pypots/classification/csai/data.py b/pypots/classification/csai/data.py index caeb5005..cd829882 100644 --- a/pypots/classification/csai/data.py +++ b/pypots/classification/csai/data.py @@ -6,22 +6,22 @@ # License: BSD-3-Clause from typing import Union -from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation - +from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation class DatasetForCSAI(DatasetForCSAI_Imputation): - def __init__(self, - data: Union[dict, str], - file_type: str = "hdf5", - return_y: bool = True, - removal_percent: float = 0.0, - increase_factor: float = 0.1, - compute_intervals: bool = False, - replacement_probabilities = None, - normalise_mean : list = [], - normalise_std: list = [], - training: bool = True + def __init__( + self, + data: Union[dict, str], + file_type: str = "hdf5", + return_y: bool = True, + removal_percent: float = 0.0, + increase_factor: float = 0.1, + compute_intervals: bool = False, + replacement_probabilities=None, + normalise_mean: list = [], + normalise_std: list = [], + training: bool = True, ): super().__init__( data=data, @@ -34,6 +34,5 @@ def __init__(self, replacement_probabilities=replacement_probabilities, normalise_mean=normalise_mean, normalise_std=normalise_std, - training=training - ) - \ No newline at end of file + training=training, + ) diff --git a/pypots/classification/csai/model.py b/pypots/classification/csai/model.py index fb9bd5b5..11c7e117 100644 --- a/pypots/classification/csai/model.py +++ b/pypots/classification/csai/model.py @@ -1,4 +1,3 @@ - """ """ @@ -19,7 +18,6 @@ class CSAI(BaseNNClassifier): - """ The PyTorch implementation of the CSAI model. @@ -87,7 +85,7 @@ class CSAI(BaseNNClassifier): verbose : Whether to print out the training logs during the training process. - + """ def __init__( @@ -99,33 +97,33 @@ def __init__( consistency_weight: float, classification_weight: float, n_classes: int, - removal_percent: int, - increase_factor: float, - compute_intervals: bool, - step_channels:int, - batch_size: int, - epochs: int, + removal_percent: int, + increase_factor: float, + compute_intervals: bool, + step_channels: int, + batch_size: int, + epochs: int, dropout: float = 0.5, - patience: Union[int, None] = None, - optimizer: Optimizer = Adam(), - num_workers: int = 0, - device: Optional[Union[str, torch.device, list]] = None, + patience: Union[int, None] = None, + optimizer: Optimizer = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, - model_saving_strategy: Union[str, None] = "best", - verbose: bool = True + model_saving_strategy: Union[str, None] = "best", + verbose: bool = True, ): super().__init__( - n_classes, - batch_size, - epochs, + n_classes, + batch_size, + epochs, patience, - num_workers, + num_workers, device, - saving_path, - model_saving_strategy, + saving_path, + model_saving_strategy, verbose, ) - + self.n_steps = n_steps self.n_features = n_features self.rnn_hidden_size = rnn_hidden_size @@ -138,8 +136,8 @@ def __init__( self.compute_intervals = compute_intervals self.dropout = dropout self.intervals = None - - # Initialise empty model + + # Initialise empty model self.model = _BCSAI( n_steps=self.n_steps, n_features=self.n_features, @@ -161,19 +159,10 @@ def __init__( def _assemble_input_for_training(self, data: list, training=True) -> dict: # extract data - sample = data['sample'] - ( - indices, - X, - missing_mask, - deltas, - last_obs, - back_X, - back_missing_mask, - back_deltas, - back_last_obs, - labels - ) = self._send_data_to_given_device(sample) + sample = data["sample"] + (indices, X, missing_mask, deltas, last_obs, back_X, back_missing_mask, back_deltas, back_last_obs, labels) = ( + self._send_data_to_given_device(sample) + ) inputs = { "indices": indices, @@ -195,10 +184,10 @@ def _assemble_input_for_training(self, data: list, training=True) -> dict: def _assemble_input_for_validating(self, data: list) -> dict: return self._assemble_input_for_training(data) - + def _assemble_input_for_testing(self, data: list) -> dict: # extract data - sample = data['sample'] + sample = data["sample"] ( indices, X, @@ -231,30 +220,30 @@ def _assemble_input_for_testing(self, data: list) -> dict: # "X_ori": X_ori, # "indicating_mask": indicating_mask, } - + return inputs - + def fit( - self, - train_set, - val_set= None, - file_type: str = "hdf5", - )-> None: + self, + train_set, + val_set=None, + file_type: str = "hdf5", + ) -> None: # Create dataset self.training_set = DatasetForCSAI( - data=train_set, - file_type=file_type, - return_y=True, - removal_percent=self.removal_percent, - increase_factor=self.increase_factor, - compute_intervals=self.compute_intervals, - ) + data=train_set, + file_type=file_type, + return_y=True, + removal_percent=self.removal_percent, + increase_factor=self.increase_factor, + compute_intervals=self.compute_intervals, + ) self.intervals = self.training_set.intervals self.replacement_probabilities = self.training_set.replacement_probabilities self.mean_set = self.training_set.mean_set self.std_set = self.training_set.std_set - + train_loader = DataLoader( self.training_set, batch_size=self.batch_size, @@ -297,7 +286,7 @@ def fit( self._print_model_size() # set up the optimizer - self.optimizer.init_optimizer(self.model.parameters()) + self.optimizer.init_optimizer(self.model.parameters()) # train the model self._train_model(train_loader, val_loader) @@ -306,13 +295,12 @@ def fit( self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") - def predict( - self, - test_set: Union[dict, str], - file_type: str = "hdf5", - ) -> dict: - + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + self.model.eval() test_set = DatasetForCSAI( data=test_set, @@ -339,20 +327,19 @@ def predict( for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) results = self.model.forward(inputs, training=False) - classificaion_results.append(results['classification_pred']) - - + classificaion_results.append(results["classification_pred"]) + classification = torch.cat(classificaion_results).cpu().detach().numpy() result_dict = { "classification": classification, - } + } return result_dict - + def classify( - self, - test_set, - file_type: str = "hdf5", - ): - + self, + test_set, + file_type: str = "hdf5", + ): + result_dict = self.predict(test_set, file_type) - return result_dict['classification'] \ No newline at end of file + return result_dict["classification"] diff --git a/pypots/imputation/csai/__init__.py b/pypots/imputation/csai/__init__.py index 529cb0ce..790c28b2 100644 --- a/pypots/imputation/csai/__init__.py +++ b/pypots/imputation/csai/__init__.py @@ -20,4 +20,4 @@ __all__ = [ "CSAI", -] \ No newline at end of file +] diff --git a/pypots/imputation/csai/core.py b/pypots/imputation/csai/core.py index fc9ea6f9..1d59621c 100644 --- a/pypots/imputation/csai/core.py +++ b/pypots/imputation/csai/core.py @@ -65,15 +65,17 @@ class _BCSAI(nn.Module): BCSAI is a bidirectional imputation model that uses forward and backward GRU cells to handle time-series data. It computes consistency and reconstruction losses to improve imputation accuracy. During training, the forward and backward reconstructions are combined, and losses are used to update the model. In evaluation mode, the model also outputs original data and indicating masks for further analysis. """ - def __init__(self, - n_steps, - n_features, - rnn_hidden_size, - step_channels, - consistency_weight, - imputation_weight, - intervals=None, - ): + + def __init__( + self, + n_steps, + n_features, + rnn_hidden_size, + step_channels, + consistency_weight, + imputation_weight, + intervals=None, + ): super().__init__() self.n_steps = n_steps self.n_features = n_features @@ -82,18 +84,18 @@ def __init__(self, self.intervals = intervals self.consistency_weight = consistency_weight self.imputation_weight = imputation_weight - + self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels, intervals) - def forward(self, inputs:dict, training:bool = True) -> dict: + def forward(self, inputs: dict, training: bool = True) -> dict: ( - imputed_data, - f_reconstruction, - b_reconstruction, - f_hidden_states, - b_hidden_states, - consistency_loss, - reconstruction_loss, + imputed_data, + f_reconstruction, + b_reconstruction, + f_hidden_states, + b_hidden_states, + consistency_loss, + reconstruction_loss, ) = self.model(inputs) results = { diff --git a/pypots/imputation/csai/data.py b/pypots/imputation/csai/data.py index 6e4e481a..26a99a5d 100644 --- a/pypots/imputation/csai/data.py +++ b/pypots/imputation/csai/data.py @@ -14,10 +14,11 @@ from ...data.utils import parse_delta from sklearn.preprocessing import StandardScaler + def normalize_csai( - data, - mean: list = None, - std: list = None, + data, + mean: list = None, + std: list = None, compute_intervals: bool = False, ): """ @@ -111,7 +112,7 @@ def normalize_csai( def compute_last_obs(data, masks): """ Compute the last observed values for each time step. - + Parameters: - data (np.array): Original data array of shape [T, D]. - masks (np.array): Binary masks indicating where data is not NaN, of shape [T, D]. @@ -122,21 +123,19 @@ def compute_last_obs(data, masks): T, D = masks.shape last_obs = np.full((T, D), np.nan) # Initialize last observed values with NaNs last_obs_val = np.full(D, np.nan) # Initialize last observed values for first time step with NaNs - + for t in range(1, T): # Start from t=1, keeping first row as NaN mask = masks[t - 1] # Update last observed values based on previous time step last_obs_val[mask] = data[t - 1, mask] # Assign last observed values to the current time step - last_obs[t] = last_obs_val - + last_obs[t] = last_obs_val + return last_obs + def adjust_probability_vectorized( - obs_count: Union[int, float], - avg_count: Union[int, float], - base_prob: float, - increase_factor: float = 0.5 + obs_count: Union[int, float], avg_count: Union[int, float], base_prob: float, increase_factor: float = 0.5 ) -> float: """ Adjusts the base probability based on observed and average counts using a scaling factor. @@ -164,10 +163,10 @@ def adjust_probability_vectorized( Notes ----- - This function adjusts a base probability based on the observed count (`obs_count`) compared to the average count - (`avg_count`). If the observed count is lower than the average, the probability is increased proportionally, - but capped at a maximum of 1.0. Conversely, if the observed count exceeds the average, the probability is reduced, - but not below 0. The `increase_factor` controls the sensitivity of the probability adjustment when the observed + This function adjusts a base probability based on the observed count (`obs_count`) compared to the average count + (`avg_count`). If the observed count is lower than the average, the probability is increased proportionally, + but capped at a maximum of 1.0. Conversely, if the observed count exceeds the average, the probability is reduced, + but not below 0. The `increase_factor` controls the sensitivity of the probability adjustment when the observed count is less than the average count. """ if obs_count < avg_count: @@ -177,16 +176,11 @@ def adjust_probability_vectorized( # Decrease probability when observed count exceeds average count return max(base_prob * (obs_count / avg_count) / increase_factor, 0.0) -def non_uniform_sample( - data, - removal_percent, - pre_replacement_probabilities=None, - increase_factor=0.5 - ): - + +def non_uniform_sample(data, removal_percent, pre_replacement_probabilities=None, increase_factor=0.5): """ - Process time-series data by randomly removing a certain percentage of observed values based on pre-defined - replacement probabilities, and compute the necessary features such as forward and backward deltas, masks, + Process time-series data by randomly removing a certain percentage of observed values based on pre-defined + replacement probabilities, and compute the necessary features such as forward and backward deltas, masks, and last observed values. This function generates records for each time series and returns them as PyTorch tensors for further usage. @@ -194,25 +188,25 @@ def non_uniform_sample( Parameters ---------- data : np.ndarray - The input data with shape [N, T, D], where N is the number of samples, T is the number of time steps, + The input data with shape [N, T, D], where N is the number of samples, T is the number of time steps, and D is the number of features. Missing values should be indicated with NaNs. - + removal_percent : float The percentage of observed values to be removed randomly from the dataset. - + pre_replacement_probabilities : np.ndarray, optional - Pre-defined replacement probabilities for each feature. If provided, this will be used to determine + Pre-defined replacement probabilities for each feature. If provided, this will be used to determine which values to remove. - + increase_factor : float, default=0.5 A factor to adjust replacement probabilities based on the observation count for each feature. Returns ------- tensor_dict : dict of torch.Tensors - A dictionary of PyTorch tensors including 'values', 'last_obs_f', 'last_obs_b', 'masks', 'deltas_f', + A dictionary of PyTorch tensors including 'values', 'last_obs_f', 'last_obs_b', 'masks', 'deltas_f', 'deltas_b', 'evals', and 'eval_masks'. - + replacement_probabilities : np.ndarray The computed or provided replacement probabilities for each feature. """ @@ -224,16 +218,16 @@ def non_uniform_sample( observations_per_feature = np.sum(~np.isnan(data), axis=(0, 1)) average_observations = np.mean(observations_per_feature) replacement_probabilities = np.full(D, removal_percent / 100) - + if increase_factor > 0: for feature_idx in range(D): replacement_probabilities[feature_idx] = adjust_probability_vectorized( observations_per_feature[feature_idx], average_observations, replacement_probabilities[feature_idx], - increase_factor=increase_factor + increase_factor=increase_factor, ) - + total_observations = np.sum(observations_per_feature) total_replacement_target = total_observations * removal_percent / 100 @@ -269,34 +263,36 @@ def non_uniform_sample( last_obs_b = compute_last_obs(values[i, ::-1, :], masks[::-1, :]) # Append the record for this sample - recs.append({ - 'values': np.nan_to_num(values[i, :, :]), - 'last_obs_f': np.nan_to_num(last_obs_f), - 'last_obs_b': np.nan_to_num(last_obs_b), - 'masks': masks.astype('int32'), - 'evals': np.nan_to_num(evals), - 'eval_masks': eval_masks.astype('int32'), - 'deltas_f': deltas_f, - 'deltas_b': deltas_b - }) + recs.append( + { + "values": np.nan_to_num(values[i, :, :]), + "last_obs_f": np.nan_to_num(last_obs_f), + "last_obs_b": np.nan_to_num(last_obs_b), + "masks": masks.astype("int32"), + "evals": np.nan_to_num(evals), + "eval_masks": eval_masks.astype("int32"), + "deltas_f": deltas_f, + "deltas_b": deltas_b, + } + ) # Convert records to PyTorch tensors tensor_dict = { - 'values': torch.FloatTensor(np.array([r['values'] for r in recs])), - 'last_obs_f': torch.FloatTensor(np.array([r['last_obs_f'] for r in recs])), - 'last_obs_b': torch.FloatTensor(np.array([r['last_obs_b'] for r in recs])), - 'masks': torch.FloatTensor(np.array([r['masks'] for r in recs])), - 'deltas_f': torch.FloatTensor(np.array([r['deltas_f'] for r in recs])), - 'deltas_b': torch.FloatTensor(np.array([r['deltas_b'] for r in recs])), - 'evals': torch.FloatTensor(np.array([r['evals'] for r in recs])), - 'eval_masks': torch.FloatTensor(np.array([r['eval_masks'] for r in recs])) + "values": torch.FloatTensor(np.array([r["values"] for r in recs])), + "last_obs_f": torch.FloatTensor(np.array([r["last_obs_f"] for r in recs])), + "last_obs_b": torch.FloatTensor(np.array([r["last_obs_b"] for r in recs])), + "masks": torch.FloatTensor(np.array([r["masks"] for r in recs])), + "deltas_f": torch.FloatTensor(np.array([r["deltas_f"] for r in recs])), + "deltas_b": torch.FloatTensor(np.array([r["deltas_b"] for r in recs])), + "evals": torch.FloatTensor(np.array([r["evals"] for r in recs])), + "eval_masks": torch.FloatTensor(np.array([r["eval_masks"] for r in recs])), } return tensor_dict, replacement_probabilities class DatasetForCSAI(BaseDataset): - """" + """ " Parameters ---------- data : @@ -337,25 +333,25 @@ class DatasetForCSAI(BaseDataset): The DatasetForCSAI class is designed for bidirectional imputation of time-series data, handling both forward and backward directions to improve imputation accuracy. It supports on-the-fly data normalization and missing value simulation, making it suitable for training and evaluating deep learning models like CSAI. The class can work with large datasets stored on disk, leveraging lazy-loading to minimize memory usage, and supports both training and testing scenarios, adjusting data handling as needed. """ - def __init__(self, - data: Union[dict, str], - return_X_ori: bool, - return_y: bool, - file_type: str = "hdf5", - removal_percent: float = 0.0, - increase_factor: float = 0.1, - compute_intervals: bool = False, - replacement_probabilities = None, - normalise_mean : list = [], - normalise_std: list = [], - training: bool = True - ): - super().__init__(data = data, - return_X_ori = return_X_ori, - return_X_pred = False, - return_y = return_y, - file_type = file_type) - + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + removal_percent: float = 0.0, + increase_factor: float = 0.1, + compute_intervals: bool = False, + replacement_probabilities=None, + normalise_mean: list = [], + normalise_std: list = [], + training: bool = True, + ): + super().__init__( + data=data, return_X_ori=return_X_ori, return_X_pred=False, return_y=return_y, file_type=file_type + ) + self.removal_percent = removal_percent self.increase_factor = increase_factor self.compute_intervals = compute_intervals @@ -366,26 +362,25 @@ def __init__(self, if not isinstance(self.data, str): self.normalized_data, self.mean_set, self.std_set, self.intervals = normalize_csai( - self.data['X'], - self.normalise_mean, - self.normalise_std, - compute_intervals, - ) + self.data["X"], + self.normalise_mean, + self.normalise_std, + compute_intervals, + ) self.processed_data, self.replacement_probabilities = non_uniform_sample( - self.normalized_data, - removal_percent, - replacement_probabilities, - increase_factor, - ) - self.forward_X = self.processed_data['values'] - self.forward_missing_mask = self.processed_data['masks'] + self.normalized_data, + removal_percent, + replacement_probabilities, + increase_factor, + ) + self.forward_X = self.processed_data["values"] + self.forward_missing_mask = self.processed_data["masks"] self.backward_X = torch.flip(self.forward_X, dims=[1]) self.backward_missing_mask = torch.flip(self.forward_missing_mask, dims=[1]) - self.X_ori = self.processed_data['evals'] - self.indicating_mask = self.processed_data['eval_masks'] - + self.X_ori = self.processed_data["evals"] + self.indicating_mask = self.processed_data["eval_masks"] def _fetch_data_from_array(self, idx: int) -> Iterable: """Fetch data from self.X if it is given. @@ -415,7 +410,6 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: label (optional) : tensor, The target label of the time-series sample. """ - sample = [ torch.tensor(idx), @@ -438,11 +432,11 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: sample.append(self.y[idx].to(torch.long)) return { - 'sample': sample, - 'replacement_probabilities': self.replacement_probabilities, - 'mean_set': self.mean_set, - 'std_set': self.std_set, - 'intervals': self.intervals + "sample": sample, + "replacement_probabilities": self.replacement_probabilities, + "mean_set": self.mean_set, + "std_set": self.std_set, + "intervals": self.intervals, } def _fetch_data_from_file(self, idx: int) -> Iterable: @@ -465,28 +459,28 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: X = torch.from_numpy(self.file_handle["X"][idx]) normalized_data, mean_set, std_set, intervals = normalize_csai( - X, - self.normalise_mean, - self.normalise_std, - self.compute_intervals, - ) - + X, + self.normalise_mean, + self.normalise_std, + self.compute_intervals, + ) + processed_data, replacement_probabilities = non_uniform_sample( - normalized_data, - self.removal_percent, - self.replacement_probabilities, - self.increase_factor, - ) - forward_X = processed_data['values'] - forward_missing_mask = processed_data['masks'] + normalized_data, + self.removal_percent, + self.replacement_probabilities, + self.increase_factor, + ) + forward_X = processed_data["values"] + forward_missing_mask = processed_data["masks"] backward_X = torch.flip(forward_X, dims=[1]) backward_missing_mask = torch.flip(forward_missing_mask, dims=[1]) - X_ori = self.processed_data['evals'] - indicating_mask = self.processed_data['eval_masks'] - + X_ori = self.processed_data["evals"] + indicating_mask = self.processed_data["eval_masks"] + if self.return_y: - y = self.processed_data['labels'] + y = self.processed_data["labels"] sample = [ torch.tensor(idx), @@ -499,7 +493,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: backward_X, backward_missing_mask, processed_data["deltas_b"], - processed_data["last_obs_b"] + processed_data["last_obs_b"], ] if self.return_X_ori: @@ -510,10 +504,9 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: sample.append(y) return { - 'sample': sample, - 'replacement_probabilities': replacement_probabilities, - 'mean_set': mean_set, - 'std_set': std_set, - 'intervals': intervals + "sample": sample, + "replacement_probabilities": replacement_probabilities, + "mean_set": mean_set, + "std_set": std_set, + "intervals": intervals, } - diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index 010561a2..19c48960 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -103,35 +103,36 @@ class CSAI(BaseNNImputer): and tensorboard files are generated for tracking the model's performance over time. """ - - def __init__(self, - n_steps: int, - n_features: int, - rnn_hidden_size: int, - imputation_weight: float, - consistency_weight: float, - removal_percent: int, - increase_factor: float, - compute_intervals: bool, - step_channels:int, - batch_size: int, - epochs: int, - patience: Union[int, None ]= None, - optimizer: Optional[Optimizer] = Adam(), - num_workers: int = 0, - device: Union[str, torch.device, list, None ]= None, - saving_path: str = None, - model_saving_strategy: Union[str, None] = "best", - verbose: bool = True, + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + imputation_weight: float, + consistency_weight: float, + removal_percent: int, + increase_factor: float, + compute_intervals: bool, + step_channels: int, + batch_size: int, + epochs: int, + patience: Union[int, None] = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Union[str, torch.device, list, None] = None, + saving_path: str = None, + model_saving_strategy: Union[str, None] = "best", + verbose: bool = True, ): super().__init__( - batch_size, - epochs, - patience, - num_workers, - device, - saving_path, - model_saving_strategy, + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, verbose, ) @@ -145,39 +146,31 @@ def __init__(self, self.step_channels = step_channels self.compute_intervals = compute_intervals self.intervals = None - - # Initialise model + + # Initialise model self.model = _BCSAI( - self.n_steps, - self.n_features, - self.rnn_hidden_size, + self.n_steps, + self.n_features, + self.rnn_hidden_size, self.step_channels, - self.consistency_weight, + self.consistency_weight, self.imputation_weight, self.intervals, ) self._send_model_to_given_device() self._print_model_size() - + # set up the optimizer self.optimizer = optimizer def _assemble_input_for_training(self, data: list, training=True) -> dict: # extract data - sample = data['sample'] - - ( - indices, - X, - missing_mask, - deltas, - last_obs, - back_X, - back_missing_mask, - back_deltas, - back_last_obs - ) = self._send_data_to_given_device(sample) + sample = data["sample"] + + (indices, X, missing_mask, deltas, last_obs, back_X, back_missing_mask, back_deltas, back_last_obs) = ( + self._send_data_to_given_device(sample) + ) # assemble input data inputs = { @@ -200,7 +193,7 @@ def _assemble_input_for_training(self, data: list, training=True) -> dict: def _assemble_input_for_validating(self, data: list) -> dict: # extract data - sample = data['sample'] + sample = data["sample"] ( indices, X, @@ -237,23 +230,17 @@ def _assemble_input_for_validating(self, data: list) -> dict: def _assemble_input_for_testing(self, data: list) -> dict: return self._assemble_input_for_validating(data) - + def fit( - self, - train_set, - val_set=None, - file_type: str = "hdf5", - )-> None: - + self, + train_set, + val_set=None, + file_type: str = "hdf5", + ) -> None: + self.training_set = DatasetForCSAI( - train_set, - False, - False, - file_type, - self.removal_percent, - self.increase_factor, - self.compute_intervals - ) + train_set, False, False, file_type, self.removal_percent, self.increase_factor, self.compute_intervals + ) self.intervals = self.training_set.intervals self.replacement_probabilities = self.training_set.replacement_probabilities self.mean_set = self.training_set.mean_set @@ -268,15 +255,15 @@ def fit( ) if val_set is not None: val_set = DatasetForCSAI( - val_set, + val_set, True, - False, - file_type, - self.removal_percent, - self.increase_factor, + False, + file_type, + self.removal_percent, + self.increase_factor, self.compute_intervals, - self.replacement_probabilities, - self.mean_set, + self.replacement_probabilities, + self.mean_set, self.std_set, False, ) @@ -290,11 +277,11 @@ def fit( # Reset the model self.model = _BCSAI( - self.n_steps, - self.n_features, - self.rnn_hidden_size, + self.n_steps, + self.n_features, + self.rnn_hidden_size, self.step_channels, - self.consistency_weight, + self.consistency_weight, self.imputation_weight, self.intervals, ) @@ -314,25 +301,25 @@ def fit( self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") def predict( - self, - test_set: Union[dict, str], + self, + test_set: Union[dict, str], file_type: str = "hdf5", ) -> dict: - + self.model.eval() test_set = DatasetForCSAI( - test_set, - True, - False, - file_type, - self.removal_percent, - self.increase_factor, - self.compute_intervals, - self.replacement_probabilities, - self.mean_set, - self.std_set, - False, - ) + test_set, + True, + False, + file_type, + self.removal_percent, + self.increase_factor, + self.compute_intervals, + self.replacement_probabilities, + self.mean_set, + self.std_set, + False, + ) test_loader = DataLoader( test_set, @@ -345,7 +332,7 @@ def predict( imputation_collector = [] x_ori_collector = [] indicating_mask_collector = [] - + with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) diff --git a/pypots/imputation/segrnn/__init__.py b/pypots/imputation/segrnn/__init__.py index 243345a8..49ad0717 100644 --- a/pypots/imputation/segrnn/__init__.py +++ b/pypots/imputation/segrnn/__init__.py @@ -16,7 +16,6 @@ # Created by Shengsheng Lin - from .model import SegRNN __all__ = [ diff --git a/pypots/imputation/segrnn/core.py b/pypots/imputation/segrnn/core.py index c1d5f3a2..b4978099 100644 --- a/pypots/imputation/segrnn/core.py +++ b/pypots/imputation/segrnn/core.py @@ -13,6 +13,7 @@ from ...nn.modules.segrnn import BackboneSegRNN from ...nn.modules.saits import SaitsLoss + class _SegRNN(nn.Module): def __init__( self, diff --git a/pypots/nn/modules/csai/__init__.py b/pypots/nn/modules/csai/__init__.py index 64c57392..c7964bb2 100644 --- a/pypots/nn/modules/csai/__init__.py +++ b/pypots/nn/modules/csai/__init__.py @@ -23,9 +23,9 @@ "BackboneCSAI", "BackboneBCSAI", "FeatureRegression", - "Decay", - "Decay_obs", - "PositionalEncoding", - "Conv1dWithInit", - "TorchTransformerEncoder" + "Decay", + "Decay_obs", + "PositionalEncoding", + "Conv1dWithInit", + "TorchTransformerEncoder", ] diff --git a/pypots/nn/modules/csai/backbone.py b/pypots/nn/modules/csai/backbone.py index 57600db9..b163a611 100644 --- a/pypots/nn/modules/csai/backbone.py +++ b/pypots/nn/modules/csai/backbone.py @@ -12,6 +12,7 @@ from .layers import FeatureRegression, Decay, Decay_obs, PositionalEncoding, Conv1dWithInit, TorchTransformerEncoder from ....utils.metrics import calc_mae + class BackboneCSAI(nn.Module): """ Attributes @@ -101,27 +102,27 @@ def __init__(self, n_steps, n_features, rnn_hidden_size, step_channels, medians_ self.step_channels = step_channels self.input_size = n_features self.hidden_size = rnn_hidden_size - self.temp_decay_h = Decay(input_size=self.input_size, output_size=self.hidden_size, diag = False) - self.temp_decay_x = Decay(input_size=self.input_size, output_size=self.input_size, diag = True) + self.temp_decay_h = Decay(input_size=self.input_size, output_size=self.hidden_size, diag=False) + self.temp_decay_x = Decay(input_size=self.input_size, output_size=self.input_size, diag=True) self.hist = nn.Linear(self.hidden_size, self.input_size) self.feat_reg_v = FeatureRegression(self.input_size) self.weight_combine = nn.Linear(self.input_size * 2, self.input_size) self.weighted_obs = Decay_obs(self.input_size, self.input_size) self.gru = nn.GRUCell(self.input_size * 2, self.hidden_size) - + self.pos_encoder = PositionalEncoding(self.step_channels) self.input_projection = Conv1dWithInit(self.input_size, self.step_channels, 1) self.output_projection1 = Conv1dWithInit(self.step_channels, self.hidden_size, 1) - self.output_projection2 = Conv1dWithInit(self.n_steps*2, 1, 1) + self.output_projection2 = Conv1dWithInit(self.n_steps * 2, 1, 1) self.time_layer = TorchTransformerEncoder(channels=self.step_channels) self.reset_parameters() - + def reset_parameters(self): for weight in self.parameters(): if len(weight.size()) == 1: continue - stv = 1. / math.sqrt(weight.size(1)) + stv = 1.0 / math.sqrt(weight.size(1)) nn.init.uniform_(weight, -stv, stv) def forward(self, x, mask, deltas, last_obs, h=None): @@ -139,7 +140,7 @@ def forward(self, x, mask, deltas, last_obs, h=None): data_last_obs = self.pos_encoder(data_last_obs.permute(1, 0, 2)).permute(1, 0, 2) data_decay_factor = self.pos_encoder(data_decay_factor.permute(1, 0, 2)).permute(1, 0, 2) - + data = torch.cat([data_last_obs, data_decay_factor], dim=1) data = self.time_layer(data) @@ -158,16 +159,16 @@ def forward(self, x, mask, deltas, last_obs, h=None): # Decayed Hidden States gamma_h = self.temp_decay_h(d_t) h = h * gamma_h - + # history based estimation - x_h = self.hist(h) - + x_h = self.hist(h) + x_r_t = (m_t * x_t) + ((1 - m_t) * x_h) # feature based estimation xu = self.feat_reg_v(x_r_t) gamma_x = self.temp_decay_x(d_t) - + beta = self.weight_combine(torch.cat([gamma_x, m_t], dim=1)) x_comb_t = beta * xu + (1 - beta) * x_h @@ -194,20 +195,20 @@ def __init__(self, n_steps, n_features, rnn_hidden_size, step_channels, medians_ self.model_f = BackboneCSAI(n_steps, n_features, rnn_hidden_size, step_channels, medians_df) self.model_b = BackboneCSAI(n_steps, n_features, rnn_hidden_size, step_channels, medians_df) - + def forward(self, xdata): # Fetching forward data from xdata - x = xdata['forward']['X'] - m = xdata['forward']['missing_mask'] - d_f = xdata['forward']['deltas'] - last_obs_f = xdata['forward']['last_obs'] + x = xdata["forward"]["X"] + m = xdata["forward"]["missing_mask"] + d_f = xdata["forward"]["deltas"] + last_obs_f = xdata["forward"]["last_obs"] # Fetching backward data from xdata - x_b = xdata['backward']['X'] - m_b = xdata['backward']['missing_mask'] - d_b = xdata['backward']['deltas'] - last_obs_b = xdata['backward']['last_obs'] + x_b = xdata["backward"]["X"] + m_b = xdata["backward"]["missing_mask"] + d_b = xdata["backward"]["deltas"] + last_obs_b = xdata["backward"]["last_obs"] # Call forward model ( @@ -227,7 +228,7 @@ def forward(self, xdata): # Averaging the imputations and prediction x_imp = (f_imputed_data + b_imputed_data.flip(dims=[1])) / 2 - imputed_data = (x * m)+ ((1-m) * x_imp) + imputed_data = (x * m) + ((1 - m) * x_imp) # average consistency loss consistency_loss = torch.abs(f_imputed_data - b_imputed_data.flip(dims=[1])).mean() * 1e-1 @@ -235,11 +236,11 @@ def forward(self, xdata): # Merge the regression loss reconstruction_loss = f_reconstruction_loss + b_reconstruction_loss return ( - imputed_data, - f_reconstruction, - b_reconstruction, - f_hidden_states, - b_hidden_states, - consistency_loss, - reconstruction_loss, - ) + imputed_data, + f_reconstruction, + b_reconstruction, + f_hidden_states, + b_hidden_states, + consistency_loss, + reconstruction_loss, + ) diff --git a/pypots/nn/modules/csai/layers.py b/pypots/nn/modules/csai/layers.py index d603eef1..39752262 100644 --- a/pypots/nn/modules/csai/layers.py +++ b/pypots/nn/modules/csai/layers.py @@ -27,11 +27,11 @@ def build(self, input_size): self.W = Parameter(torch.Tensor(input_size, input_size)) self.b = Parameter(torch.Tensor(input_size)) m = torch.ones(input_size, input_size) - torch.eye(input_size, input_size) - self.register_buffer('m', m) + self.register_buffer("m", m) self.reset_parameters() def reset_parameters(self): - stdv = 1. / math.sqrt(self.W.size(0)) + stdv = 1.0 / math.sqrt(self.W.size(0)) self.W.data.uniform_(-stdv, stdv) if self.b is not None: self.b.data.uniform_(-stdv, stdv) @@ -40,6 +40,7 @@ def forward(self, x): z_h = F.linear(x, self.W * Variable(self.m), self.b) return z_h + class Decay(nn.Module): def __init__(self, input_size, output_size, diag=False): super(Decay, self).__init__() @@ -51,13 +52,13 @@ def build(self, input_size, output_size): self.b = Parameter(torch.Tensor(output_size)) if self.diag == True: - assert(input_size == output_size) + assert input_size == output_size m = torch.eye(input_size, input_size) - self.register_buffer('m', m) + self.register_buffer("m", m) self.reset_parameters() def reset_parameters(self): - stdv = 1. / math.sqrt(self.W.size(0)) + stdv = 1.0 / math.sqrt(self.W.size(0)) self.W.data.uniform_(-stdv, stdv) if self.b is not None: self.b.data.uniform_(-stdv, stdv) @@ -70,6 +71,7 @@ def forward(self, d): gamma = torch.exp(-gamma) return gamma + class Decay_obs(nn.Module): def __init__(self, input_size, output_size): super(Decay_obs, self).__init__() @@ -87,12 +89,13 @@ def forward(self, delta_diff): weight_diff = sign * weight_diff # Using a tanh activation to squeeze values between -1 and 1 weight_diff = torch.tanh(weight_diff) - # This will move the weight values towards 1 if delta_diff is negative + # This will move the weight values towards 1 if delta_diff is negative # and towards 0 if delta_diff is positive weight = 0.5 * (1 - weight_diff) return weight + class TorchTransformerEncoder(nn.Module): def __init__(self, heads=8, layers=1, channels=64): super(TorchTransformerEncoder, self).__init__() @@ -100,19 +103,21 @@ def __init__(self, heads=8, layers=1, channels=64): d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu" ) self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=layers) - + def forward(self, x): return self.transformer_encoder(x) - + + class Conv1dWithInit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size): super(Conv1dWithInit, self).__init__() self.conv = nn.Conv1d(in_channels, out_channels, kernel_size) nn.init.kaiming_normal_(self.conv.weight) - + def forward(self, x): return self.conv(x) + class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): @@ -124,12 +129,12 @@ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) def forward(self, x): """ Arguments: x: Tensor, shape ``[seq_len, batch_size, embedding_dim]`` """ - x = x + self.pe[:x.size(0)] - return self.dropout(x) \ No newline at end of file + x = x + self.pe[: x.size(0)] + return self.dropout(x) diff --git a/pypots/nn/modules/segrnn/backbone.py b/pypots/nn/modules/segrnn/backbone.py index e0588403..5b7f2fda 100644 --- a/pypots/nn/modules/segrnn/backbone.py +++ b/pypots/nn/modules/segrnn/backbone.py @@ -11,14 +11,7 @@ class BackboneSegRNN(nn.Module): - def __init__( - self, - n_steps: int, - n_features: int, - seg_len: int = 24, - d_model: int = 512, - dropout: float = 0.5 - ): + def __init__(self, n_steps: int, n_features: int, seg_len: int = 24, d_model: int = 512, dropout: float = 0.5): super().__init__() self.n_steps = n_steps @@ -27,23 +20,24 @@ def __init__( self.d_model = d_model self.dropout = dropout - if n_steps % seg_len: - raise ValueError("The argument seg_len is necessary for SegRNN need to be divisible by the sequence length n_steps.") + raise ValueError( + "The argument seg_len is necessary for SegRNN need to be divisible by the sequence length n_steps." + ) self.seg_num = self.n_steps // self.seg_len - self.valueEmbedding = nn.Sequential( - nn.Linear(self.seg_len, self.d_model), - nn.ReLU() + self.valueEmbedding = nn.Sequential(nn.Linear(self.seg_len, self.d_model), nn.ReLU()) + self.rnn = nn.GRU( + input_size=self.d_model, + hidden_size=self.d_model, + num_layers=1, + bias=True, + batch_first=True, + bidirectional=False, ) - self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, - batch_first=True, bidirectional=False) self.pos_emb = nn.Parameter(torch.randn(self.seg_num, self.d_model // 2)) self.channel_emb = nn.Parameter(torch.randn(self.n_features, self.d_model // 2)) - self.predict = nn.Sequential( - nn.Dropout(self.dropout), - nn.Linear(self.d_model, self.seg_len) - ) + self.predict = nn.Sequential(nn.Dropout(self.dropout), nn.Linear(self.d_model, self.seg_len)) def forward(self, x): # b:batch_size c:channel_size s:seq_len s:seq_len @@ -52,23 +46,30 @@ def forward(self, x): # normalization and permute b,s,c -> b,c,s seq_last = x[:, -1:, :].detach() - x = (x - seq_last).permute(0, 2, 1) # b,c,s + x = (x - seq_last).permute(0, 2, 1) # b,c,s # segment and embedding b,c,s -> bc,n,w -> bc,n,d x = self.valueEmbedding(x.reshape(-1, self.seg_num, self.seg_len)) # encoding - _, hn = self.rnn(x) # bc,n,d 1,bc,d + _, hn = self.rnn(x) # bc,n,d 1,bc,d # m,d//2 -> 1,m,d//2 -> c,m,d//2 # c,d//2 -> c,1,d//2 -> c,m,d//2 # c,m,d -> cm,1,d -> bcm, 1, d - pos_emb = torch.cat([ - self.pos_emb.unsqueeze(0).repeat(self.n_features, 1, 1), - self.channel_emb.unsqueeze(1).repeat(1, self.seg_num, 1) - ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1) + pos_emb = ( + torch.cat( + [ + self.pos_emb.unsqueeze(0).repeat(self.n_features, 1, 1), + self.channel_emb.unsqueeze(1).repeat(1, self.seg_num, 1), + ], + dim=-1, + ) + .view(-1, 1, self.d_model) + .repeat(batch_size, 1, 1) + ) - _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d + _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d # 1,bcm,d -> 1,bcm,w -> b,c,s y = self.predict(hy).view(-1, self.n_features, self.n_steps)