From 7a8b0ea7e016e886903e25eddb45689ab54c0ae2 Mon Sep 17 00:00:00 2001 From: shabie <30535146+shabie@users.noreply.github.com> Date: Thu, 18 Nov 2021 18:29:13 +0100 Subject: [PATCH] log metrics for correct dataloader only (#10522) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: tchaton Co-authored-by: Carlos MocholĂ­ --- .../logger_connector/logger_connector.py | 17 +++++++++--- .../logging_/test_eval_loop_logging.py | 27 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 37fcb06a1dc24a..640fc667705a8d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -154,6 +154,19 @@ def update_eval_step_metrics(self) -> None: # increment the step even if nothing was logged self._increment_eval_log_step() + @staticmethod + def _filter_metrics_for_dataloader( + dl_idx: int, metrics: Dict[str, Union[Any, Dict[str, Any]]], metric_prefix: str = "dataloader_idx" + ) -> Dict[str, Union[Any, Dict[str, Any]]]: + result = {} + for k, v in metrics.items(): + if metric_prefix not in k: + result[k] = v + continue + if k.endswith(f"{metric_prefix}_{dl_idx}"): + result[k] = v + return result + def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: if self.trainer.sanity_checking: return @@ -162,9 +175,7 @@ def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: has_been_initialized = len(self.eval_loop_results) == num_dataloaders for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders): # remove callback metrics that don't belong to this dataloader - callback_metrics = { - k: v for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k - } + callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics) if has_been_initialized: self.eval_loop_results[dl_idx].update(callback_metrics) else: diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 6ed40b5f030827..88229effbc8c9a 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -23,6 +23,7 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -672,3 +673,29 @@ def val_dataloader(self): enable_model_summary=False, ) trainer.fit(model) + + +@pytest.mark.parametrize( + ["kwargs", "expected"], + [ + ({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}), + ( + {"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}}, + {"acc/dataloader_idx_0": 123}, + ), + ( + {"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}}, + {"acc/dataloader_idx_10": 321}, + ), + ( + {"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}}, + {"top_3_acc/dataloader_idx_3": 321}, + ), + # theoretical case, as `/dataloader_idx_3` would have been added + ({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}), + ], +) +def test_filter_metrics_for_dataloader(kwargs, expected): + """Logged metrics should only include metrics from the concerned dataloader.""" + actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs) + assert actual == expected