From 642c21f238519181f5228ff986ba32d2712ae6bc Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 10 May 2024 01:21:46 +0800 Subject: [PATCH 1/2] feat: add MICN modules; --- pypots/nn/modules/fedformer/layers.py | 26 ++++ pypots/nn/modules/micn/__init__.py | 24 ++++ pypots/nn/modules/micn/backbone.py | 50 ++++++++ pypots/nn/modules/micn/layers.py | 169 ++++++++++++++++++++++++++ 4 files changed, 269 insertions(+) create mode 100644 pypots/nn/modules/micn/__init__.py create mode 100644 pypots/nn/modules/micn/backbone.py create mode 100644 pypots/nn/modules/micn/layers.py diff --git a/pypots/nn/modules/fedformer/layers.py b/pypots/nn/modules/fedformer/layers.py index d02bc996..36522bf9 100644 --- a/pypots/nn/modules/fedformer/layers.py +++ b/pypots/nn/modules/fedformer/layers.py @@ -17,6 +17,7 @@ from torch import Tensor from torch import nn +from ..autoformer.layers import MovingAvgBlock from ..transformer.attention import AttentionOperator @@ -952,3 +953,28 @@ def forward( out_ft / self.in_channels / self.out_channels, n=xq.size(-1) ) return out, None + + +class SeriesDecompositionMultiBlock(nn.Module): + """ + Series decomposition block from FEDfromer, + i.e. series_decomp_multi from https://github.com/MAZiqing/FEDformer + + """ + + def __init__(self, kernel_size): + super().__init__() + self.moving_avg = [MovingAvgBlock(kernel, stride=1) for kernel in kernel_size] + self.layer = torch.nn.Linear(1, len(kernel_size)) + + def forward(self, x): + moving_mean = [] + for func in self.moving_avg: + moving_avg = func(x) + moving_mean.append(moving_avg.unsqueeze(-1)) + moving_mean = torch.cat(moving_mean, dim=-1) + moving_mean = torch.sum( + moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1 + ) + res = x - moving_mean + return res, moving_mean diff --git a/pypots/nn/modules/micn/__init__.py b/pypots/nn/modules/micn/__init__.py new file mode 100644 index 00000000..7a707839 --- /dev/null +++ b/pypots/nn/modules/micn/__init__.py @@ -0,0 +1,24 @@ +""" +The package including the modules of MICN. + +Refer to the paper +`Huiqiang Wang, Jian Peng, Feihu Huang, Jince Wang, Junhui Chen, and Yifei Xiao +"MICN: Multi-scale Local and Global Context Modeling for Long-term Series Forecasting". +In the Eleventh International Conference on Learning Representations, 2023. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/wanghq21/MICN + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .backbone import BackboneMICN + +__all__ = [ + "BackboneMICN", +] diff --git a/pypots/nn/modules/micn/backbone.py b/pypots/nn/modules/micn/backbone.py new file mode 100644 index 00000000..3c828254 --- /dev/null +++ b/pypots/nn/modules/micn/backbone.py @@ -0,0 +1,50 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from .layers import SeasonalPrediction +from ..fedformer.layers import SeriesDecompositionMultiBlock + + +class BackboneMICN(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_pred_steps, + n_pred_features, + n_layers, + d_model, + conv_kernel=[12, 24], + ): + super().__init__() + self.n_steps = n_steps + self.n_features = n_features + self.n_pred_steps = n_pred_steps + self.n_pred_features = n_pred_features + + decomp_kernel = [] # kernel of decomposition operation + isometric_kernel = [] # kernel of isometric convolution + for ii in conv_kernel: + if ii % 2 == 0: # the kernel of decomposition operation must be odd + decomp_kernel.append(ii + 1) + isometric_kernel.append((n_steps + n_pred_steps + ii) // ii) + else: + decomp_kernel.append(ii) + isometric_kernel.append((n_steps + n_pred_steps + ii - 1) // ii) + + self.decomp_multi = SeriesDecompositionMultiBlock(decomp_kernel) + + self.conv_trans = SeasonalPrediction( + embedding_size=d_model, + d_layers=n_layers, + decomp_kernel=decomp_kernel, + c_out=n_pred_features, + conv_kernel=conv_kernel, + isometric_kernel=isometric_kernel, + ) diff --git a/pypots/nn/modules/micn/layers.py b/pypots/nn/modules/micn/layers.py new file mode 100644 index 00000000..8189d72e --- /dev/null +++ b/pypots/nn/modules/micn/layers.py @@ -0,0 +1,169 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.fft +import torch.nn as nn + +from ..autoformer import SeriesDecompositionBlock + + +class MIC(nn.Module): + """ + MIC layer to extract local and global features + """ + + def __init__( + self, + feature_size=512, + decomp_kernel=[32], + conv_kernel=[24], + isometric_kernel=[18, 6], + ): + super().__init__() + self.conv_kernel = conv_kernel + + # isometric convolution + self.isometric_conv = nn.ModuleList( + [ + nn.Conv1d( + in_channels=feature_size, + out_channels=feature_size, + kernel_size=i, + padding=0, + stride=1, + ) + for i in isometric_kernel + ] + ) + + # downsampling convolution: padding=i//2, stride=i + self.conv = nn.ModuleList( + [ + nn.Conv1d( + in_channels=feature_size, + out_channels=feature_size, + kernel_size=i, + padding=i // 2, + stride=i, + ) + for i in conv_kernel + ] + ) + + # upsampling convolution + self.conv_trans = nn.ModuleList( + [ + nn.ConvTranspose1d( + in_channels=feature_size, + out_channels=feature_size, + kernel_size=i, + padding=0, + stride=i, + ) + for i in conv_kernel + ] + ) + + self.decomp = nn.ModuleList( + [SeriesDecompositionBlock(k) for k in decomp_kernel] + ) + self.merge = torch.nn.Conv2d( + in_channels=feature_size, + out_channels=feature_size, + kernel_size=(len(self.conv_kernel), 1), + ) + + # feedforward network + self.conv1 = nn.Conv1d( + in_channels=feature_size, out_channels=feature_size * 4, kernel_size=1 + ) + self.conv2 = nn.Conv1d( + in_channels=feature_size * 4, out_channels=feature_size, kernel_size=1 + ) + self.norm1 = nn.LayerNorm(feature_size) + self.norm2 = nn.LayerNorm(feature_size) + + self.norm = torch.nn.LayerNorm(feature_size) + self.act = torch.nn.Tanh() + self.drop = torch.nn.Dropout(0.05) + + def conv_trans_conv(self, input, conv1d, conv1d_trans, isometric): + batch, seq_len, channel = input.shape + x = input.permute(0, 2, 1) + + # downsampling convolution + x1 = self.drop(self.act(conv1d(x))) + x = x1 + + # isometric convolution + zeros = torch.zeros( + (x.shape[0], x.shape[1], x.shape[2] - 1), device=input.device + ) + x = torch.cat((zeros, x), dim=-1) + x = self.drop(self.act(isometric(x))) + x = self.norm((x + x1).permute(0, 2, 1)).permute(0, 2, 1) + + # upsampling convolution + x = self.drop(self.act(conv1d_trans(x))) + x = x[:, :, :seq_len] # truncate + + x = self.norm(x.permute(0, 2, 1) + input) + return x + + def forward(self, src): + # multi-scale + multi = [] + for i in range(len(self.conv_kernel)): + src_out, trend1 = self.decomp[i](src) + src_out = self.conv_trans_conv( + src_out, self.conv[i], self.conv_trans[i], self.isometric_conv[i] + ) + multi.append(src_out) + + # merge + mg = torch.tensor([], device=src.device) + for i in range(len(self.conv_kernel)): + mg = torch.cat((mg, multi[i].unsqueeze(1)), dim=1) + mg = self.merge(mg.permute(0, 3, 1, 2)).squeeze(-2).permute(0, 2, 1) + + y = self.norm1(mg) + y = self.conv2(self.conv1(y.transpose(-1, 1))).transpose(-1, 1) + + return self.norm2(mg + y) + + +class SeasonalPrediction(nn.Module): + def __init__( + self, + embedding_size=512, + d_layers=1, + decomp_kernel=[32], + c_out=1, + conv_kernel=[2, 4], + isometric_kernel=[18, 6], + ): + super().__init__() + + self.mic = nn.ModuleList( + [ + MIC( + feature_size=embedding_size, + decomp_kernel=decomp_kernel, + conv_kernel=conv_kernel, + isometric_kernel=isometric_kernel, + ) + for _ in range(d_layers) + ] + ) + + self.projection = nn.Linear(embedding_size, c_out) + + def forward(self, dec): + for mic_layer in self.mic: + dec = mic_layer(dec) + return self.projection(dec) From 1bac98fe95341a7ca64ad0ecab38200f387702ea Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 12 May 2024 23:17:40 +0800 Subject: [PATCH 2/2] feat: add MICN as an imputation model; --- pypots/imputation/__init__.py | 2 + pypots/imputation/micn/__init__.py | 24 +++ pypots/imputation/micn/core.py | 95 +++++++++ pypots/imputation/micn/data.py | 24 +++ pypots/imputation/micn/model.py | 308 +++++++++++++++++++++++++++++ pypots/nn/modules/micn/backbone.py | 21 +- tests/imputation/micn.py | 124 ++++++++++++ 7 files changed, 584 insertions(+), 14 deletions(-) create mode 100644 pypots/imputation/micn/__init__.py create mode 100644 pypots/imputation/micn/core.py create mode 100644 pypots/imputation/micn/data.py create mode 100644 pypots/imputation/micn/model.py create mode 100644 tests/imputation/micn.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 0c63ca9a..800b16e3 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -29,6 +29,7 @@ from .scinet import SCINet from .revinscinet import RevIN_SCINet from .koopa import Koopa +from .micn import MICN # naive imputation methods from .locf import LOCF @@ -60,6 +61,7 @@ "SCINet", "RevIN_SCINet", "Koopa", + "MICN", # naive imputation methods "LOCF", "Mean", diff --git a/pypots/imputation/micn/__init__.py b/pypots/imputation/micn/__init__.py new file mode 100644 index 00000000..9e5a9246 --- /dev/null +++ b/pypots/imputation/micn/__init__.py @@ -0,0 +1,24 @@ +""" +The package of the partially-observed time-series imputation model MICN. + +Refer to the paper +`Huiqiang Wang, Jian Peng, Feihu Huang, Jince Wang, Junhui Chen, and Yifei Xiao +"MICN: Multi-scale Local and Global Context Modeling for Long-term Series Forecasting". +In the Eleventh International Conference on Learning Representations, 2023. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/wanghq21/MICN + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import MICN + +__all__ = [ + "MICN", +] diff --git a/pypots/imputation/micn/core.py b/pypots/imputation/micn/core.py new file mode 100644 index 00000000..a37cbaf8 --- /dev/null +++ b/pypots/imputation/micn/core.py @@ -0,0 +1,95 @@ +""" +The core wrapper assembles the submodules of MICN imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ...nn.modules.fedformer.layers import SeriesDecompositionMultiBlock +from ...nn.modules.micn import BackboneMICN +from ...nn.modules.saits import SaitsLoss, SaitsEmbedding + + +class _MICN(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + d_model: int, + dropout: float, + conv_kernel: list = None, + ORT_weight: float = 1, + MIT_weight: float = 1, + ): + super().__init__() + + self.saits_embedding = SaitsEmbedding( + n_features * 2, + d_model, + with_pos=True, + dropout=dropout, + ) + + decomp_kernel = [] # kernel of decomposition operation + isometric_kernel = [] # kernel of isometric convolution + for ii in conv_kernel: + if ii % 2 == 0: # the kernel of decomposition operation must be odd + decomp_kernel.append(ii + 1) + isometric_kernel.append((n_steps + n_steps + ii) // ii) + else: + decomp_kernel.append(ii) + isometric_kernel.append((n_steps + n_steps + ii - 1) // ii) + + self.decomp_multi = SeriesDecompositionMultiBlock(decomp_kernel) + self.backbone = BackboneMICN( + n_steps, + n_features, + n_steps, + n_features, + n_layers, + d_model, + decomp_kernel, + isometric_kernel, + conv_kernel, + ) + + # for the imputation task, the output dim is the same as input dim + self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + seasonal_init, trend_init = self.decomp_multi(X) + + # WDU: the original MICN paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the + # SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + enc_out = self.saits_embedding(seasonal_init, missing_mask) + + # MICN encoder processing + reconstruction = self.backbone(enc_out) + reconstruction = reconstruction + trend_init + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] + loss, ORT_loss, MIT_loss = self.saits_loss_func( + reconstruction, X_ori, missing_mask, indicating_mask + ) + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + results["loss"] = loss + + return results diff --git a/pypots/imputation/micn/data.py b/pypots/imputation/micn/data.py new file mode 100644 index 00000000..0dd75690 --- /dev/null +++ b/pypots/imputation/micn/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for MICN. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForMICN(DatasetForSAITS): + """Actually MICN uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/micn/model.py b/pypots/imputation/micn/model.py new file mode 100644 index 00000000..a4b7d607 --- /dev/null +++ b/pypots/imputation/micn/model.py @@ -0,0 +1,308 @@ +""" +The implementation of MICN for the partially-observed time-series imputation task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _MICN +from .data import DatasetForMICN +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class MICN(BaseNNImputer): + """The PyTorch implementation of the MICN model. + MICN is originally proposed by Huang et al. in :cite:`wang2023micn`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the MICN model. + + d_model : + The dimension of the model. + + conv_kernel : + The kernel size for the convolutional layers in the model. It should be a list of integers, + and the maximum value in the list should be less than or equal to the minimum value of n_steps and n_features. + + dropout : + The dropout rate for the model. + + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + d_model: int, + conv_kernel: list, + dropout: float = 0, + ORT_weight: float = 1, + MIT_weight: float = 1, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + assert isinstance(conv_kernel, list), "conv_kernel must be a list." + assert max(conv_kernel) <= min( + n_steps, n_features + ), "The maximum value in conv_kernel must be <= the minimum value of n_steps and n_features." + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_layers = n_layers + self.d_model = d_model + self.dropout = dropout + self.conv_kernel = conv_kernel + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + + # set up the model + self.model = _MICN( + self.n_steps, + self.n_features, + self.n_layers, + self.d_model, + self.dropout, + self.conv_kernel, + self.ORT_weight, + self.MIT_weight, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForMICN( + train_set, return_X_ori=False, return_y=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForMICN( + val_set, return_X_ori=True, return_y=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] diff --git a/pypots/nn/modules/micn/backbone.py b/pypots/nn/modules/micn/backbone.py index 3c828254..b3f6b51a 100644 --- a/pypots/nn/modules/micn/backbone.py +++ b/pypots/nn/modules/micn/backbone.py @@ -8,7 +8,6 @@ import torch.nn as nn from .layers import SeasonalPrediction -from ..fedformer.layers import SeriesDecompositionMultiBlock class BackboneMICN(nn.Module): @@ -20,7 +19,9 @@ def __init__( n_pred_features, n_layers, d_model, - conv_kernel=[12, 24], + decomp_kernel, + isometric_kernel, + conv_kernel: list, ): super().__init__() self.n_steps = n_steps @@ -28,18 +29,6 @@ def __init__( self.n_pred_steps = n_pred_steps self.n_pred_features = n_pred_features - decomp_kernel = [] # kernel of decomposition operation - isometric_kernel = [] # kernel of isometric convolution - for ii in conv_kernel: - if ii % 2 == 0: # the kernel of decomposition operation must be odd - decomp_kernel.append(ii + 1) - isometric_kernel.append((n_steps + n_pred_steps + ii) // ii) - else: - decomp_kernel.append(ii) - isometric_kernel.append((n_steps + n_pred_steps + ii - 1) // ii) - - self.decomp_multi = SeriesDecompositionMultiBlock(decomp_kernel) - self.conv_trans = SeasonalPrediction( embedding_size=d_model, d_layers=n_layers, @@ -48,3 +37,7 @@ def __init__( conv_kernel=conv_kernel, isometric_kernel=isometric_kernel, ) + + def forward(self, x): + dec_out = self.conv_trans(x) + return dec_out diff --git a/tests/imputation/micn.py b/tests/imputation/micn.py new file mode 100644 index 00000000..ea27fd95 --- /dev/null +++ b/tests/imputation/micn.py @@ -0,0 +1,124 @@ +""" +Test cases for MICN imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import MICN +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestMICN(unittest.TestCase): + logger.info("Running tests for an imputation model MICN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "MICN") + model_save_name = "saved_micn_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a MICN model + micn = MICN( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + d_model=32, + conv_kernel=[2, 4], + dropout=0, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-micn") + def test_0_fit(self): + self.micn.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-micn") + def test_1_impute(self): + imputation_results = self.micn.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"MICN test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-micn") + def test_2_parameters(self): + assert hasattr(self.micn, "model") and self.micn.model is not None + + assert hasattr(self.micn, "optimizer") and self.micn.optimizer is not None + + assert hasattr(self.micn, "best_loss") + self.assertNotEqual(self.micn.best_loss, float("inf")) + + assert ( + hasattr(self.micn, "best_model_dict") + and self.micn.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-micn") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.micn) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.micn.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.micn.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-micn") + def test_4_lazy_loading(self): + self.micn.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.micn.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading MICN test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main()