Skip to content

Commit

Permalink
ref: decouple apex second attemp part 10/n (#4064)
Browse files Browse the repository at this point in the history
* ref: decouple apex second attemp part 9/n

* ref: decouple apex second attemp part 9/n

* ref: decouple apex second attemp part 9/n
  • Loading branch information
williamFalcon authored Oct 11, 2020
1 parent dbfe2b6 commit 0281b07
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 12 deletions.
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/base_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ def validation_step_end(self, output):
def process_dataloader(self, dataloader):
return dataloader

def backward(self, closure_loss, optimizer, *args, **kwargs):
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
if self.trainer.precision == 16:
closure_loss = self.trainer.precision_connector.backend.backward(closure_loss, optimizer, *args, **kwargs)
closure_loss = self.trainer.precision_connector.backend.backward(
closure_loss, optimizer, opt_idx, *args, **kwargs
)
else:
# do backward pass
closure_loss.backward(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def test_step(self, args):
output = self.trainer.model.test_step(*args)
return output

def backward(self, closure_loss, optimizer, *args, **kwargs):
super().backward(closure_loss, optimizer, *args, **kwargs)
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)
optimizer.synchronize()

def on_train_epoch_end(self, outputs):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __setup_tpu_training(self, model: LightningModule, trainer):
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')

def backward(self, closure_loss, optimizer, *args, **kwargs):
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# do backward pass
closure_loss.backward(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ def training_step(...):
# automatically applies scaling, etc...
self.manual_backward(loss, opt_a)
"""
self.trainer.train_loop.backward(loss, optimizer, *args, **kwargs)
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def training_step(self, fx, args):
output = fx(args)
return output

def backward(self, closure_loss, optimizer, *args, **kwargs):
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
closure_loss = amp.scale_loss(closure_loss, optimizer)

# enter apex context
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, trainer):
def connect(self, model, optimizers):
return model, optimizers

def backward(self, closure_loss, optimizer, *args, **kwargs):
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
closure_loss = self.trainer.scaler.scale(closure_loss)

# do backward pass
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,23 +750,24 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,

# backward pass
with self.trainer.profiler.profile('model_backward'):
self.backward(result, optimizer)
self.backward(result, optimizer, opt_idx)

# hook
self.on_after_backward(result.training_step_output, batch_idx, result.loss)

return result

def backward(self, result, optimizer, *args, **kwargs):
def backward(self, result, optimizer, opt_idx, *args, **kwargs):
self.trainer.dev_debugger.track_event('backward_call')

# backward can be called manually in the training loop.
# backward can be called manually in the training loop
if isinstance(result, torch.Tensor):
self.trainer.accelerator_backend.backward(result, optimizer, *args, **kwargs)
self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs)
else:
result.closure_loss = self.trainer.accelerator_backend.backward(
result.closure_loss,
optimizer,
opt_idx,
*args,
**kwargs
)
Expand Down

0 comments on commit 0281b07

Please sign in to comment.