From 1071ba11f6061947a03e833bdd192cc656353ada Mon Sep 17 00:00:00 2001 From: dima Date: Sun, 9 Feb 2020 17:01:46 +0300 Subject: [PATCH 1/6] allow to specify 'step' key --- pytorch_lightning/core/lightning.py | 5 +++-- pytorch_lightning/trainer/logging.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 66b887fa76374..4f3910e32a569 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -409,7 +409,8 @@ def validation_end(self, outputs): The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything. Any keys present in 'log', 'progress_bar' or the rest of the dictionary - are available for callbacks to access. + are available for callbacks to access. If you want to manually set current step, you can specify it with + 'step' key in the 'log' Dict. Example ------- @@ -459,7 +460,7 @@ def validation_end(self, outputs): # show val_loss and val_acc in progress bar but only log val_loss results = { 'progress_bar': tqdm_dict, - 'log': {'val_loss': val_loss_mean.item()} + 'log': {'val_loss': val_loss_mean.item(), 'step': self.current_epoch} } return results diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 2025bd511eb49..f9c6a5b4a6d00 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -38,14 +38,12 @@ def configure_logger(self, logger): self.logger.rank = 0 def log_metrics(self, metrics, grad_norm_dic, step=None): - """Logs the metric dict passed in. - + """Logs the metric dict passed in. If 'step' parameter is None and + 'step' key is presented is metrics, uses metrics['step'] as a step :param metrics: :param grad_norm_dic: + :param step: """ - # added metrics by Lightning for convenience - metrics['epoch'] = self.current_epoch - # add gpu memory if self.on_gpu and self.log_gpu_memory: mem_map = memory.get_memory_profile(self.log_gpu_memory) @@ -57,7 +55,12 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): # turn all tensors to scalars scalar_metrics = self.metrics_to_scalars(metrics) - step = step if step is not None else self.global_step + if "step" in scalar_metrics and step is None: + step = scalar_metrics.pop("step") + else: + # added metrics by Lightning for convenience + metrics['epoch'] = self.current_epoch + step = step if step is not None else self.global_step # log actual metrics if self.proc_rank == 0 and self.logger is not None: self.logger.log_metrics(scalar_metrics, step=step) From af72079322163c2f61360fe065098d4034e3425b Mon Sep 17 00:00:00 2001 From: Dmitry Lipin Date: Tue, 11 Feb 2020 12:32:32 +0300 Subject: [PATCH 2/6] add test --- tests/test_logging.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_logging.py b/tests/test_logging.py index 1b531420c8541..64847f2d08f4f 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -363,3 +363,33 @@ def version(self): assert logger.hparams_logged == hparams assert logger.metrics_logged != {} assert logger.finalized_status == "success" + + +def test_adding_step_key(tmpdir): + logged_step = 0 + + def validation_end(outputs): + nonlocal logged_step + logged_step += 1 + return {"log": {"step": logged_step, "val_acc": logged_step / 10}} + + def log_metrics_decorator(log_metrics_fn): + def decorated(metrics, step): + if "val_acc" in metrics: + assert metrics["step"] == logged_step + return log_metrics_fn(metrics, step) + + return decorated + + model, hparams = tutils.get_model() + model.validation_end = validation_end + trainer_options = dict( + max_epochs=4, + default_save_path=tmpdir, + train_percent_check=0.001, + val_percent_check=0.01, + num_sanity_val_steps=0 + ) + trainer = Trainer(**trainer_options) + trainer.logger.log_metrics = log_metrics_decorator(trainer.logger.log_metrics) + trainer.fit(model) From e0583a434831d2bd07df2258474700da8553fc60 Mon Sep 17 00:00:00 2001 From: Dmitry Lipin Date: Fri, 14 Feb 2020 17:41:00 +0300 Subject: [PATCH 3/6] docs to log_metrics --- pytorch_lightning/trainer/logging.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index f9c6a5b4a6d00..34b1c114b338a 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -38,11 +38,12 @@ def configure_logger(self, logger): self.logger.rank = 0 def log_metrics(self, metrics, grad_norm_dic, step=None): - """Logs the metric dict passed in. If 'step' parameter is None and - 'step' key is presented is metrics, uses metrics['step'] as a step - :param metrics: - :param grad_norm_dic: - :param step: + """Logs the metric dict passed in. + If `step` parameter is None and `step` key is presented is metrics, + uses metrics["step"] as a step + :param metrics (dict): Metric values + :param grad_norm_dic (dict): Gradient norms + :param step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` """ # add gpu memory if self.on_gpu and self.log_gpu_memory: From 68a31b36c8ab8b9c2f1f7961f0648c322389e0c4 Mon Sep 17 00:00:00 2001 From: Dmitry Lipin Date: Fri, 14 Feb 2020 18:11:31 +0300 Subject: [PATCH 4/6] fix test --- tests/test_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_logging.py b/tests/test_logging.py index 64847f2d08f4f..f9c207f428692 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -376,7 +376,7 @@ def validation_end(outputs): def log_metrics_decorator(log_metrics_fn): def decorated(metrics, step): if "val_acc" in metrics: - assert metrics["step"] == logged_step + assert step == logged_step return log_metrics_fn(metrics, step) return decorated From eb6d6498209e0a13c0d777ca8cf73b40d164faa2 Mon Sep 17 00:00:00 2001 From: Dmitry Lipin Date: Fri, 14 Feb 2020 18:48:57 +0300 Subject: [PATCH 5/6] rename --- tests/test_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_logging.py b/tests/test_logging.py index f9c207f428692..2ba3d67a2a590 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -368,7 +368,7 @@ def version(self): def test_adding_step_key(tmpdir): logged_step = 0 - def validation_end(outputs): + def _validation_end(outputs): nonlocal logged_step logged_step += 1 return {"log": {"step": logged_step, "val_acc": logged_step / 10}} @@ -382,7 +382,7 @@ def decorated(metrics, step): return decorated model, hparams = tutils.get_model() - model.validation_end = validation_end + model.validation_end = _validation_end trainer_options = dict( max_epochs=4, default_save_path=tmpdir, From a667fb829590b91efd1c29bef98bba6acd0c617c Mon Sep 17 00:00:00 2001 From: Dmitry Lipin Date: Fri, 14 Feb 2020 18:49:51 +0300 Subject: [PATCH 6/6] also rename --- tests/test_logging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_logging.py b/tests/test_logging.py index 2ba3d67a2a590..20f8b5d60c332 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -373,7 +373,7 @@ def _validation_end(outputs): logged_step += 1 return {"log": {"step": logged_step, "val_acc": logged_step / 10}} - def log_metrics_decorator(log_metrics_fn): + def _log_metrics_decorator(log_metrics_fn): def decorated(metrics, step): if "val_acc" in metrics: assert step == logged_step @@ -391,5 +391,5 @@ def decorated(metrics, step): num_sanity_val_steps=0 ) trainer = Trainer(**trainer_options) - trainer.logger.log_metrics = log_metrics_decorator(trainer.logger.log_metrics) + trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics) trainer.fit(model)