Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: decouple apex second attemp part 9/n #4063

Merged
merged 2 commits into from
Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ For advanced research topics like reinforcement learning, sparse coding, or GAN
to manually manage the optimization process. To do so, do the following:

* Ignore the optimizer_idx argument
* So we can scale the loss automatically for you use self.backward(loss) instead of loss.backward()
* So we can scale the loss automatically for you use self.manual_backward(loss) instead of loss.backward()

.. code-block:: python

Expand All @@ -34,16 +34,16 @@ to manually manage the optimization process. To do so, do the following:
loss_a = ...

# use self.backward which will also handle scaling the loss when using amp
self.backward(loss_a, opt_g)
self.manual_backward(loss_a, opt_g)
opt_g.step()
opt_g.zero_grad()

# do anything you want
loss_b = ...

# pass in any args that loss.backward() normally takes
self.backward(loss_b, opt_d, retain_graph=True)
self.backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d, retain_graph=True)
loss_b.step()
loss_b.zero_grad()

Expand Down
25 changes: 23 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def configure_optimizers(self):
"`configure_optimizers` must be implemented to be used with the Lightning Trainer"
)

def backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
"""
Call this directly from your training_step when doing optimizations manually.
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
Expand All @@ -1051,10 +1051,31 @@ def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.backward(loss, opt_a)
self.manual_backward(loss, opt_a)
"""
self.trainer.train_loop.backward(loss, optimizer, *args, **kwargs)

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Override backward with your own implementation if you need to.

Args:
loss: Loss is already scaled by accumulated grads
optimizer: Current optimizer being used
optimizer_idx: Index of the current optimizer being used

Called to perform backward step.
Feel free to override as needed.
The loss passed in has already been scaled for accumulated gradients if requested.

Example::

def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()

"""
loss.backward()

def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""
Makes sure only the gradients of the current optimizer's parameters are calculated
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def training_step(self, batch, batch_idx):
(opt) = self.optimizers()

loss = ...
self.backward(loss, opt)
self.manual_backward(loss, opt)
opt.step()
opt.zero_grad()

Expand All @@ -311,12 +311,12 @@ def training_step(self, batch, batch_idx, optimizer_idx):
(opt_a, opt_b) = self.optimizers()

gen_loss = ...
self.backward(gen_loss, opt_a)
self.manual_backward(gen_loss, opt_a)
opt_a.step()
opt_a.zero_grad()

disc_loss = ...
self.backward(disc_loss, opt_b)
self.manual_backward(disc_loss, opt_b)
opt_b.step()
opt_b.zero_grad()

Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/dynamic_args/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def training_step(self, batch, batch_idx, optimizer_idx):
loss_1 = self.step(batch[0])

# fake generator
self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()

# fake discriminator
loss_2 = self.step(batch[0])
self.backward(loss_2, opt_b)
self.manual_backward(loss_2, opt_b)
opt_b.step()
opt_b.zero_grad()

Expand Down
18 changes: 9 additions & 9 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)

self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
Expand All @@ -33,8 +33,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)

assert self.layer.weight.grad is not None
opt_b.step()
Expand Down Expand Up @@ -87,7 +87,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)

self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
Expand All @@ -97,8 +97,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)

assert self.layer.weight.grad is not None
opt_b.step()
Expand Down Expand Up @@ -157,7 +157,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)

self.backward(loss_1, opt_a)
self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)
Expand All @@ -168,8 +168,8 @@ def training_step(self, batch, batch_idx, optimizer_idx):

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.backward(loss_2, opt_b, retain_graph=True)
self.backward(loss_2, opt_a, retain_graph=True)
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)

assert self.layer.weight.grad is not None
opt_b.step()
Expand Down