diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 37fcb06a1dc24a..640fc667705a8d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -154,6 +154,19 @@ def update_eval_step_metrics(self) -> None: # increment the step even if nothing was logged self._increment_eval_log_step() + @staticmethod + def _filter_metrics_for_dataloader( + dl_idx: int, metrics: Dict[str, Union[Any, Dict[str, Any]]], metric_prefix: str = "dataloader_idx" + ) -> Dict[str, Union[Any, Dict[str, Any]]]: + result = {} + for k, v in metrics.items(): + if metric_prefix not in k: + result[k] = v + continue + if k.endswith(f"{metric_prefix}_{dl_idx}"): + result[k] = v + return result + def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: if self.trainer.sanity_checking: return @@ -162,9 +175,7 @@ def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None: has_been_initialized = len(self.eval_loop_results) == num_dataloaders for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders): # remove callback metrics that don't belong to this dataloader - callback_metrics = { - k: v for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k - } + callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics) if has_been_initialized: self.eval_loop_results[dl_idx].update(callback_metrics) else: diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 6ed40b5f030827..88229effbc8c9a 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -23,6 +23,7 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -672,3 +673,29 @@ def val_dataloader(self): enable_model_summary=False, ) trainer.fit(model) + + +@pytest.mark.parametrize( + ["kwargs", "expected"], + [ + ({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}), + ( + {"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}}, + {"acc/dataloader_idx_0": 123}, + ), + ( + {"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}}, + {"acc/dataloader_idx_10": 321}, + ), + ( + {"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}}, + {"top_3_acc/dataloader_idx_3": 321}, + ), + # theoretical case, as `/dataloader_idx_3` would have been added + ({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}), + ], +) +def test_filter_metrics_for_dataloader(kwargs, expected): + """Logged metrics should only include metrics from the concerned dataloader.""" + actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs) + assert actual == expected