Skip to content

Commit

Permalink
ref: decouple apex second attemp part 9/n (#4063)
Browse files Browse the repository at this point in the history
* ref: decouple apex second attemp part 9/n

* ref: decouple apex second attemp part 9/n
  • Loading branch information
williamFalcon authored Oct 10, 2020
1 parent e3717ed commit dbfe2b6
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 20 deletions.
8 changes: 4 additions & 4 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ For advanced research topics like reinforcement learning, sparse coding, or GAN
to manually manage the optimization process. To do so, do the following:

* Ignore the optimizer_idx argument
* So we can scale the loss automatically for you use self.backward(loss) instead of loss.backward()
* So we can scale the loss automatically for you use self.manual_backward(loss) instead of loss.backward()

.. code-block:: python
Expand All @@ -34,16 +34,16 @@ to manually manage the optimization process. To do so, do the following:
loss_a = ...
# use self.backward which will also handle scaling the loss when using amp
self.backward(loss_a, opt_g)
self.manual_backward(loss_a, opt_g)
opt_g.step()
opt_g.zero_grad()
# do anything you want
loss_b = ...
# pass in any args that loss.backward() normally takes
self.backward(loss_b, opt_d, retain_graph=True)
self.backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d, retain_graph=True)
loss_b.step()
loss_b.zero_grad()
Expand Down
25 changes: 23 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def configure_optimizers(self):
"`configure_optimizers` must be implemented to be used with the Lightning Trainer"
)

def backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
"""
Call this directly from your training_step when doing optimizations manually.
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
Expand All @@ -1051,10 +1051,31 @@ def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.backward(loss, opt_a)
self.manual_backward(loss, opt_a)
"""
self.trainer.train_loop.backward(loss, optimizer, *args, **kwargs)

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Override backward with your own implementation if you need to.
Args:
loss: Loss is already scaled by accumulated grads
optimizer: Current optimizer being used
optimizer_idx: Index of the current optimizer being used
Called to perform backward step.
Feel free to override as needed.
The loss passed in has already been scaled for accumulated gradients if requested.
Example::
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
"""
loss.backward()

def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""
Makes sure only the gradients of the current optimizer's parameters are calculated
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def training_step(self, batch, batch_idx):
(opt) = self.optimizers()
loss = ...
self.backward(loss, opt)
self.manual_backward(loss, opt)
opt.step()
opt.zero_grad()
Expand All @@ -311,12 +311,12 @@ def training_step(self, batch, batch_idx, optimizer_idx):
(opt_a, opt_b) = self.optimizers()
gen_loss = ...
self.backward(gen_loss, opt_a)
self.manual_backward(gen_loss, opt_a)
opt_a.step()
opt_a.zero_grad()
disc_loss = ...
self.backward(disc_loss, opt_b)
self.manual_backward(disc_loss, opt_b)
opt_b.step()
opt_b.zero_grad()
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/dynamic_args/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def training_step(self, batch, batch_idx, optimizer_idx):
loss_1 = self.step(batch[0])

# fake generator
self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()

# fake discriminator
loss_2 = self.step(batch[0])
self.backward(loss_2, opt_b)
self.manual_backward(loss_2, opt_b)
opt_b.step()
opt_b.zero_grad()

Expand Down
18 changes: 9 additions & 9 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)

self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
Expand All @@ -33,8 +33,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)

assert self.layer.weight.grad is not None
opt_b.step()
Expand Down Expand Up @@ -87,7 +87,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)

self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
Expand All @@ -97,8 +97,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)

assert self.layer.weight.grad is not None
opt_b.step()
Expand Down Expand Up @@ -157,7 +157,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)

self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
Expand All @@ -168,8 +168,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)

assert self.layer.weight.grad is not None
opt_b.step()
Expand Down

0 comments on commit dbfe2b6

Please sign in to comment.