Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add lambda closure to manual_optimizer_step #4618

Merged
merged 22 commits into from
Nov 12, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ mlruns/
pytorch\ lightning
test-reports/
wandb
.forked/
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
closure_loss = closure_loss.detach()
return closure_loss

def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs):
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
Expand All @@ -127,7 +127,9 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs
using_lbfgs=is_lbfgs,
*args,
**kwargs,
)

# scale when native amp
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

return closure_loss

def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs):
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)

Expand All @@ -258,7 +258,9 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
optimizer_closure=lambda_closure,
on_tpu=True,
using_native_amp=False,
using_lbfgs=is_lbfgs
using_lbfgs=is_lbfgs,
*args,
**kwargs,
)

def clip_gradients(self, optimizer, clip_val=None):
Expand Down
144 changes: 125 additions & 19 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import inspect
import re
import types
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping
Expand Down Expand Up @@ -1106,7 +1107,12 @@ def training_step(...):
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)
self._running_manual_backward = False

def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None:
def manual_optimizer_step(self,
optimizer: Optimizer,
*args,
make_optimizer_step: Optional[bool] = None,
optimizer_closure: Optional[Callable] = None,
** kwargs) -> None:
"""
Call this directly from your training_step when doing optimizations manually.
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
Expand All @@ -1116,36 +1122,132 @@ def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool
Args:
optimizer: Optimizer used to perform `.step()` call

force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers
and one should use accumulated gradients but not the other one.
One could put its own logic to force an optimizer step.
make_optimizer_step: Whether to force an optimizer step. When not provided, we will use `accumulate_grad_batches`
for accumulation frequency. However, one coud provide True and False based on its own scheduling.
c.f example 2 and 3

optimizer_closure: One could provide its own optimizer_closure. Set to None by default.

args: Any parameters provided to optimizer.step()

kwargs: Any parameters provided to optimizer.step()

Example::

def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...

self.manual_backward(loss, opt_a)
# This will force an opt.step() even if accumulate_grad_batches is set.
self.manual_optimizer_step(opt_a, force_optimizer_step=True)

# This will use accumulate gradients for `accumulate_grad_batches` batches
# and then run opt_a.step()
self.manual_optimizer_step(opt_a)

Example::

def training_step(self, batch, batch_idx):
# using Boring Model
opt = self.optimizers() # only 1 optimizer

def compute_loss():
x = batch[0]
x = F.dropout(x, 0.1)
predictions = self(x)
predictions = F.dropout(predictions, 0.1)
loss = self.loss(None, predictions)
return loss

def optimizer_closure():
# emulate MC dropout training
num_backward = 1
losses = []
for backward_idx in range(num_backward + 1):
loss = compute_loss()
losses.append(loss)
retain_graph = num_backward!= backward_idx
self.manual_backward(loss, opt, retain_graph=retain_graph)
loss_mean = torch.stack(losses).mean()
loss_std = torch.stack(losses).std()
self.log("train_loss_mean", loss_mean, on_step=True, prog_bar=True, on_epoch=True)
self.log("train_loss_std", loss_std, on_step=True, prog_bar=True, on_epoch=True)

self.manual_optimizer_step(opt, optimizer_closure=optimizer_closure)

Example::

# Scenario for a gan.

def training_step(self, batch, batch_idx, optimizer_idx):

# emulate gans training
opt_gen, opt_dis = self.optimizers()

# Note: Be careful, don't log on the same key in self.log in both closure
# as they will be aggregated together on epoch_end

def gen_closure():
... forward and compute loss for generator
loss_gen = ...
self.log("loss_gen", loss_gen, on_step=True, on_epoch=True)
self.manual_backward(loss_gen, opt_gen)

def dis_closure():
... forward and compute loss for discriminator
loss_dis = ...
self.log("loss_dis", loss_dis, on_step=True, on_epoch=True)
self.manual_backward(loss_dis, opt_dis)

# this will accumulate gradients for 2 batches and then call opt_gen.step()
self.manual_optimizer_step(
opt_gen,
optimizer_closure=gen_closure,
make_optimizer_step=batch_idx % 2 == 0)

# update discriminator every 4 batches
# therefore, no gradient accumulation for discriminator
if batch_idx % 4 == 0 :
# Note: Set make_optimizer_step to True or it will use by default
# Trainer(accumulate_grad_batches=x)
self.manual_optimizer_step(
opt_dis,
optimizer_closure=dis_closure,
make_optimizer_step=True)
"""
# make sure we're using manual opt
self._verify_is_manual_optimization('manual_optimizer_step')

