From f5216bbf98f1e0781d143fa525c48627fa68f052 Mon Sep 17 00:00:00 2001 From: Anand Date: Thu, 23 Jan 2020 12:16:58 +0900 Subject: [PATCH] Added optimizer_idx to backward call --- pytorch_lightning/core/hooks.py | 3 ++- pytorch_lightning/trainer/training_loop.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index dc5a502c5f89a..c89b82e00a101 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -124,12 +124,13 @@ def on_after_backward(self): """ pass - def backward(self, use_amp, loss, optimizer): + def backward(self, use_amp, loss, optimizer, optimizer_idx): """Override backward with your own implementation if you need to :param use_amp: Whether amp was requested or not :param loss: Loss is already scaled by accumulated grads :param optimizer: Current optimizer being used + :param optimizer_idx: Index of the current optimizer being used :return: Called to perform backward step. diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d5f16c9462bf1..9f07fb523b4e6 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -486,13 +486,14 @@ def optimizer_closure(): # backward pass model_ref = self.get_model() - model_ref.backward(self.use_amp, closure_loss, optimizer) + model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx) # track metrics for callbacks all_callback_metrics.append(callback_metrics) # track progress bar metrics self.add_tqdm_metrics(progress_bar_metrics) + self.add_tqdm_metrics(progress_bar_metrics) all_log_metrics.append(log_metrics) # insert after step hook