Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to customized loss and val funcs #526

Merged
merged 25 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
41b00a3
feat: enable users to customize training loss func and val metric func;
WenjieDu May 8, 2024
06ba371
Merge branch 'dev' into (feat)customized_loss_and_val_func
WenjieDu May 8, 2024
620a95c
refactor: move pypots.utils.metrics to pypots.nn.functional;
WenjieDu May 8, 2024
c7dc6be
feat: add BaseLoss and BaseMetric;
WenjieDu May 9, 2024
83879ff
Merge branch 'dev' into (feat)customized_loss_and_val_func
WenjieDu Sep 25, 2024
6b5c449
Merge branch 'dev' into (feat)customized_loss_and_val_func
WenjieDu Sep 26, 2024
92773f1
Merge branch 'dev' into (feat)customized_loss_and_val_func
WenjieDu Sep 27, 2024
f674a3c
refactor: make FITS able to apply customized loss func;
WenjieDu Sep 27, 2024
40e9602
refactor: replace arg training with self attribute in new added models;
WenjieDu Sep 27, 2024
cc286ee
fix: globally replace importing from with ;
WenjieDu Sep 30, 2024
ba4588c
refactor: still keep pypots.utils.metrics for future compatibility;
WenjieDu Sep 30, 2024
86340d0
refactor: do not expose by default;
WenjieDu Oct 3, 2024
b4b5b48
refactor: remove lingting issues;
WenjieDu Oct 8, 2024
054b3c9
refactor: add customized loss and metric funcs for CSAI models;
WenjieDu Dec 2, 2024
9bc40f4
docs: update doc strings;
WenjieDu Dec 2, 2024
2f5840a
refactor: add customized loss and metric funcs for SegRNN;
WenjieDu Dec 3, 2024
566b859
refactor: simplify some parts;
WenjieDu Dec 3, 2024
30efa4e
refactor: disable clustering algos to customize training loss and val…
WenjieDu Dec 3, 2024
6f3b36e
feat: add loss and metric classes;
WenjieDu Dec 3, 2024
8705db6
refactor: use classes to replace funcs in CLAS algos;
WenjieDu Dec 3, 2024
5b58313
refactor: simplify some parts;
WenjieDu Dec 4, 2024
507cc63
Merge branch 'dev' into (feat)customized_loss_and_val_func
WenjieDu Feb 8, 2025
af3bf85
fix: rename labels in classification CSAI into y;
WenjieDu Dec 2, 2024
9dbc1b1
Merge branch 'dev' into (feat)customized_loss_and_val_func
WenjieDu Feb 19, 2025
6ca6da8
refactor: import from pypots.nn.functional instead of pypots.utils.me…
WenjieDu Feb 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ print(X.shape) # (11988, 48, 37), 11988 samples and each sample has 48 time ste

# Model training. This is PyPOTS showtime.
from pypots.imputation import SAITS
from pypots.utils.metrics import calc_mae
from pypots.nn.functional import calc_mae
saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, n_heads=4, d_k=64, d_v=64, d_ffn=128, dropout=0.1, epochs=10)
# Here I use the whole dataset as the training set because ground truth is not visible to the model, you can also split it into train/val/test sets
saits.fit(dataset) # train the model on the dataset
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ print(X.shape) # X的形状为(11988, 48, 37), 即11988个样本, 每个样本

# 模型训练. PyPOTS的好戏上演了!
from pypots.imputation import SAITS
from pypots.utils.metrics import calc_mae
from pypots.nn.functional import calc_mae
saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, n_heads=4, d_k=64, d_v=64, d_ffn=128, dropout=0.1, epochs=10)
# 因为基准数据对模型不可知, 将整个数据集作为训练集, 也可以把数据集分为训练/验证/测试集
saits.fit(dataset) # 基于数据集训练模型
Expand Down
2 changes: 1 addition & 1 deletion docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ You can also find a simple and quick-start tutorial notebook on Google Colab
from pygrinder import mcar
from pypots.data import load_specific_dataset
from pypots.imputation import SAITS
from pypots.utils.metrics import calc_mae
from pypots.nn.functional import calc_mae

# Data preprocessing. Tedious, but PyPOTS can help. 🤓
data = load_specific_dataset('physionet_2012') # PyPOTS will automatically download and extract it.
Expand Down
4 changes: 2 additions & 2 deletions docs/pypots.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ pypots.utils.file
:show-inheritance:
:inherited-members:

pypots.utils.metrics
pypots.nn.functional
---------------------------

