-
-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #505 from ztxtech/main
Add TEFN model
- Loading branch information
Showing
12 changed files
with
993 additions
and
181 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
# neural network imputation methods | ||
from .brits import BRITS | ||
from .csdi import CSDI | ||
from .gpvae import GPVAE | ||
|
@@ -44,6 +43,7 @@ | |
from .mean import Mean | ||
from .median import Median | ||
from .lerp import Lerp | ||
from .tefn import TEFN | ||
|
||
__all__ = [ | ||
# neural network imputation methods | ||
|
@@ -84,4 +84,5 @@ | |
"Mean", | ||
"Median", | ||
"Lerp", | ||
"TEFN" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
The package of the forecasting model TEFN. | ||
Refer to the paper | ||
`Tianxiang Zhan, Yuanpeng He, Yong Deng, and Zhen Li. | ||
Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting. | ||
In Arxiv, 2024. | ||
<https://arxiv.org/abs/2405.06419>`_ | ||
Notes | ||
----- | ||
This implementation is transfered from the official one https://github.com/ztxtech/Time-Evidence-Fusion-Network | ||
""" | ||
|
||
# Created by Tianxiang Zhan <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
|
||
from .model import TEFN | ||
|
||
__all__ = [ | ||
"TEFN", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
""" | ||
""" | ||
|
||
# Created by Tianxiang Zhan <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
import torch.nn as nn | ||
|
||
from ...nn.functional import nonstationary_norm, nonstationary_denorm | ||
from ...nn.modules.tefn import BackboneTEFN | ||
from ...utils.metrics import calc_mse | ||
|
||
|
||
class _TEFN(nn.Module): | ||
def __init__( | ||
self, | ||
n_steps, | ||
n_features, | ||
n_fod, | ||
apply_nonstationary_norm, | ||
): | ||
super().__init__() | ||
|
||
self.seq_len = n_steps | ||
self.n_fod = n_fod | ||
self.apply_nonstationary_norm = apply_nonstationary_norm | ||
|
||
self.model = BackboneTEFN( | ||
n_steps, | ||
n_features, | ||
n_fod, | ||
) | ||
|
||
def forward(self, inputs: dict, training: bool = True) -> dict: | ||
X, missing_mask = inputs["X"], inputs["missing_mask"] | ||
|
||
if self.apply_nonstationary_norm: | ||
# Normalization from Non-stationary Transformer | ||
X, means, stdev = nonstationary_norm(X, missing_mask) | ||
|
||
# TEFN processing | ||
out = self.model(X) | ||
|
||
if self.apply_nonstationary_norm: | ||
# De-Normalization from Non-stationary Transformer | ||
out = nonstationary_denorm(out, means, stdev) | ||
|
||
imputed_data = missing_mask * X + (1 - missing_mask) * out | ||
results = { | ||
"imputed_data": imputed_data, | ||
} | ||
|
||
if training: | ||
# `loss` is always the item for backward propagating to update the model | ||
loss = calc_mse(out, inputs["X_ori"], inputs["indicating_mask"]) | ||
results["loss"] = loss | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
Dataset class for the imputation model TEFN. | ||
""" | ||
|
||
# Created by Tianxiang Zhan <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
from typing import Union | ||
|
||
from ..saits.data import DatasetForSAITS | ||
|
||
|
||
class DatasetForTEFN(DatasetForSAITS): | ||
"""Actually TEFN uses the same data strategy as SAITS, needs MIT for training.""" | ||
|
||
def __init__( | ||
self, | ||
data: Union[dict, str], | ||
return_X_ori: bool, | ||
return_y: bool, | ||
file_type: str = "hdf5", | ||
rate: float = 0.2, | ||
): | ||
super().__init__(data, return_X_ori, return_y, file_type, rate) |
Oops, something went wrong.