Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
log metrics for correct dataloader only (Lightning-AI#10522)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored and Raalsky committed Nov 23, 2021
1 parent 6af991d commit 7a8b0ea
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ def update_eval_step_metrics(self) -> None:
# increment the step even if nothing was logged
self._increment_eval_log_step()

@staticmethod
def _filter_metrics_for_dataloader(
dl_idx: int, metrics: Dict[str, Union[Any, Dict[str, Any]]], metric_prefix: str = "dataloader_idx"
) -> Dict[str, Union[Any, Dict[str, Any]]]:
result = {}
for k, v in metrics.items():
if metric_prefix not in k:
result[k] = v
continue
if k.endswith(f"{metric_prefix}_{dl_idx}"):
result[k] = v
return result

def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None:
if self.trainer.sanity_checking:
return
Expand All @@ -162,9 +175,7 @@ def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None:
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders):
# remove callback metrics that don't belong to this dataloader
callback_metrics = {
k: v for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k
}
callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics)
if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset

Expand Down Expand Up @@ -672,3 +673,29 @@ def val_dataloader(self):
enable_model_summary=False,
)
trainer.fit(model)


@pytest.mark.parametrize(
["kwargs", "expected"],
[
({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}),
(
{"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}},
{"acc/dataloader_idx_0": 123},
),
(
{"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}},
{"acc/dataloader_idx_10": 321},
),
(
{"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}},
{"top_3_acc/dataloader_idx_3": 321},
),
# theoretical case, as `/dataloader_idx_3` would have been added
({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}),
],
)
def test_filter_metrics_for_dataloader(kwargs, expected):
"""Logged metrics should only include metrics from the concerned dataloader."""
actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs)
assert actual == expected

0 comments on commit 7a8b0ea

Please sign in to comment.