From bb1d7074dae905c4c1ef25d66fc016a9937daf12 Mon Sep 17 00:00:00 2001 From: Dmitry Lipin Date: Sun, 16 Feb 2020 07:35:23 +0300 Subject: [PATCH] Allow user to specify 'step' key while logging metrics (#808) * allow to specify 'step' key * add test * docs to log_metrics * fix test * rename * also rename --- pytorch_lightning/core/lightning.py | 5 +++-- pytorch_lightning/trainer/logging.py | 18 ++++++++++------- tests/test_logging.py | 30 ++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3274bb9bac987e..d39486e63508ea 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -415,7 +415,8 @@ def validation_end(self, outputs): The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything. Any keys present in 'log', 'progress_bar' or the rest of the dictionary - are available for callbacks to access. + are available for callbacks to access. If you want to manually set current step, you can specify it with + 'step' key in the 'log' Dict. Example ------- @@ -465,7 +466,7 @@ def validation_end(self, outputs): # show val_loss and val_acc in progress bar but only log val_loss results = { 'progress_bar': tqdm_dict, - 'log': {'val_loss': val_loss_mean.item()} + 'log': {'val_loss': val_loss_mean.item(), 'step': self.current_epoch} } return results diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 2025bd511eb494..34b1c114b338a8 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -39,13 +39,12 @@ def configure_logger(self, logger): def log_metrics(self, metrics, grad_norm_dic, step=None): """Logs the metric dict passed in. - - :param metrics: - :param grad_norm_dic: + If `step` parameter is None and `step` key is presented is metrics, + uses metrics["step"] as a step + :param metrics (dict): Metric values + :param grad_norm_dic (dict): Gradient norms + :param step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` """ - # added metrics by Lightning for convenience - metrics['epoch'] = self.current_epoch - # add gpu memory if self.on_gpu and self.log_gpu_memory: mem_map = memory.get_memory_profile(self.log_gpu_memory) @@ -57,7 +56,12 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): # turn all tensors to scalars scalar_metrics = self.metrics_to_scalars(metrics) - step = step if step is not None else self.global_step + if "step" in scalar_metrics and step is None: + step = scalar_metrics.pop("step") + else: + # added metrics by Lightning for convenience + metrics['epoch'] = self.current_epoch + step = step if step is not None else self.global_step # log actual metrics if self.proc_rank == 0 and self.logger is not None: self.logger.log_metrics(scalar_metrics, step=step) diff --git a/tests/test_logging.py b/tests/test_logging.py index c0166796ca0c58..0d4104ef7a34bc 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -376,3 +376,33 @@ def version(self): assert logger.hparams_logged == hparams assert logger.metrics_logged != {} assert logger.finalized_status == "success" + + +def test_adding_step_key(tmpdir): + logged_step = 0 + + def _validation_end(outputs): + nonlocal logged_step + logged_step += 1 + return {"log": {"step": logged_step, "val_acc": logged_step / 10}} + + def _log_metrics_decorator(log_metrics_fn): + def decorated(metrics, step): + if "val_acc" in metrics: + assert step == logged_step + return log_metrics_fn(metrics, step) + + return decorated + + model, hparams = tutils.get_model() + model.validation_end = _validation_end + trainer_options = dict( + max_epochs=4, + default_save_path=tmpdir, + train_percent_check=0.001, + val_percent_check=0.01, + num_sanity_val_steps=0 + ) + trainer = Trainer(**trainer_options) + trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics) + trainer.fit(model)