From 0281b077d8beefa731c11ab3626e151ca79d0fda Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 10 Oct 2020 20:05:05 -0400 Subject: [PATCH] ref: decouple apex second attemp part 10/n (#4064) * ref: decouple apex second attemp part 9/n * ref: decouple apex second attemp part 9/n * ref: decouple apex second attemp part 9/n --- pytorch_lightning/accelerators/base_accelerator.py | 6 ++++-- pytorch_lightning/accelerators/horovod_backend.py | 4 ++-- pytorch_lightning/accelerators/tpu_backend.py | 2 +- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/plugins/apex.py | 2 +- pytorch_lightning/plugins/native_amp.py | 2 +- pytorch_lightning/trainer/training_loop.py | 9 +++++---- 7 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/accelerators/base_accelerator.py b/pytorch_lightning/accelerators/base_accelerator.py index d55ad045aeb1f..bfa0006b4e3ad 100644 --- a/pytorch_lightning/accelerators/base_accelerator.py +++ b/pytorch_lightning/accelerators/base_accelerator.py @@ -71,9 +71,11 @@ def validation_step_end(self, output): def process_dataloader(self, dataloader): return dataloader - def backward(self, closure_loss, optimizer, *args, **kwargs): + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): if self.trainer.precision == 16: - closure_loss = self.trainer.precision_connector.backend.backward(closure_loss, optimizer, *args, **kwargs) + closure_loss = self.trainer.precision_connector.backend.backward( + closure_loss, optimizer, opt_idx, *args, **kwargs + ) else: # do backward pass closure_loss.backward(*args, **kwargs) diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index ef6b937ca176e..69e9bc72b6c3f 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -150,8 +150,8 @@ def test_step(self, args): output = self.trainer.model.test_step(*args) return output - def backward(self, closure_loss, optimizer, *args, **kwargs): - super().backward(closure_loss, optimizer, *args, **kwargs) + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): + super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs) optimizer.synchronize() def on_train_epoch_end(self, outputs): diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index c6f4b95fdfa38..4faebf4a3df9b 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -222,7 +222,7 @@ def __setup_tpu_training(self, model: LightningModule, trainer): f' global rank: {trainer.tpu_global_core_rank}' f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}') - def backward(self, closure_loss, optimizer, *args, **kwargs): + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): # do backward pass closure_loss.backward(*args, **kwargs) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dd4f5787d4d0f..96ab33de8e872 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1053,7 +1053,7 @@ def training_step(...): # automatically applies scaling, etc... self.manual_backward(loss, opt_a) """ - self.trainer.train_loop.backward(loss, optimizer, *args, **kwargs) + self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs) def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None: """ diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index e09d6c42c5a32..f4597bdf8d10d 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -36,7 +36,7 @@ def training_step(self, fx, args): output = fx(args) return output - def backward(self, closure_loss, optimizer, *args, **kwargs): + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): closure_loss = amp.scale_loss(closure_loss, optimizer) # enter apex context diff --git a/pytorch_lightning/plugins/native_amp.py b/pytorch_lightning/plugins/native_amp.py index 14727f0bb0943..ac19f187822cb 100644 --- a/pytorch_lightning/plugins/native_amp.py +++ b/pytorch_lightning/plugins/native_amp.py @@ -23,7 +23,7 @@ def __init__(self, trainer): def connect(self, model, optimizers): return model, optimizers - def backward(self, closure_loss, optimizer, *args, **kwargs): + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): closure_loss = self.trainer.scaler.scale(closure_loss) # do backward pass diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e655518567a03..2c8a5b1f48809 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -750,23 +750,24 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, # backward pass with self.trainer.profiler.profile('model_backward'): - self.backward(result, optimizer) + self.backward(result, optimizer, opt_idx) # hook self.on_after_backward(result.training_step_output, batch_idx, result.loss) return result - def backward(self, result, optimizer, *args, **kwargs): + def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.dev_debugger.track_event('backward_call') - # backward can be called manually in the training loop. + # backward can be called manually in the training loop if isinstance(result, torch.Tensor): - self.trainer.accelerator_backend.backward(result, optimizer, *args, **kwargs) + self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator_backend.backward( result.closure_loss, optimizer, + opt_idx, *args, **kwargs )