Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add convenience function to read loss from CSVLogger #267

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,3 @@ exclude = [ # don't report on objects that match any of these regex
"test_*",
"src/careamics/lvae_training/*",
]

44 changes: 33 additions & 11 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
EarlyStopping,
ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger

from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
from careamics.config.support import (
Expand All @@ -33,10 +33,11 @@
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.prediction_utils import convert_outputs
from careamics.utils import check_path_exists, get_logger
from careamics.utils.lightning_utils import read_csv_logger

logger = get_logger(__name__)

LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]


class CAREamist:
Expand Down Expand Up @@ -170,18 +171,29 @@ def __init__(
self._define_callbacks(callbacks)

# instantiate logger
csv_logger = CSVLogger(
name=self.cfg.experiment_name,
save_dir=self.work_dir / "csv_logs",
)

if self.cfg.training_config.has_logger():
if self.cfg.training_config.logger == SupportedLogger.WANDB:
self.experiment_logger: LOGGER_TYPES = WandbLogger(
name=self.cfg.experiment_name,
save_dir=self.work_dir / Path("logs"),
)
experiment_logger: LOGGER_TYPES = [
WandbLogger(
name=self.cfg.experiment_name,
save_dir=self.work_dir / Path("wandb_logs"),
),
csv_logger,
]
elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
self.experiment_logger = TensorBoardLogger(
save_dir=self.work_dir / Path("logs"),
)
experiment_logger = [
TensorBoardLogger(
save_dir=self.work_dir / Path("tb_logs"),
),
csv_logger,
]
else:
self.experiment_logger = None
experiment_logger = [csv_logger]

# instantiate trainer
self.trainer = Trainer(
Expand All @@ -195,7 +207,7 @@ def __init__(
gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
callbacks=self.callbacks,
default_root_dir=self.work_dir,
logger=self.experiment_logger,
logger=experiment_logger,
)

# place holder for the datamodules
Expand Down Expand Up @@ -904,3 +916,13 @@ def export_to_bmz(
channel_names=channel_names,
data_description=data_description,
)

def get_losses(self) -> dict[str, list]:
"""Return data that can be used to plot train and validation loss curves.

Returns
-------
dict of str: list
Dictionary containing the losses for each epoch.
"""
return read_csv_logger(self.cfg.experiment_name, self.work_dir / "csv_logs")
57 changes: 57 additions & 0 deletions src/careamics/utils/lightning_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""PyTorch lightning utilities."""

from pathlib import Path
from typing import Union


def read_csv_logger(experiment_name: str, log_folder: Union[str, Path]) -> dict:
"""Return the loss curves from the csv logs.

Parameters
----------
experiment_name : str
Name of the experiment.
log_folder : Path or str
Path to the folder containing the csv logs.

Returns
-------
dict
Dictionary containing the loss curves, with keys "train_epoch", "val_epoch",
"train_loss" and "val_loss".
"""
path = Path(log_folder) / experiment_name

# find the most recent of version_* folders
versions = [int(v.name.split("_")[-1]) for v in path.iterdir() if v.is_dir()]
version = max(versions)

path_log = path / f"version_{version}" / "metrics.csv"

epochs = []
train_losses_tmp = []
val_losses_tmp = []
with open(path_log) as f:
lines = f.readlines()

for single_line in lines[1:]:
epoch, _, train_loss, _, val_loss = single_line.strip().split(",")

epochs.append(epoch)
train_losses_tmp.append(train_loss)
val_losses_tmp.append(val_loss)

# train and val are not logged on the same row and can have different lengths
train_epoch = [
int(epochs[i]) for i in range(len(epochs)) if train_losses_tmp[i] != ""
]
val_epoch = [int(epochs[i]) for i in range(len(epochs)) if val_losses_tmp[i] != ""]
train_losses = [float(loss) for loss in train_losses_tmp if loss != ""]
val_losses = [float(loss) for loss in val_losses_tmp if loss != ""]

return {
"train_epoch": train_epoch,
"val_epoch": val_epoch,
"train_loss": train_losses,
"val_loss": val_losses,
}
16 changes: 16 additions & 0 deletions tests/test_careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,19 @@ def _train():
thread.join()

assert careamist.trainer.should_stop


def test_read_logger(tmp_path, minimum_configuration):

config = Configuration(**minimum_configuration)
config.training_config.num_epochs = 10

array = np.arange(32 * 32).reshape((32, 32))

careamist = CAREamist(config, work_dir=tmp_path)
careamist.train(train_source=array)
losses = careamist.get_losses()

assert len(losses) == 4
for key in losses:
assert len(losses[key]) == config.training_config.num_epochs
23 changes: 23 additions & 0 deletions tests/utils/test_lightning_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np

from careamics import CAREamist, Configuration
from careamics.utils import cwd
from careamics.utils.lightning_utils import read_csv_logger


def test_read_logger(tmp_path, minimum_configuration):

config = Configuration(**minimum_configuration)
config.training_config.num_epochs = 10

array = np.arange(32 * 32).reshape((32, 32))

with cwd(tmp_path):
careamist = CAREamist(config)
careamist.train(train_source=array)

losses = read_csv_logger(config.experiment_name, tmp_path / "csv_logs")

assert len(losses) == 4
for key in losses:
assert len(losses[key]) == config.training_config.num_epochs
Loading