.. automodule:: pypots.utils.metrics
.. automodule:: pypots.nn.functional
:members:
:undoc-members:
:show-inheritance:
Expand Down
2 changes: 0 additions & 2 deletions pypots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
data,
utils,
)
from .gungnir import Gungnir
from .version import __version__

__all__ = [
Expand All @@ -26,6 +25,5 @@
"optim",
"data",
"utils",
"Gungnir",
"__version__",
]
40 changes: 34 additions & 6 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from typing import Optional, Union, Iterable
from typing import Optional, Union, Iterable, Callable

import torch
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -221,7 +221,9 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None
# save all items containing "loss" or "error" in the name
# WDU: may enable customization keywords in the future
if ("loss" in item_name) or ("error" in item_name):
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss.sum(), step)
if isinstance(loss, torch.Tensor):
loss = loss.sum()
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss, step)

def _auto_save_model_if_necessary(
self,
Expand Down Expand Up @@ -415,9 +417,17 @@ class BaseNNModel(BaseModel):
Training epochs, i.e. the maximum rounds of the model to be trained with.

patience :
Number of epochs the training procedure will keep if loss doesn't decrease.
Once exceeding the number, the training will stop.
Must be smaller than or equal to the value of ``epochs``.
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.

train_loss_func:
The customized loss function designed by users for training the model.
If not given, will use the default loss as claimed in the original paper.

val_metric_func:
The customized metric function designed by users for validating the model.
If not given, will use the default MSE metric.

num_workers :
The number of subprocesses to use for data loading.
Expand Down Expand Up @@ -474,6 +484,8 @@ def __init__(
batch_size: int,
epochs: int,
patience: Optional[int] = None,
train_loss_func: Optional[dict] = None,
val_metric_func: Optional[dict] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand All @@ -487,17 +499,33 @@ def __init__(
verbose,
)

# check patience
if patience is None:
patience = -1 # early stopping on patience won't work if it is set as < 0
else:
assert (
patience <= epochs
), f"patience must be smaller than epochs which is {epochs}, but got patience={patience}"

# training hype-parameters
# check train_loss_func and val_metric_func
train_loss_func_name, val_metric_func_name = "default", "loss (default)"
if train_loss_func is not None:
train_loss_func_name = train_loss_func.__class__.__name__
assert isinstance(train_loss_func, Callable), "train_loss_func should be a callable instance"
logger.info(f"Using customized {train_loss_func_name} as the training loss function.")
if val_metric_func is not None:
val_metric_func_name = val_metric_func.__class__.__name__
assert isinstance(val_metric_func, Callable), "val_metric_func should be a callable instance"
logger.info(f"Using customized {val_metric_func_name} as the validation metric function.")

# set up the hype-parameters
self.batch_size = batch_size
self.epochs = epochs
self.patience = patience
self.train_loss_func = train_loss_func
self.train_loss_func_name = train_loss_func_name
self.val_metric_func = val_metric_func
self.val_metric_func_name = val_metric_func_name
self.original_patience = patience
self.num_workers = num_workers

Expand Down
71 changes: 50 additions & 21 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torch.utils.data import DataLoader

from ..base import BaseModel, BaseNNModel
from ..nn.modules.loss import CrossEntropy
from ..nn.modules.metric import Accuracy
from ..utils.logging import logger

try:
Expand Down Expand Up @@ -155,9 +157,17 @@ class BaseNNClassifier(BaseNNModel):
Training epochs, i.e. the maximum rounds of the model to be trained with.

patience :
Number of epochs the training procedure will keep if loss doesn't decrease.
Once exceeding the number, the training will stop.
Must be smaller than or equal to the value of ``epochs``.
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.

train_loss_func:
The customized loss function designed by users for training the model.
If not given, will use the default loss as claimed in the original paper.

val_metric_func:
The customized metric function designed by users for validating the model.
If not given, will use the default loss from the original paper as the metric.

num_workers :
The number of subprocesses to use for data loading.
Expand Down Expand Up @@ -202,24 +212,36 @@ def __init__(
batch_size: int,
epochs: int,
patience: Optional[int] = None,
train_loss_func: Optional[dict] = None,
val_metric_func: Optional[dict] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
model_saving_strategy: Optional[str] = "best",
verbose: bool = True,
):
super().__init__(
batch_size,
epochs,
patience,
num_workers,
device,
saving_path,
model_saving_strategy,
verbose,
batch_size=batch_size,
epochs=epochs,
patience=patience,
train_loss_func=train_loss_func,
val_metric_func=val_metric_func,
num_workers=num_workers,
device=device,
saving_path=saving_path,
model_saving_strategy=model_saving_strategy,
verbose=verbose,
)
self.n_classes = n_classes

# set default training loss function and validation metric function if not given
if train_loss_func is None:
self.train_loss_func = CrossEntropy()
self.train_loss_func_name = self.train_loss_func.__class__.__name__
if val_metric_func is None:
self.val_metric_func = Accuracy()
self.val_metric_func_name = self.val_metric_func.__class__.__name__

@abstractmethod
def _assemble_input_for_training(self, data: list) -> dict:
"""Assemble the given data into a dictionary for training input.
Expand Down Expand Up @@ -308,30 +330,39 @@ def _train_model(

if val_loader is not None:
self.model.eval()
epoch_val_loss_collector = []
epoch_val_pred_collector = []
epoch_val_label_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs)
epoch_val_loss_collector.append(results["loss"].sum().item())
results = self.model(inputs)
epoch_val_pred_collector.append(results["classification_pred"])
epoch_val_label_collector.append(inputs["y"])

epoch_val_pred_collector = torch.cat(epoch_val_pred_collector, dim=-1)
epoch_val_label_collector = torch.cat(epoch_val_label_collector, dim=-1)

mean_val_loss = np.mean(epoch_val_loss_collector)
# TODO: refactor the following code to a function
epoch_val_pred_collector = np.argmax(epoch_val_pred_collector, axis=1)
mean_val_loss = self.val_metric_func(epoch_val_pred_collector, epoch_val_label_collector.numpy())

# save validation loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"classification_loss": mean_val_loss,
self.val_metric_func_name: mean_val_loss,
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)

logger.info(
f"Epoch {epoch:03d} - "
f"training loss: {mean_train_loss:.4f}, "
f"validation loss: {mean_val_loss:.4f}"
f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, "
f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}")
logger.info(
f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}"
)
mean_loss = mean_train_loss

if np.isnan(mean_loss):
Expand Down Expand Up @@ -431,8 +462,6 @@ def classify(
) -> np.ndarray:
"""Classify the input data with the trained model.



Parameters
----------
test_set :
Expand Down
8 changes: 4 additions & 4 deletions pypots/classification/brits/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
(
imputed_data,
f_reconstruction,
Expand All @@ -59,11 +59,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
results["consistency_loss"] = consistency_loss
results["reconstruction_loss"] = reconstruction_loss
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["label"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["label"])
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["y"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["y"])
classification_loss = (f_classification_loss + b_classification_loss) / 2
loss = (
consistency_loss
Expand Down
36 changes: 24 additions & 12 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ class BRITS(BaseNNClassifier):
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.

train_loss_func:
The customized loss function designed by users for training the model.
If not given, will use the default loss as claimed in the original paper.

val_metric_func:
The customized metric function designed by users for validating the model.
If not given, will use the default loss from the original paper as the metric.

optimizer :
The optimizer for model training.
If not given, will use a default Adam optimizer.
Expand Down Expand Up @@ -96,6 +104,8 @@ def __init__(
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
train_loss_func: Optional[dict] = None,
val_metric_func: Optional[dict] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand All @@ -104,15 +114,17 @@ def __init__(
verbose: bool = True,
):
super().__init__(
n_classes,
batch_size,
epochs,
patience,
num_workers,
device,
saving_path,
model_saving_strategy,
verbose,
n_classes=n_classes,
batch_size=batch_size,
epochs=epochs,
patience=patience,
train_loss_func=train_loss_func,
val_metric_func=val_metric_func,
num_workers=num_workers,
device=device,
saving_path=saving_path,
model_saving_strategy=model_saving_strategy,
verbose=verbose,
)

self.n_steps = n_steps
Expand Down Expand Up @@ -147,13 +159,13 @@ def _assemble_input_for_training(self, data: list) -> dict:
back_X,
back_missing_mask,
back_deltas,
label,
y,
) = self._send_data_to_given_device(data)

# assemble input data
inputs = {
"indices": indices,
"label": label,
"y": y,
"forward": {
"X": X,
"missing_mask": missing_mask,
Expand Down Expand Up @@ -248,7 +260,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
classification_pred = results["classification_pred"]
classification_collector.append(classification_pred)

Expand Down
4 changes: 2 additions & 2 deletions pypots/classification/csai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
results["consistency_loss"] = consistency_loss
results["reconstruction_loss"] = reconstruction_loss
# print(inputs["labels"].unsqueeze(1))
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["labels"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["labels"])
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["y"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["y"])
# f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float())
# b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float())
classification_loss = f_classification_loss + b_classification_loss
Expand Down
Loading
Loading