Skip to content

Commit

Permalink
[bug-fix] DDP and automatic_optimization=False (#4485)
Browse files Browse the repository at this point in the history
* resolve bug

* add self._running_manual_optim

* update

* update tests

* update lightning module

* resolve bug

* update tests

* update

* resolve pep8

* update

* replace by `ddp_spawn`

* temporary fix

* update

* update

* move update to training_loop

* make both ddp_spawn

* introduce `manual_optimizer_step`

* update changelog

* added changelog wrong place

* add force_optimizer_step

* update docstring for tests

* update optimizer_step

* update zero_grad

* resolve flake8

* move update into manual_optimizer_step

* add zero_grad

* remove zero_grad tests

* remove manual_backward in AMP, it doesn't help

* update

* loosen tests

* update

* update doc

* add TODO

* Removed unnecessary get model from native amp

* Remove try except with pytest raise

* Add seed, clean up imports, remove try catch to reproduce error

* update code

* update test

* revert back

* formatting

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Jirka Borovec <[email protected]>

Co-authored-by: SeanNaren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>

(cherry picked from commit 7e08b0d)
  • Loading branch information
tchaton authored and SeanNaren committed Nov 10, 2020
1 parent 13e95cd commit 10488dc
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 31 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ timit_data/
.Python
ide_layouts/
build/
_build/
develop-eggs/
dist/
downloads/
Expand Down
6 changes: 6 additions & 0 deletions docs/source/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,12 @@ manual_backward
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
:noindex:

manual_optimizer_step
~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
:noindex:

on_after_backward
~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 3 additions & 4 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@ to manually manage the optimization process. To do so, do the following:
# use self.backward which will also handle scaling the loss when using amp
self.manual_backward(loss_a, opt_g)
opt_g.step()
opt_g.zero_grad()
self.manual_optimizer_step(opt_g)
# do anything you want
loss_b = ...
# pass in any args that loss.backward() normally takes
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d)
opt_d.step()
opt_d.zero_grad()
self.manual_optimizer_step(opt_d)
# log losses
self.log('loss_a', loss_a)
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
native_amp = self.trainer.amp_backend == AMPType.NATIVE
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
automatic_optimization = self.trainer.train_loop.automatic_optimization

# native amp + lbfgs is a no go right now
if native_amp and is_lbfgs:
if using_native_amp and is_lbfgs:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
Expand All @@ -125,12 +126,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=native_amp,
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs
)

# scale when native amp
if native_amp:
if automatic_optimization and using_native_amp:
self.trainer.scaler.update()

def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
Expand Down
50 changes: 49 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, *args, **kwargs):
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._running_manual_backward = False

def optimizers(self):
opts = self.trainer.optimizers
Expand Down Expand Up @@ -1070,19 +1071,65 @@ def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -
.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set
and you use `model.manual_optimizer_step(optimizer)`
Example::
def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.manual_backward(loss, opt_a)
self.manual_optimizer_step(opt_a)
"""
# make sure we're using manual opt
self._verify_is_manual_optimization('manual_backward')

# backward
self._running_manual_backward = True
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)
self._running_manual_backward = False

def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None:
"""
Call this directly from your training_step when doing optimizations manually.
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set.
Args:
optimizer: Optimizer used to perform `.step()` call
force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers
and one should use accumulated gradients but not the other one.
One could put its own logic to force an optimizer step.
Example::
def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.manual_backward(loss, opt_a)
# This will force an opt.step() even if accumulate_grad_batches is set.
self.manual_optimizer_step(opt_a, force_optimizer_step=True)
"""
# make sure we're using manual opt
self._verify_is_manual_optimization('manual_optimizer_step')

if not self.trainer.train_loop.should_accumulate() or force_optimizer_step:

# mock closure function as the user is responsible to call `manual_backward`
def mock_optimizer_closure():
return

self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure)

# update will be called after every optimizer_step call
if self.trainer.amp_backend == AMPType.NATIVE:
self.trainer.scaler.update()

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
"""
Expand All @@ -1103,7 +1150,8 @@ def backward(self, loss, optimizer, optimizer_idx):
loss.backward()
"""
loss.backward(*args, **kwargs)
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
loss.backward(*args, **kwargs)

def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
"""
Expand Down
71 changes: 58 additions & 13 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
# when in dev debugging track the losses
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())

def _check_training_step_output(self, training_step_output):
if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization:
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")

def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
# give the PL module a result for logging
model = self.trainer.get_model()
Expand All @@ -312,6 +318,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
with self.trainer.profiler.profile("model_forward"):
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
training_step_output = self.trainer.accelerator_backend.training_step(args)
self._check_training_step_output(training_step_output)

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
Expand Down Expand Up @@ -612,6 +620,9 @@ def run_training_epoch(self):
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

# epoch end hook
self.run_on_epoch_end_hook(epoch_output)

# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers
Expand Down Expand Up @@ -723,6 +734,8 @@ def train_step_and_backward_closure():

if self._curr_step_result is None:
# user decided to skip optimization
# make sure to zero grad.
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
continue

batch_outputs = self._process_closure_result(
Expand All @@ -735,20 +748,11 @@ def train_step_and_backward_closure():
grad_norm_dic = self._cur_grad_norm_dict
self._cur_grad_norm_dict = None

# hook
self.on_before_zero_grad(optimizer)

# clear gradients
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
# hook + clear gradients
self.zero_grad_handler(batch_idx, optimizer, opt_idx)

accumulated_loss = self.accumulated_loss.mean()

if accumulated_loss is not None:
# calculate running loss for display
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)

# reset for next set of accumulated grads
self.accumulated_loss.reset()
# update running loss + reset accumulated loss
self.update_running_loss()

# collapse all metrics into one dict
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
Expand Down Expand Up @@ -949,3 +953,44 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
epoch_end_outputs.append(optimizer_idx_outputs)

return epoch_end_outputs

def prepare_optimizers(self):
# in manual optimization we loop over all optimizers at once
optimizers = self.get_optimizers_iterable()
if not self.automatic_optimization:
optimizers = [optimizers[0]]
return optimizers

def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
# set split_idx to trainer for tracking
self.trainer.split_idx = split_idx

# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if self.automatic_optimization and len(self.trainer.optimizers) > 1:
model = self.trainer.get_model()
model.toggle_optimizer(optimizer, opt_idx)

# use to track metrics internally
self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch)

def update_running_loss(self):
accumulated_loss = self.accumulated_loss.mean()

if accumulated_loss is not None:
# calculate running loss for display
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)

# reset for next set of accumulated grads
self.accumulated_loss.reset()

def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
if self.automatic_optimization:
# hook
self.on_before_zero_grad(optimizer)
optimizers = enumerate([optimizer])
else:
optimizers = self.get_optimizers_iterable()

for idx, optimizer in optimizers:
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
Loading

0 comments on commit 10488dc

Please sign in to comment.