Skip to content

Commit

Permalink
Merge pull request #454 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Add ImputeFormer, fix RevIN, and update docs
  • Loading branch information
WenjieDu authored Jul 2, 2024
2 parents ac3318f + dfcdfd1 commit 90aa00b
Show file tree
Hide file tree
Showing 14 changed files with 1,007 additions and 32 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ The paper references and links are all listed at the bottom of this file.

| **Type** | **Algo** | **IMPU** | **FORE** | **CLAS** | **CLUS** | **ANOD** | **Year - Venue** |
|:--------------|:----------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------------|
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |
| Neural Net | iTransformer🧑‍🔧[^24] || | | | | `2024 - ICLR` |
| Neural Net | SAITS[^1] || | | | | `2023 - ESWA` |
| Neural Net | FreTS🧑‍🔧[^23] || | | | | `2023 - NeurIPS` |
Expand Down Expand Up @@ -293,9 +294,9 @@ year={2023},
}
```
or
> Wenjie Du. (2023).
> Wenjie Du.
> PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series.
> arXiv, abs/2305.18811. https://arxiv.org/abs/2305.18811
> arXiv, abs/2305.18811, 2023.

## ❖ Contribution
Expand Down Expand Up @@ -380,7 +381,7 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together
[^31]: Kim, T., Kim, J., Tae, Y., Park, C., Choi, J. H., & Choo, J. (2022). [Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift](https://openreview.net/forum?id=cGDAkQo1C0p). *ICLR 2022*.
[^32]: Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). [Reformer: The Efficient Transformer](https://openreview.net/forum?id=0EXmFzUn5I). *ICLR 2020*.
[^33]: Cao, D., Wang, Y., Duan, J., Zhang, C., Zhu, X., Huang, C., Tong, Y., Xu, B., Bai, J., Tong, J., & Zhang, Q. (2020). [Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting](https://proceedings.neurips.cc/paper/2020/hash/cdf6581cb7aca4b7e19ef136c6e601a5-Abstract.html). *NeurIPS 2020*.

[^34]: Nie, T., Qin, G., Mei, Y., & Sun, J. (2024). [ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation](https://arxiv.org/abs/2312.01728). *KDD 2024*.

<details>
<summary>🏠 Visits</summary>
Expand Down
4 changes: 3 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及

| **类型** | **算法** | **插补** | **预测** | **分类** | **聚类** | **异常检测** | **年份 - 刊物** |
|:--------------|:----------------------------|:------:|:------:|:------:|:------:|:--------:|:-----------------|
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |
| Neural Net | iTransformer🧑‍🔧[^24] || | | | | `2024 - ICLR` |
| Neural Net | SAITS[^1] || | | | | `2023 - ESWA` |
| Neural Net | FreTS🧑‍🔧[^23] || | | | | `2023 - NeurIPS` |
Expand Down Expand Up @@ -145,7 +146,7 @@ PyPOTS当前支持多变量POTS数据的插补,预测,分类,聚类以及

👈 在PyPOTS中,数据可以被看作是咖啡豆,而写的携带缺失值的POTS数据则是不完整的咖啡豆。
为了让用户能够轻松使用各种开源的时间序列数据集,我们创建了开源时间序列数据集的仓库 Time Series Data Beans (TSDB)(可以将其视为咖啡豆仓库),
TSDB让加载开源时序数据集变得超级简单!访问 [TSDB](https://github.com/WenjieDu/TSDB),了解更多关于TSDB的信息,目前总共支持170个开源数据集
TSDB让加载开源时序数据集变得超级简单!访问 [TSDB](https://github.com/WenjieDu/TSDB),了解更多关于TSDB的信息,目前总共支持172个开源数据集

<a href="https://github.com/WenjieDu/PyGrinder">
<img src="https://pypots.com/figs/pypots_logos/PyGrinder/logo_FFBG.svg" align="right" width="140" alt="PyGrinder logo"/>
Expand Down Expand Up @@ -351,6 +352,7 @@ PyPOTS社区是一个开放、透明、友好的社区,让我们共同努力
[^31]: Kim, T., Kim, J., Tae, Y., Park, C., Choi, J. H., & Choo, J. (2022). [Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift](https://openreview.net/forum?id=cGDAkQo1C0p). *ICLR 2022*.
[^32]: Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). [Reformer: The Efficient Transformer](https://openreview.net/forum?id=0EXmFzUn5I). *ICLR 2020*.
[^33]: Cao, D., Wang, Y., Duan, J., Zhang, C., Zhu, X., Huang, C., Tong, Y., Xu, B., Bai, J., Tong, J., & Zhang, Q. (2020). [Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting](https://proceedings.neurips.cc/paper/2020/hash/cdf6581cb7aca4b7e19ef136c6e601a5-Abstract.html). *NeurIPS 2020*.
[^34]: Nie, T., Qin, G., Mei, Y., & Sun, J. (2024). [ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation](https://arxiv.org/abs/2312.01728). *KDD 2024*.


<details>
Expand Down
6 changes: 3 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Welcome to PyPOTS docs!

**A Python Toolbox for Machine Learning on Partially-Observed Time Series**

.. image:: https://img.shields.io/badge/Python-v3.7+-E97040?logo=python&logoColor=white
.. image:: https://img.shields.io/badge/Python-v3.8+-E97040?logo=python&logoColor=white
:alt: Python version
:target: https://docs.pypots.com/en/latest/install.html#reasons-of-version-limitations-on-dependencies

Expand Down Expand Up @@ -306,9 +306,9 @@ or

..
Wenjie Du. (2023).
Wenjie Du.
PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series.
arXiv, abs/2305.18811. https://doi.org/10.48550/arXiv.2305.18811
arXiv, abs/2305.18811, 2023.


❖ Contribution
Expand Down
11 changes: 11 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,14 @@ @inproceedings{xu2024fits
year={2024},
url={https://openreview.net/forum?id=bWcnvZ3qMb}
}

@article{nie2024imputeformer,
title={ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation},
author={Nie, Tong and Qin, Guoyang and Ma, Wei and Mei, Yuewen and Sun, Jian},
booktitle = {Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
publisher = {Association for Computing Machinery},
year={2024},
series = {KDD '24},
doi = {10.1145/3637528.3671751},
url = {https://doi.org/10.1145/3637528.3671751},
}
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .tide import TiDE
from .grud import GRUD
from .stemgnn import StemGNN
from .imputeformer import ImputeFormer

# naive imputation methods
from .locf import LOCF
Expand Down Expand Up @@ -70,6 +71,7 @@
"TiDE",
"GRUD",
"StemGNN",
"ImputeFormer",
# naive imputation methods
"LOCF",
"Mean",
Expand Down
20 changes: 20 additions & 0 deletions pypots/imputation/imputeformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
The package of the partially-observed time-series imputation model ImputeFormer.
Refer to the papers
`Tong Nie, Guoyang Qin, Wei Ma, Yuewen Mei, Jian Sun.
"ImputeFormer: Low Rankness-Induced Transformers for Generalizable Spatiotemporal Imputation"
KDD 2024.
<https://doi.org/10.48550/arXiv.2312.01728>`_
"""

# Created by Tong Nie <[email protected]> and Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import ImputeFormer

__all__ = [
"ImputeFormer",
]
151 changes: 151 additions & 0 deletions pypots/imputation/imputeformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
The core wrapper assembles the submodules of ImputeFormer imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Tong Nie <[email protected]> and Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from ...nn.modules.imputeformer import (
EmbeddedAttentionLayer,
ProjectedAttentionLayer,
MLP,
)
from ...nn.modules.saits import SaitsLoss


class _ImputeFormer(nn.Module):
"""
Spatiotemporal Imputation Transformer induced by low-rank factorization, KDD'24.
Note:
This is a simplified implementation under the SAITS framework (ORT+MIT).
The timestamp encoding is also removed for ease of implementation.
"""

def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_input_embed: int,
d_learnable_embed: int,
d_proj: int,
d_ffn: int,
n_temporal_heads: int,
dropout: float = 0.0,
input_dim: int = 1,
output_dim: int = 1,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.n_nodes = n_features
self.in_steps = n_steps
self.out_steps = n_steps
self.input_dim = input_dim
self.output_dim = output_dim
self.input_embedding_dim = d_input_embed
self.learnable_embedding_dim = d_learnable_embed
self.model_dim = d_input_embed + d_learnable_embed

self.n_temporal_heads = n_temporal_heads
self.num_layers = n_layers
self.input_proj = nn.Linear(input_dim, self.input_embedding_dim)
self.d_proj = d_proj
self.d_ffn = d_ffn

self.learnable_embedding = nn.init.xavier_uniform_(
nn.Parameter(
torch.empty(self.in_steps, self.n_nodes, self.learnable_embedding_dim)
)
)

self.readout = MLP(self.model_dim, self.model_dim, output_dim, n_layers=2)

self.attn_layers_t = nn.ModuleList(
[
ProjectedAttentionLayer(
self.n_nodes,
self.d_proj,
self.model_dim,
self.n_temporal_heads,
self.model_dim,
dropout,
)
for _ in range(self.num_layers)
]
)

self.attn_layers_s = nn.ModuleList(
[
EmbeddedAttentionLayer(
self.model_dim,
self.learnable_embedding_dim,
self.d_ffn,
)
for _ in range(self.num_layers)
]
)

# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
x, missing_mask = inputs["X"], inputs["missing_mask"]

# x: (batch_size, in_steps, num_nodes)
# Note that ImputeFormer is designed for Spatial-Temporal data that has the format [B, S, N, C],
# where N is the number of nodes and C is an additional feature dimension,
# We simply add an extra axis here for implementation.
x = x.unsqueeze(-1) # [b s n c]
missing_mask = missing_mask.unsqueeze(-1) # [b s n c]
batch_size = x.shape[0]
# Whiten missing values
x = x * missing_mask
x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)

# Learnable node embedding
node_emb = self.learnable_embedding.expand(
batch_size, *self.learnable_embedding.shape
)
x = torch.cat(
[x, node_emb], dim=-1
) # (batch_size, in_steps, num_nodes, model_dim)

# Spatial and temporal processing with customized attention layers
x = x.permute(0, 2, 1, 3) # [b n s c]
for att_t, att_s in zip(self.attn_layers_t, self.attn_layers_s):
x = att_t(x)
x = att_s(x, self.learnable_embedding, dim=1)

# Readout
x = x.permute(0, 2, 1, 3) # [b s n c]
reconstruction = self.readout(x)
reconstruction = reconstruction.squeeze(-1) # [b s n]
missing_mask = missing_mask.squeeze(-1) # [b s n]

# Below is the SAITS processing pipeline:
# replace the observed part with values from X
imputed_data = missing_mask * inputs["X"] + (1 - missing_mask) * reconstruction

# ensemble the results as a dictionary for return
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
22 changes: 22 additions & 0 deletions pypots/imputation/imputeformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Dataset class for the imputation model ImputeFormer.
"""

# Created by Tong Nie <[email protected]> and Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForImputeFormer(DatasetForSAITS):
def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit 90aa00b

Please sign in to comment.