if not self.trainer.train_loop.should_accumulate() or force_optimizer_step:
should_make_optimizer_step = not self.trainer.train_loop.should_accumulate()
make_optimizer_step = make_optimizer_step if make_optimizer_step is not None else should_make_optimizer_step

if make_optimizer_step:

# mock closure function as the user is responsible to call `manual_backward`
def mock_optimizer_closure():
return

self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure)
optimizer_closure = optimizer_closure if isinstance(optimizer_closure, types.FunctionType) else mock_optimizer_closure

self.trainer.train_loop.optimizer_step(optimizer,
None,
self.trainer.batch_idx,
optimizer_closure,
*args,
**kwargs)

# update will be called after every optimizer_step call
if self.trainer.amp_backend == AMPType.NATIVE:
self.trainer.scaler.update()

# perform zero grad
optimizer.zero_grad()

else:
# make sure to call optimizer_closure when accumulating
if isinstance(optimizer_closure, types.FunctionType):
optimizer_closure()

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
"""
Override backward with your own implementation if you need to.
Expand Down Expand Up @@ -1190,21 +1292,25 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):

def optimizer_step(
self,
epoch: int,
batch_idx: int,
optimizer: Optimizer,
optimizer_idx: int,
optimizer_closure: Optional[Callable],
on_tpu: bool,
using_native_amp: bool,
using_lbfgs: bool,
*args,
epoch: int = None,
batch_idx: int = None,
optimizer: Optimizer = None,
optimizer_idx: int = None,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = None,
using_native_amp: bool = None,
using_lbfgs: bool = None,
**kwargs,
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
once per optimizer.

.. tip:: Consider using `manual_optimizer_step` instead of overriding this method as done previously.

Warning:
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
to ``optimizer.step()`` function as shown in the examples. This ensures that
Expand Down Expand Up @@ -1272,7 +1378,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,

"""
if on_tpu:
xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure})
xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure, **kwargs})
elif self.trainer.amp_backend == AMPType.NATIVE:
# native amp does not yet support closures.
# TODO: pass the closure to the step ASAP
Expand All @@ -1282,9 +1388,9 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
# apex amp does not yet support closures.
# TODO: pass the closure to the step ASAP
optimizer_closure()
optimizer.step()
optimizer.step(*args, **kwargs)
else:
optimizer.step(closure=optimizer_closure)
optimizer.step(closure=optimizer_closure, *args, **kwargs)

def optimizer_zero_grad(
self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,11 +471,11 @@ def _process_result(self, training_step_output, split_batch):

return training_step_output_for_epoch_end

def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure):
def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs):
with self.trainer.profiler.profile("optimizer_step"):
# optimizer step lightningModule hook
self.trainer.accelerator_backend.optimizer_step(
optimizer, batch_idx, opt_idx, train_step_and_backward_closure
optimizer, batch_idx, opt_idx, train_step_and_backward_closure, *args, **kwargs
)

def on_before_zero_grad(self, optimizer):
Expand Down Expand Up @@ -945,7 +945,8 @@ def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
self.on_before_zero_grad(optimizer)
optimizers = enumerate([optimizer])
else:
optimizers = self.get_optimizers_iterable()
# should be called handled in `manual_optimizer_step`
optimizers = []

for idx, optimizer in optimizers:
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
Loading