Skip to content

Commit

Permalink
Merge pull request #348 from WenjieDu/(refactor)modular_multitask_nn
Browse files Browse the repository at this point in the history
Modularize neural network models
  • Loading branch information
WenjieDu authored Apr 13, 2024
2 parents 77c7ab2 + 0d7c493 commit e29388a
Show file tree
Hide file tree
Showing 128 changed files with 3,142 additions and 1,793 deletions.
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ Additionally, we present you a usage example of imputing missing values in time
❖ Available Algorithms
^^^^^^^^^^^^^^^^^^^^^^^
PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values.
The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed `on the reference page </references.html>`_.
The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are provided and you can click them to check out.

🌟 Since **v0.2**, all neural-network models in PyPOTS has got hyperparameter-optimization support.
This functionality is implemented with the `Microsoft NNI <https://github.com/microsoft/nni>`_ framework.
Expand All @@ -183,7 +183,7 @@ the same as we did in `SAITS paper <https://arxiv.org/pdf/2202.08516)>`_.
Task Type Algorithm Year Reference
============================== ================ ========================================================================================================= ====== =========
Imputation Neural Net SAITS (Self-Attention-based Imputation for Time Series) 2023 :cite:`du2023SAITS`
Imputation Neural Net Transformer 2017 :cite:`vaswani2017Transformer`, :cite:`du2023SAITS`
Imputation Neural Net Transformer 2017 :cite:`vaswani2017Transformer`
Imputation Neural Net Crossformer (Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting) 2023 :cite:`nie2023patchtst`
Imputation Neural Net TimesNet (Temporal 2D-Variation Modeling for General Time Series Analysis) 2023 :cite:`wu2023timesnet`
Imputation Neural Net PatchTST (A Time Series is Worth 64 Words: Long-Term Forecasting with Transformers) 2023 :cite:`nie2023patchtst`
Expand Down
126 changes: 123 additions & 3 deletions docs/pypots.nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,134 @@ pypots.nn.functional
:members:


pypots.nn.modules.rnn
---------------------
pypots.nn.modules.autoformer
-----------------------------

.. automodule:: pypots.nn.modules.autoformer
:members:


pypots.nn.modules.brits
-----------------------------

.. automodule:: pypots.nn.modules.brits
:members:


pypots.nn.modules.crli
-----------------------------

.. automodule:: pypots.nn.modules.crli
:members:


pypots.nn.modules.crossformer
-----------------------------

.. automodule:: pypots.nn.modules.crossformer
:members:


pypots.nn.modules.csdi
-----------------------------

.. automodule:: pypots.nn.modules.csdi
:members:


pypots.nn.modules.dlinear
-----------------------------

.. automodule:: pypots.nn.modules.dlinear
:members:


pypots.nn.modules.etsformer
-----------------------------

.. automodule:: pypots.nn.modules.etsformer
:members:

.. automodule:: pypots.nn.modules.rnn

pypots.nn.modules.fedformer
-----------------------------

.. automodule:: pypots.nn.modules.fedformer
:members:


pypots.nn.modules.gpvae
-----------------------------

.. automodule:: pypots.nn.modules.gpvae
:members:


pypots.nn.modules.grud
-----------------------------

.. automodule:: pypots.nn.modules.grud
:members:


pypots.nn.modules.informer
-----------------------------

.. automodule:: pypots.nn.modules.informer
:members:


pypots.nn.modules.mrnn
-----------------------------

.. automodule:: pypots.nn.modules.mrnn
:members:


pypots.nn.modules.patchtst
-----------------------------

.. automodule:: pypots.nn.modules.patchtst
:members:


pypots.nn.modules.raindrop
-----------------------------

.. automodule:: pypots.nn.modules.raindrop
:members:


pypots.nn.modules.saits
-----------------------------

.. automodule:: pypots.nn.modules.saits
:members:


pypots.nn.modules.timesnet
-----------------------------

.. automodule:: pypots.nn.modules.timesnet
:members:


pypots.nn.modules.transformer
-----------------------------

.. automodule:: pypots.nn.modules.transformer
:members:


pypots.nn.modules.usgan
-----------------------------

.. automodule:: pypots.nn.modules.usgan
:members:


pypots.nn.modules.vader
-----------------------------

.. automodule:: pypots.nn.modules.vader
:members:
89 changes: 89 additions & 0 deletions pypots/classification/brits/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
The implementation of BRITS for the partially-observed time-series classification task.
Refer to the paper "Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018).
BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018."
Notes
-----
Partial implementation uses code from https://github.com/caow13/BRITS. The bugs in the original implementation
are fixed here.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...nn.modules.brits import BackboneBRITS


class _BRITS(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
rnn_hidden_size: int,
n_classes: int,
classification_weight: float,
reconstruction_weight: float,
):
super().__init__()
self.n_steps = n_steps
self.n_features = n_features
self.rnn_hidden_size = rnn_hidden_size
self.n_classes = n_classes
self.classification_weight = classification_weight
self.reconstruction_weight = reconstruction_weight

# create models
self.model = BackboneBRITS(n_steps, n_features, rnn_hidden_size)
self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes)

def forward(self, inputs: dict, training: bool = True) -> dict:
(
imputed_data,
f_reconstruction,
b_reconstruction,
f_hidden_states,
b_hidden_states,
consistency_loss,
reconstruction_loss,
) = self.model(inputs)

f_logits = self.f_classifier(f_hidden_states)
b_logits = self.b_classifier(b_hidden_states)
f_prediction = torch.softmax(f_logits, dim=1)
b_prediction = torch.softmax(b_logits, dim=1)
classification_pred = (f_prediction + b_prediction) / 2

results = {
"imputed_data": imputed_data,
"classification_pred": classification_pred,
}

# if in training mode, return results with losses
if training:
results["consistency_loss"] = consistency_loss
results["reconstruction_loss"] = reconstruction_loss
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["label"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["label"])
classification_loss = (f_classification_loss + b_classification_loss) / 2
loss = (
consistency_loss
+ reconstruction_loss * self.reconstruction_weight
+ classification_loss * self.classification_weight
)

# `loss` is always the item for backward propagating to update the model
results["loss"] = loss
results["reconstruction"] = (f_reconstruction + b_reconstruction) / 2
results["classification_loss"] = classification_loss
results["f_reconstruction"] = f_reconstruction
results["b_reconstruction"] = b_reconstruction

return results
3 changes: 1 addition & 2 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import torch
from torch.utils.data import DataLoader

from .core import _BRITS
from .data import DatasetForBRITS
from .modules import _BRITS
from ..base import BaseNNClassifier
from ...optim.adam import Adam
from ...optim.base import Optimizer
Expand Down Expand Up @@ -135,7 +135,6 @@ def __init__(
self.n_classes,
self.classification_weight,
self.reconstruction_weight,
self.device,
)
self._send_model_to_given_device()
self._print_model_size()
Expand Down
13 changes: 0 additions & 13 deletions pypots/classification/brits/modules/__init__.py

This file was deleted.

Loading

0 comments on commit e29388a

Please sign in to comment.