Skip to content

Commit

Permalink
feat: Add convenience function to read loss from CSVLogger (#267)
Browse files Browse the repository at this point in the history
### Description

Following #252, this PR
enforces that the `CSVLogger` is always used (even if `WandB` is
requested for instance), and add an API entry point in `CAREamist` to
return a dictionary of the losses.

This allows users to simply plot the loss in a notebook after training
for instance. While they will be better off using `WandB` or
`TensorBoard`, this is enough for most users.

- **What**: Enforce `CSVLogger` and add functions to read the losses
from `metrics.csv`.
- **Why**: So that users have an easy way to plot the loss curves.
- **How**: Add a new `lightning_utils.py` file with the read csv
function, and call this method from `CAREamist`.

### Changes Made

- **Added**: `lightning_utils.py`.
- **Modified**: `CAREamist`.

### Related Issues

Link to any related issues or discussions. Use keywords like "Fixes",
"Resolves", or "Closes" to link to issues automatically.

- Resolves #252

### Additional Notes and Examples

An alternative path would have been to add a `Callback` and do the
logging ourselves. I decided for the solution that uses the `csv` file
that is anyway created by default (when there is no WandB or TB
loggers), to minimize the code that needs to be maintained.

One potential issue is the particular csv file read is chosen following
the experiment name recorded by `CAREamist` and the last `version_*`.
This may not be true if the paths have changed, but in most cases it
should be valid if called right after training.

Here is what it looks like in the notebooks:
``` python
import matplotlib.pyplot as plt

losses = careamist.get_losses()

plt.plot(losses["train_epoch"], losses["train_loss"], label="Train Loss")
plt.plot(losses["val_epoch"], losses["val_loss"], label="Val Loss")
``` 

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
jdeschamps authored Nov 12, 2024
1 parent 8dcc447 commit b444084
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 12 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,3 @@ exclude = [ # don't report on objects that match any of these regex
"test_*",
"src/careamics/lvae_training/*",
]

44 changes: 33 additions & 11 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
EarlyStopping,
ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger

from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
from careamics.config.support import (
Expand All @@ -33,10 +33,11 @@
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.prediction_utils import convert_outputs
from careamics.utils import check_path_exists, get_logger
from careamics.utils.lightning_utils import read_csv_logger

logger = get_logger(__name__)

LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
LOGGER_TYPES = list[Union[TensorBoardLogger, WandbLogger, CSVLogger]]


class CAREamist:
Expand Down Expand Up @@ -170,18 +171,29 @@ def __init__(
self._define_callbacks(callbacks)

# instantiate logger
csv_logger = CSVLogger(
name=self.cfg.experiment_name,
save_dir=self.work_dir / "csv_logs",
)

if self.cfg.training_config.has_logger():
if self.cfg.training_config.logger == SupportedLogger.WANDB:
self.experiment_logger: LOGGER_TYPES = WandbLogger(
name=self.cfg.experiment_name,
save_dir=self.work_dir / Path("logs"),
)
experiment_logger: LOGGER_TYPES = [
WandbLogger(
name=self.cfg.experiment_name,
save_dir=self.work_dir / Path("wandb_logs"),
),
csv_logger,
]
elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD:
self.experiment_logger = TensorBoardLogger(
save_dir=self.work_dir / Path("logs"),
)
experiment_logger = [
TensorBoardLogger(
save_dir=self.work_dir / Path("tb_logs"),
),
csv_logger,
]
else:
self.experiment_logger = None
experiment_logger = [csv_logger]

# instantiate trainer
self.trainer = Trainer(
Expand All @@ -195,7 +207,7 @@ def __init__(
gradient_clip_algorithm=self.cfg.training_config.gradient_clip_algorithm,
callbacks=self.callbacks,
default_root_dir=self.work_dir,
logger=self.experiment_logger,
logger=experiment_logger,
)

# place holder for the datamodules
Expand Down Expand Up @@ -904,3 +916,13 @@ def export_to_bmz(
channel_names=channel_names,
data_description=data_description,
)

def get_losses(self) -> dict[str, list]:
"""Return data that can be used to plot train and validation loss curves.
Returns
-------
dict of str: list
Dictionary containing the losses for each epoch.
"""
return read_csv_logger(self.cfg.experiment_name, self.work_dir / "csv_logs")
57 changes: 57 additions & 0 deletions src/careamics/utils/lightning_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""PyTorch lightning utilities."""

from pathlib import Path
from typing import Union


def read_csv_logger(experiment_name: str, log_folder: Union[str, Path]) -> dict:
"""Return the loss curves from the csv logs.
Parameters
----------
experiment_name : str
Name of the experiment.
log_folder : Path or str
Path to the folder containing the csv logs.
Returns
-------
dict
Dictionary containing the loss curves, with keys "train_epoch", "val_epoch",
"train_loss" and "val_loss".
"""
path = Path(log_folder) / experiment_name

# find the most recent of version_* folders
versions = [int(v.name.split("_")[-1]) for v in path.iterdir() if v.is_dir()]
version = max(versions)

path_log = path / f"version_{version}" / "metrics.csv"

epochs = []
train_losses_tmp = []
val_losses_tmp = []
with open(path_log) as f:
lines = f.readlines()

for single_line in lines[1:]:
epoch, _, train_loss, _, val_loss = single_line.strip().split(",")

epochs.append(epoch)
train_losses_tmp.append(train_loss)
val_losses_tmp.append(val_loss)

# train and val are not logged on the same row and can have different lengths
train_epoch = [
int(epochs[i]) for i in range(len(epochs)) if train_losses_tmp[i] != ""
]
val_epoch = [int(epochs[i]) for i in range(len(epochs)) if val_losses_tmp[i] != ""]
train_losses = [float(loss) for loss in train_losses_tmp if loss != ""]
val_losses = [float(loss) for loss in val_losses_tmp if loss != ""]

return {
"train_epoch": train_epoch,
"val_epoch": val_epoch,
"train_loss": train_losses,
"val_loss": val_losses,
}
16 changes: 16 additions & 0 deletions tests/test_careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,19 @@ def _train():
thread.join()

assert careamist.trainer.should_stop


def test_read_logger(tmp_path, minimum_configuration):

config = Configuration(**minimum_configuration)
config.training_config.num_epochs = 10

array = np.arange(32 * 32).reshape((32, 32))

careamist = CAREamist(config, work_dir=tmp_path)
careamist.train(train_source=array)
losses = careamist.get_losses()

assert len(losses) == 4
for key in losses:
assert len(losses[key]) == config.training_config.num_epochs
23 changes: 23 additions & 0 deletions tests/utils/test_lightning_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np

from careamics import CAREamist, Configuration
from careamics.utils import cwd
from careamics.utils.lightning_utils import read_csv_logger


def test_read_logger(tmp_path, minimum_configuration):

config = Configuration(**minimum_configuration)
config.training_config.num_epochs = 10

array = np.arange(32 * 32).reshape((32, 32))

with cwd(tmp_path):
careamist = CAREamist(config)
careamist.train(train_source=array)

losses = read_csv_logger(config.experiment_name, tmp_path / "csv_logs")

assert len(losses) == 4
for key in losses:
assert len(losses[key]) == config.training_config.num_epochs

0 comments on commit b444084

Please sign in to comment.