-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add convenience function to read loss from CSVLogger (#267)
### Description Following #252, this PR enforces that the `CSVLogger` is always used (even if `WandB` is requested for instance), and add an API entry point in `CAREamist` to return a dictionary of the losses. This allows users to simply plot the loss in a notebook after training for instance. While they will be better off using `WandB` or `TensorBoard`, this is enough for most users. - **What**: Enforce `CSVLogger` and add functions to read the losses from `metrics.csv`. - **Why**: So that users have an easy way to plot the loss curves. - **How**: Add a new `lightning_utils.py` file with the read csv function, and call this method from `CAREamist`. ### Changes Made - **Added**: `lightning_utils.py`. - **Modified**: `CAREamist`. ### Related Issues Link to any related issues or discussions. Use keywords like "Fixes", "Resolves", or "Closes" to link to issues automatically. - Resolves #252 ### Additional Notes and Examples An alternative path would have been to add a `Callback` and do the logging ourselves. I decided for the solution that uses the `csv` file that is anyway created by default (when there is no WandB or TB loggers), to minimize the code that needs to be maintained. One potential issue is the particular csv file read is chosen following the experiment name recorded by `CAREamist` and the last `version_*`. This may not be true if the paths have changed, but in most cases it should be valid if called right after training. Here is what it looks like in the notebooks: ``` python import matplotlib.pyplot as plt losses = careamist.get_losses() plt.plot(losses["train_epoch"], losses["train_loss"], label="Train Loss") plt.plot(losses["val_epoch"], losses["val_loss"], label="Val Loss") ``` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features)
- Loading branch information
1 parent
8dcc447
commit b444084
Showing
5 changed files
with
129 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
"""PyTorch lightning utilities.""" | ||
|
||
from pathlib import Path | ||
from typing import Union | ||
|
||
|
||
def read_csv_logger(experiment_name: str, log_folder: Union[str, Path]) -> dict: | ||
"""Return the loss curves from the csv logs. | ||
Parameters | ||
---------- | ||
experiment_name : str | ||
Name of the experiment. | ||
log_folder : Path or str | ||
Path to the folder containing the csv logs. | ||
Returns | ||
------- | ||
dict | ||
Dictionary containing the loss curves, with keys "train_epoch", "val_epoch", | ||
"train_loss" and "val_loss". | ||
""" | ||
path = Path(log_folder) / experiment_name | ||
|
||
# find the most recent of version_* folders | ||
versions = [int(v.name.split("_")[-1]) for v in path.iterdir() if v.is_dir()] | ||
version = max(versions) | ||
|
||
path_log = path / f"version_{version}" / "metrics.csv" | ||
|
||
epochs = [] | ||
train_losses_tmp = [] | ||
val_losses_tmp = [] | ||
with open(path_log) as f: | ||
lines = f.readlines() | ||
|
||
for single_line in lines[1:]: | ||
epoch, _, train_loss, _, val_loss = single_line.strip().split(",") | ||
|
||
epochs.append(epoch) | ||
train_losses_tmp.append(train_loss) | ||
val_losses_tmp.append(val_loss) | ||
|
||
# train and val are not logged on the same row and can have different lengths | ||
train_epoch = [ | ||
int(epochs[i]) for i in range(len(epochs)) if train_losses_tmp[i] != "" | ||
] | ||
val_epoch = [int(epochs[i]) for i in range(len(epochs)) if val_losses_tmp[i] != ""] | ||
train_losses = [float(loss) for loss in train_losses_tmp if loss != ""] | ||
val_losses = [float(loss) for loss in val_losses_tmp if loss != ""] | ||
|
||
return { | ||
"train_epoch": train_epoch, | ||
"val_epoch": val_epoch, | ||
"train_loss": train_losses, | ||
"val_loss": val_losses, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import numpy as np | ||
|
||
from careamics import CAREamist, Configuration | ||
from careamics.utils import cwd | ||
from careamics.utils.lightning_utils import read_csv_logger | ||
|
||
|
||
def test_read_logger(tmp_path, minimum_configuration): | ||
|
||
config = Configuration(**minimum_configuration) | ||
config.training_config.num_epochs = 10 | ||
|
||
array = np.arange(32 * 32).reshape((32, 32)) | ||
|
||
with cwd(tmp_path): | ||
careamist = CAREamist(config) | ||
careamist.train(train_source=array) | ||
|
||
losses = read_csv_logger(config.experiment_name, tmp_path / "csv_logs") | ||
|
||
assert len(losses) == 4 | ||
for key in losses: | ||
assert len(losses[key]) == config.training_config.num_epochs |