From 9bd19d32abbb0c6adfeaf06308b8efc5e420ea90 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 19 Oct 2020 21:39:30 +0100 Subject: [PATCH] add dataloader_idx logic directly in trainer --- pytorch_lightning/core/lightning.py | 7 +++++- pytorch_lightning/core/step_result.py | 12 +++++++-- .../trainer/connectors/logger_connector.py | 25 +++++++++---------- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 2 +- .../logging/test_eval_loop_logging_1_0.py | 16 +++++------- 6 files changed, 36 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9bbef7c586a822..30ed325d094570 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -251,6 +251,10 @@ def log( self.trainer.callback_connector.validate_callback_logging_arguments(self._current_hook_fx_name, on_step=on_step, on_epoch=on_epoch) + + # make sure user doesn't introduce logic for multi-dataloaders + if "/dataloader_idx_" in name: + raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.") self._results.log( name, @@ -265,7 +269,8 @@ def log( enable_graph, sync_dist, sync_dist_op, - sync_dist_group + sync_dist_group, + self._current_dataloader_idx, ) def log_dict( diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index a50586880a1c1d..acf860b706eb58 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -109,6 +109,11 @@ def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], add m += f' {additional_err}' assert x.grad_fn is not None, m + def add_dl_idx(self, name, dl_idx): + if dl_idx is not None: + name += f"/dataloader_idx_{dl_idx}" + return name + def log( self, name: str, @@ -124,6 +129,7 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + current_dataloader_idx: Optional[int] = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): @@ -143,7 +149,7 @@ def log( was_forked = True # set step version - step_name = f'{name}_step' + step_name = self.add_dl_idx(f'{name}_step', current_dataloader_idx) self.__set_meta( step_name, value, @@ -159,7 +165,7 @@ def log( self.__setitem__(step_name, value) # set epoch version - epoch_name = f'{name}_epoch' + epoch_name = self.add_dl_idx(f'{name}_epoch', current_dataloader_idx) self.__set_meta( epoch_name, value, @@ -174,6 +180,8 @@ def log( ) self.__setitem__(epoch_name, value) + name = self.add_dl_idx(name, current_dataloader_idx) + # always log the original metric self.__set_meta( name, diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index d506962e8acddd..41ba7330fc2262 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -184,11 +184,6 @@ def _track_callback_metrics_1_0(self, logs, metrics_to_log=[], reduce_on_epoch=F pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics() forked_metrics = reduced_epoch_metrics.get_forked_metrics() - # make the keys 'k/dl' - logger_metrics = self.__rename_keys_by_dataloader_idx(logger_metrics, dl_idx, num_loaders) - pbar_metrics = self.__rename_keys_by_dataloader_idx(pbar_metrics, dl_idx, num_loaders) - forked_metrics = self.__rename_keys_by_dataloader_idx(forked_metrics, dl_idx, num_loaders) - # track the metrics self.logged_metrics.update(logger_metrics) self.add_progress_bar_metrics(pbar_metrics) @@ -201,24 +196,28 @@ def _track_callback_metrics_1_0(self, logs, metrics_to_log=[], reduce_on_epoch=F self.callback_metrics.update(forked_metrics) # track the final results for the dataloader - self.eval_loop_results.append(deepcopy(self.callback_metrics)) + self.add_to_eval_loop_results(dl_idx) # actually log if len(logger_metrics) > 0: metrics_to_log.append(logger_metrics) + def add_to_eval_loop_results(self, dl_idx): + callback_metrics = deepcopy(self.callback_metrics) + for key in list(callback_metrics.keys()): + if "/dataloader_idx_" in key: + dl_idx_in_key = int(key.split("_")[-1]) + # remove dl_idx from self.callback_metrics not belonging to this dataset. + if dl_idx_in_key != dl_idx: + del callback_metrics[key] + self.eval_loop_results.append(callback_metrics) + + def log_on_evaluation_end(self, metrics_to_log): metrics_to_log = dict(ChainMap(*metrics_to_log)) if len(metrics_to_log) > 0: self.log_metrics(metrics_to_log, {}, step=self.trainer.global_step) - def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders): - if num_loaders == 1: - return metrics - - result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()} - return result - def _track_callback_metrics(self, eval_results, using_eval_result): if ( len(eval_results) > 0 and diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index affe1fd3008b25..210b74555d0fb9 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -347,7 +347,7 @@ def _unset_dataloader_idx(self): def set_dataloader_idx(self, dl_idx): # reset the result of the PL module model = self.trainer.get_model() - model._current_dataloader_idx = dl_idx if self.num_dataloaders > 1 else None + model._current_dataloader_idx = dl_idx if self.num_dataloaders > 1 else None def on_evaluation_batch_start(self, *args, **kwargs): # reset the result of the PL module diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 204521babe0402..65e4cb5760ff0c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -560,7 +560,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): dataloader = self.accelerator_backend.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] - # set dataloader idx inside model, so we can properly log + # set dataloader idx in pl_model, so we can handle multi dataloaders logging. self.evaluation_loop.set_dataloader_idx(dataloader_idx) for batch_idx, batch in enumerate(dataloader): diff --git a/tests/trainer/logging/test_eval_loop_logging_1_0.py b/tests/trainer/logging/test_eval_loop_logging_1_0.py index d037894ec3e886..d70dfc00440e63 100644 --- a/tests/trainer/logging/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging/test_eval_loop_logging_1_0.py @@ -589,27 +589,21 @@ def on_epoch_start(self, trainer, pl_module): def on_test_epoch_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) - def on_batch_start(self, trainer, pl_module): - self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) - def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - self.make_logging(pl_module, 'on_test_batch_start', 5, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) - - def on_batch_end(self, trainer, pl_module): - self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) + self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self.make_logging(pl_module, 'on_test_batch_end', 7, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) + self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) # used to make sure aggregation works fine. # we should obtain func[value * c for c in range(1, max_epochs * limit_test_batches)]) with func = np.mean if on_epoch else func = np.max self.count += 1 def on_epoch_end(self, trainer, pl_module): - self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) + self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) def on_test_epoch_end(self, trainer, pl_module): - self.make_logging(pl_module, 'on_test_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) + self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) max_epochs = 1 num_dataloaders = 2 @@ -644,6 +638,8 @@ def test_dataloader(self): trainer.fit(model) trainer.test() + + breakpoint() # Make sure the func_name exists within callback_metrics. If not, we missed some callback_metrics_keys = [*trainer.callback_metrics.keys()] for func_name in test_callback.callback_funcs_called.keys():