Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Non-stationary Transformer as an imputation model #388

Merged
merged 8 commits into from
May 7, 2024
2 changes: 1 addition & 1 deletion pypots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.4.1"
__version__ = "0.5"


from . import imputation, classification, clustering, forecasting, optim, data, utils
Expand Down
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .saits import SAITS
from .transformer import Transformer
from .itransformer import iTransformer
from .nonstationary_transformer import NonstationaryTransformer
from .timesnet import TimesNet
from .etsformer import ETSformer
from .fedformer import FEDformer
Expand Down Expand Up @@ -45,6 +46,7 @@
"DLinear",
"Informer",
"Autoformer",
"NonstationaryTransformer",
"BRITS",
"MRNN",
"GPVAE",
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/nonstationary_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the partially-observed time-series imputation model Nonstationary-Transformer.

Refer to the paper
`Yong Liu, Haixu Wu, Jianmin Wang, Mingsheng Long.
Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting.
Advances in Neural Information Processing Systems 35 (2022): 9881-9893.
<https://proceedings.neurips.cc/paper_files/paper/2022/file/4054556fcaa934b0bf76da52cf4f92cb-Paper-Conference.pdf>`_

Notes
-----
This implementation is inspired by the official one https://github.com/thuml/Nonstationary_Transformers

"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import NonstationaryTransformer

__all__ = [
"NonstationaryTransformer",
]
111 changes: 111 additions & 0 deletions pypots/imputation/nonstationary_transformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
The core wrapper assembles the submodules of NonstationaryTransformer imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.modules.nonstationary_transformer import (
NonstationaryTransformerEncoder,
Projector,
)
from ...nn.modules.saits import SaitsLoss, SaitsEmbedding
from ...nn.functional.normalization import nonstationary_norm, nonstationary_denorm


class _NonstationaryTransformer(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_model: int,
n_heads: int,
d_ffn: int,
d_projector_hidden: int,
n_projector_hidden_layers: int,
dropout: float,
attn_dropout: float,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

d_k = d_v = d_model // n_heads
self.n_steps = n_steps

self.saits_embedding = SaitsEmbedding(
n_features * 2,
d_model,
with_pos=False,
dropout=dropout,
)
self.encoder = NonstationaryTransformerEncoder(
n_layers,
d_model,
n_heads,
d_k,
d_v,
d_ffn,
dropout,
attn_dropout,
)
self.tau_learner = Projector(
d_in=n_features,
n_steps=n_steps,
d_hidden=d_projector_hidden,
n_hidden_layers=n_projector_hidden_layers,
d_output=1,
)
self.delta_learner = Projector(
d_in=n_features,
n_steps=n_steps,
d_hidden=d_projector_hidden,
n_hidden_layers=n_projector_hidden_layers,
d_output=n_steps,
)

# for the imputation task, the output dim is the same as input dim
self.output_projection = nn.Linear(d_model, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]
X_enc, means, stdev = nonstationary_norm(X, missing_mask)

tau = self.tau_learner(X, stdev).exp()
delta = self.delta_learner(X, means)

# WDU: the original Nonstationary Transformer paper isn't proposed for imputation task. Hence the model doesn't
# take the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
enc_out = self.saits_embedding(X, missing_mask)

# NonstationaryTransformer encoder processing
enc_out, attns = self.encoder(enc_out, tau=tau, delta=delta)
# project back the original data space
reconstruction = self.output_projection(enc_out)
reconstruction = nonstationary_denorm(reconstruction, means, stdev)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/nonstationary_transformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for NonstationaryTransformer.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForNonstationaryTransformer(DatasetForSAITS):
"""Actually NonstationaryTransformer uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading
Loading