From cb24887af4f9a3c8b4a9d42d3e68900650d78cec Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 01:02:10 +0800 Subject: [PATCH 1/2] feat: add DLinear as an imputation model; --- pypots/imputation/__init__.py | 2 + pypots/imputation/dlinear/__init__.py | 17 + pypots/imputation/dlinear/data.py | 24 ++ pypots/imputation/dlinear/model.py | 296 ++++++++++++++++++ pypots/imputation/dlinear/modules/__init__.py | 6 + pypots/imputation/dlinear/modules/core.py | 93 ++++++ tests/imputation/dlinear.py | 122 ++++++++ 7 files changed, 560 insertions(+) create mode 100644 pypots/imputation/dlinear/__init__.py create mode 100644 pypots/imputation/dlinear/data.py create mode 100644 pypots/imputation/dlinear/model.py create mode 100644 pypots/imputation/dlinear/modules/__init__.py create mode 100644 pypots/imputation/dlinear/modules/core.py create mode 100644 tests/imputation/dlinear.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 64a8a758..2d408d58 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -14,6 +14,7 @@ from .transformer import Transformer from .timesnet import TimesNet from .autoformer import Autoformer +from .dlinear import DLinear from .patchtst import PatchTST from .usgan import USGAN @@ -28,6 +29,7 @@ "Transformer", "TimesNet", "PatchTST", + "DLinear", "Autoformer", "BRITS", "MRNN", diff --git a/pypots/imputation/dlinear/__init__.py b/pypots/imputation/dlinear/__init__.py new file mode 100644 index 00000000..0b179e70 --- /dev/null +++ b/pypots/imputation/dlinear/__init__.py @@ -0,0 +1,17 @@ +""" +The package of the partially-observed time-series imputation model DLinear. + +Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021). +DLinear: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.". + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import DLinear + +__all__ = [ + "DLinear", +] diff --git a/pypots/imputation/dlinear/data.py b/pypots/imputation/dlinear/data.py new file mode 100644 index 00000000..1884054f --- /dev/null +++ b/pypots/imputation/dlinear/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for DLinear. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForDLinear(DatasetForSAITS): + """Actually DLinear uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_labels: bool, + file_type: str = "h5py", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_labels, file_type, rate) diff --git a/pypots/imputation/dlinear/model.py b/pypots/imputation/dlinear/model.py new file mode 100644 index 00000000..33bce716 --- /dev/null +++ b/pypots/imputation/dlinear/model.py @@ -0,0 +1,296 @@ +""" +The implementation of DLinear for the partially-observed time-series imputation task. + +Refer to the paper "Zeng, A., Chen, M., Zhang, L., & Xu, Q. (2023). +Are transformers effective for time series forecasting? AAAI 2023". + +Notes +----- +Partial implementation uses code from https://github.com/thuml/Time-Series-Library + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .data import DatasetForDLinear +from .modules.core import _DLinear +from ..base import BaseNNImputer +from ...data.base import BaseDataset +from ...data.checking import check_X_ori_in_val_set +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class DLinear(BaseNNImputer): + """The PyTorch implementation of the DLinear model. + DLinear is originally proposed by Zeng et al. in :cite:`zeng2023dlinear`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + moving_avg_window_size : + The window size of moving average. + + individual : + Whether to share model across different features. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + References + ---------- + .. [1] `Zeng, Ailing, Muxi Chen, Lei Zhang, and Qiang Xu. + "Are transformers effective for time series forecasting?". + In Proceedings of the AAAI conference on artificial intelligence, vol. 37, no. 9, pp. 11121-11128. 2023. + `_ + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + moving_avg_window_size: int, + individual: bool = False, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.moving_avg_window_size = moving_avg_window_size + self.individual = individual + + # set up the model + self.model = _DLinear( + n_steps, + n_features, + moving_avg_window_size, + individual, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForDLinear( + train_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not check_X_ori_in_val_set(val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForDLinear( + val_set, return_X_ori=True, return_labels=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict : dict, + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/dlinear/modules/__init__.py b/pypots/imputation/dlinear/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/dlinear/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/dlinear/modules/core.py b/pypots/imputation/dlinear/modules/core.py new file mode 100644 index 00000000..e8e5ec35 --- /dev/null +++ b/pypots/imputation/dlinear/modules/core.py @@ -0,0 +1,93 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + +from ...autoformer.modules.submodules import SeriesDecompositionBlock +from ....utils.metrics import calc_mse + + +class _DLinear(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + moving_avg_window_size: int, + individual: bool = False, + ): + super().__init__() + + self.n_steps = n_steps + self.n_features = n_features + self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size) + self.individual = individual + + if individual: + self.Linear_Seasonal = nn.ModuleList() + self.Linear_Trend = nn.ModuleList() + + for i in range(self.n_features): + self.Linear_Seasonal.append(nn.Linear(self.n_steps, self.n_steps)) + self.Linear_Trend.append(nn.Linear(self.n_steps, self.n_steps)) + + self.Linear_Seasonal[i].weight = nn.Parameter( + (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + ) + self.Linear_Trend[i].weight = nn.Parameter( + (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + ) + else: + self.Linear_Seasonal = nn.Linear(self.n_steps, self.n_steps) + self.Linear_Trend = nn.Linear(self.n_steps, self.n_steps) + + self.Linear_Seasonal.weight = nn.Parameter( + (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + ) + self.Linear_Trend.weight = nn.Parameter( + (1 / self.n_steps) * torch.ones([self.n_steps, self.n_steps]) + ) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + # DLinear encoder processing + seasonal_init, trend_init = self.series_decomp(X) + seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute( + 0, 2, 1 + ) + if self.individual: + seasonal_output = torch.zeros( + [seasonal_init.size(0), seasonal_init.size(1), self.n_steps], + dtype=seasonal_init.dtype, + ).to(seasonal_init.device) + trend_output = torch.zeros( + [trend_init.size(0), trend_init.size(1), self.n_steps], + dtype=trend_init.dtype, + ).to(trend_init.device) + for i in range(self.n_features): + seasonal_output[:, i, :] = self.Linear_Seasonal[i]( + seasonal_init[:, i, :] + ) + trend_output[:, i, :] = self.Linear_Trend[i](trend_init[:, i, :]) + else: + seasonal_output = self.Linear_Seasonal(seasonal_init) + trend_output = self.Linear_Trend(trend_init) + output = seasonal_output + trend_output + output = output.permute(0, 2, 1) + + imputed_data = masks * X + (1 - masks) * output + results = { + "imputed_data": imputed_data, + } + + if training: + # `loss` is always the item for backward propagating to update the model + loss = calc_mse(output, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/tests/imputation/dlinear.py b/tests/imputation/dlinear.py new file mode 100644 index 00000000..e2680b23 --- /dev/null +++ b/tests/imputation/dlinear.py @@ -0,0 +1,122 @@ +""" +Test cases for DLinear imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import DLinear +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + H5_TRAIN_SET_PATH, + H5_VAL_SET_PATH, + H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestDLinear(unittest.TestCase): + logger.info("Running tests for an imputation model DLinear...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "DLinear") + model_save_name = "saved_dlinear_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a DLinear model + dlinear = DLinear( + DATA["n_steps"], + DATA["n_features"], + moving_avg_window_size=3, + individual=False, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_0_fit(self): + self.dlinear.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_1_impute(self): + imputation_results = self.dlinear.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"DLinear test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_2_parameters(self): + assert hasattr(self.dlinear, "model") and self.dlinear.model is not None + + assert hasattr(self.dlinear, "optimizer") and self.dlinear.optimizer is not None + + assert hasattr(self.dlinear, "best_loss") + self.assertNotEqual(self.dlinear.best_loss, float("inf")) + + assert ( + hasattr(self.dlinear, "best_model_dict") + and self.dlinear.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.dlinear) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.dlinear.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.dlinear.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-dlinear") + def test_4_lazy_loading(self): + self.dlinear.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) + imputation_results = self.dlinear.predict(H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading DLinear test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From a812c42235c4a6a85288745005050cd6449f7fdd Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 30 Mar 2024 10:27:39 +0800 Subject: [PATCH 2/2] docs: add references for Autoformer, PatchTST, and TimesNet; --- pypots/imputation/autoformer/model.py | 11 +++++------ pypots/imputation/patchtst/model.py | 11 +++++------ pypots/imputation/timesnet/model.py | 11 +++++------ 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py index 350277a2..b6e6c96d 100644 --- a/pypots/imputation/autoformer/model.py +++ b/pypots/imputation/autoformer/model.py @@ -101,13 +101,12 @@ class Autoformer(BaseNNImputer): better than in previous epochs. The "all" strategy will save every model after each epoch training. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Transformer model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Wu, Haixu, Jiehui Xu, Jianmin Wang, and Mingsheng Long. + "Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting". + Advances in neural information processing systems 34 (2021): 22419-22430. + `_ """ diff --git a/pypots/imputation/patchtst/model.py b/pypots/imputation/patchtst/model.py index d0ba98ca..fbd6567e 100644 --- a/pypots/imputation/patchtst/model.py +++ b/pypots/imputation/patchtst/model.py @@ -111,13 +111,12 @@ class PatchTST(BaseNNImputer): better than in previous epochs. The "all" strategy will save every model after each epoch training. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Transformer model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Nie, Yuqi, Nam H. Nguyen, Phanwadee Sinthong, and Jayant Kalagnanam. + "A time series is worth 64 words: Long-term forecasting with transformers". + ICLR 2023. + `_ """ diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py index 9e93d2f9..8b648d2d 100644 --- a/pypots/imputation/timesnet/model.py +++ b/pypots/imputation/timesnet/model.py @@ -103,13 +103,12 @@ class TimesNet(BaseNNImputer): better than in previous epochs. The "all" strategy will save every model after each epoch training. - Attributes + References ---------- - model : :class:`torch.nn.Module` - The underlying Transformer model. - - optimizer : :class:`pypots.optim.Optimizer` - The optimizer for model training. + .. [1] `Wu, Haixu, Tengge Hu, Yong Liu, Hang Zhou, Jianmin Wang, and Mingsheng Long. + "TimesNet: Temporal 2d-variation modeling for general time series analysis". + ICLR 2022. + `_ """