diff --git a/.gitignore b/.gitignore index 946d5f0f4c2ca..743fdaaf33dc2 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,4 @@ mlruns/ pytorch\ lightning test-reports/ wandb +.forked/ diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a0d8f6f21a2f7..e765c2ab626df 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -98,7 +98,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): closure_loss = closure_loss.detach() return closure_loss - def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): + def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) using_native_amp = self.trainer.amp_backend == AMPType.NATIVE @@ -119,7 +119,9 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): optimizer_closure=lambda_closure, on_tpu=False, # TPUAccelerator class sets this as True using_native_amp=using_native_amp, - using_lbfgs=is_lbfgs + using_lbfgs=is_lbfgs, + *args, + **kwargs, ) # scale when native amp diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 54ee57b74a16a..cc7da4dc10781 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -247,7 +247,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): return closure_loss - def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): + def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure, *args, **kwargs): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) @@ -260,7 +260,9 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): optimizer_closure=lambda_closure, on_tpu=True, using_native_amp=False, - using_lbfgs=is_lbfgs + using_lbfgs=is_lbfgs, + *args, + **kwargs, ) def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a332c0dcaa99a..acfdc107cc7e0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -17,6 +17,7 @@ import copy import inspect import re +import types from abc import ABC from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping @@ -969,7 +970,8 @@ def configure_optimizers( - Single optimizer. - List or Tuple - List of optimizers. - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict). - - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' key which value is a single LR scheduler or lr_dict. + - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' + key which value is a single LR scheduler or lr_dict. - Tuple of dictionaries as described, with an optional 'frequency' key. - None - Fit will run without any optimizer. @@ -1086,8 +1088,8 @@ def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) - .. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set - .. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set - and you use `model.manual_optimizer_step(optimizer)` + .. tip:: In manual mode we still automatically accumulate grad over batches if + Trainer(accumulate_grad_batches=x) is set and you use `model.manual_optimizer_step(optimizer)` Example:: @@ -1106,19 +1108,32 @@ def training_step(...): self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs) self._running_manual_backward = False - def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None: + def manual_optimizer_step(self, + optimizer: Optimizer, + *args, + make_optimizer_step: Optional[bool] = None, + optimizer_closure: Optional[Callable] = None, + ** kwargs) -> None: """ Call this directly from your training_step when doing optimizations manually. By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you - .. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set. + .. tip:: In manual mode we still automatically accumulate grad over batches if + Trainer(accumulate_grad_batches=x) is set. Args: optimizer: Optimizer used to perform `.step()` call - force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers - and one should use accumulated gradients but not the other one. - One could put its own logic to force an optimizer step. + make_optimizer_step: Whether to force an optimizer step. When nothing is provided, + we will use `accumulate_grad_batches` for accumulation frequency by default. + However, one coud provide True and False based on its own scheduling. + c.f example 2 and 3 + + optimizer_closure: One could provide its own optimizer_closure. Set to None by default. + + args: Any parameters provided to optimizer.step() + + kwargs: Any parameters provided to optimizer.step() Example:: @@ -1126,26 +1141,119 @@ def training_step(...): (opt_a, opt_b) = self.optimizers() loss = ... # automatically applies scaling, etc... + self.manual_backward(loss, opt_a) - # This will force an opt.step() even if accumulate_grad_batches is set. - self.manual_optimizer_step(opt_a, force_optimizer_step=True) + # This will use accumulate gradients for `accumulate_grad_batches` batches + # and then run opt_a.step() + self.manual_optimizer_step(opt_a) + + Example:: + + def training_step(self, batch, batch_idx): + # using Boring Model + opt = self.optimizers() # only 1 optimizer + + def compute_loss(): + x = batch[0] + x = F.dropout(x, 0.1) + predictions = self(x) + predictions = F.dropout(predictions, 0.1) + loss = self.loss(None, predictions) + return loss + + def optimizer_closure(): + # emulate MC dropout training + num_backward = 1 + losses = [] + for backward_idx in range(num_backward + 1): + loss = compute_loss() + losses.append(loss) + retain_graph = num_backward!= backward_idx + self.manual_backward(loss, opt, retain_graph=retain_graph) + loss_mean = torch.stack(losses).mean() + loss_std = torch.stack(losses).std() + self.log("train_loss_mean", loss_mean, on_step=True, prog_bar=True, on_epoch=True) + self.log("train_loss_std", loss_std, on_step=True, prog_bar=True, on_epoch=True) + + self.manual_optimizer_step(opt, optimizer_closure=optimizer_closure) + + Example:: + + # Scenario for a gan. + + def training_step(self, batch, batch_idx, optimizer_idx): + + # emulate gans training + opt_gen, opt_dis = self.optimizers() + + # Note: Be careful, don't log on the same key in self.log in both closure + # as they will be aggregated together on epoch_end + + def gen_closure(): + ... forward and compute loss for generator + loss_gen = ... + self.log("loss_gen", loss_gen, on_step=True, on_epoch=True) + self.manual_backward(loss_gen, opt_gen) + + def dis_closure(): + ... forward and compute loss for discriminator + loss_dis = ... + self.log("loss_dis", loss_dis, on_step=True, on_epoch=True) + self.manual_backward(loss_dis, opt_dis) + + # this will accumulate gradients for 2 batches and then call opt_gen.step() + self.manual_optimizer_step( + opt_gen, + optimizer_closure=gen_closure, + make_optimizer_step=batch_idx % 2 == 0) + + # update discriminator every 4 batches + # therefore, no gradient accumulation for discriminator + if batch_idx % 4 == 0 : + # Note: Set make_optimizer_step to True or it will use by default + # Trainer(accumulate_grad_batches=x) + self.manual_optimizer_step( + opt_dis, + optimizer_closure=dis_closure, + make_optimizer_step=True) """ # make sure we're using manual opt self._verify_is_manual_optimization('manual_optimizer_step') - if not self.trainer.train_loop.should_accumulate() or force_optimizer_step: + should_make_optimizer_step = not self.trainer.train_loop.should_accumulate() + make_optimizer_step = make_optimizer_step if make_optimizer_step is not None else should_make_optimizer_step + + if make_optimizer_step: # mock closure function as the user is responsible to call `manual_backward` - def mock_optimizer_closure(): + def do_nothing_optimizer_closure(): return - self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure) + is_callable = isinstance(optimizer_closure, types.FunctionType) + optimizer_closure = optimizer_closure if is_callable else do_nothing_optimizer_closure + + self.trainer.train_loop.optimizer_step( + optimizer, + None, + self.trainer.batch_idx, + optimizer_closure, + *args, + **kwargs, + ) # update will be called after every optimizer_step call if self.trainer.amp_backend == AMPType.NATIVE: self.trainer.scaler.update() + # perform zero grad + optimizer.zero_grad() + + else: + # make sure to call optimizer_closure when accumulating + if isinstance(optimizer_closure, types.FunctionType): + optimizer_closure() + def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: """ Override backward with your own implementation if you need to. @@ -1190,14 +1298,16 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): def optimizer_step( self, - epoch: int, - batch_idx: int, - optimizer: Optimizer, - optimizer_idx: int, - optimizer_closure: Optional[Callable], - on_tpu: bool, - using_native_amp: bool, - using_lbfgs: bool, + *args, + epoch: int = None, + batch_idx: int = None, + optimizer: Optimizer = None, + optimizer_idx: int = None, + optimizer_closure: Optional[Callable] = None, + on_tpu: bool = None, + using_native_amp: bool = None, + using_lbfgs: bool = None, + **kwargs, ) -> None: r""" Override this method to adjust the default way the @@ -1205,6 +1315,8 @@ def optimizer_step( By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer. + .. tip:: Consider using `manual_optimizer_step` instead of overriding this method as done previously. + Warning: If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to ``optimizer.step()`` function as shown in the examples. This ensures that @@ -1272,7 +1384,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, """ if on_tpu: - xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure}) + xm.optimizer_step(optimizer, optimizer_args={'closure': optimizer_closure, **kwargs}) elif self.trainer.amp_backend == AMPType.NATIVE: # native amp does not yet support closures. # TODO: pass the closure to the step ASAP @@ -1282,9 +1394,9 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, # apex amp does not yet support closures. # TODO: pass the closure to the step ASAP optimizer_closure() - optimizer.step() + optimizer.step(*args, **kwargs) else: - optimizer.step(closure=optimizer_closure) + optimizer.step(closure=optimizer_closure, *args, **kwargs) def optimizer_zero_grad( self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f705d82868da7..8af55f64715f2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -471,11 +471,11 @@ def _process_result(self, training_step_output, split_batch): return training_step_output_for_epoch_end - def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): + def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure, *args, **kwargs): with self.trainer.profiler.profile("optimizer_step"): # optimizer step lightningModule hook self.trainer.accelerator_backend.optimizer_step( - optimizer, batch_idx, opt_idx, train_step_and_backward_closure + optimizer, batch_idx, opt_idx, train_step_and_backward_closure, *args, **kwargs ) def on_before_zero_grad(self, optimizer): @@ -945,7 +945,8 @@ def zero_grad_handler(self, batch_idx, optimizer, opt_idx): self.on_before_zero_grad(optimizer) optimizers = enumerate([optimizer]) else: - optimizers = self.get_optimizers_iterable() + # should be called handled in `manual_optimizer_step` + optimizers = [] for idx, optimizer in optimizers: self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index d816c1e9bc5b1..ec7c22196cc06 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -16,6 +16,8 @@ import pytest import torch +import torch.nn.functional as F +from unittest.mock import patch, call, ANY from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.utilities import APEX_AVAILABLE @@ -621,3 +623,277 @@ def configure_optimizers(self): num_manual_backward_calls = 3 assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls + + +def test_manual_optimizer_step_with_optimizer_closure(tmpdir): + """ + Tests that `manual_optimizer_step` works with optimizer_closure + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + + _losses = [] + + def training_step(self, batch, batch_idx): + # manual + + # make sure there are no grads + if self.layer.weight.grad is not None: + assert torch.all(self.layer.weight.grad == 0) + + opt = self.optimizers() + + def compute_loss(): + x = batch[0] + x = F.dropout(x, 0.1) + predictions = self(x) + predictions = F.dropout(predictions, 0.1) + loss = self.loss(None, predictions) + return loss + + def optimizer_closure(): + # emulate bayesian optimization. + num_backward = 2 + losses = [] + for backward_idx in range(num_backward): + loss = compute_loss() + losses.append(loss) + retain_graph = (num_backward - 1) != backward_idx + self.manual_backward(loss, opt, retain_graph=retain_graph) + # emulate MC dropout training + loss = torch.stack(losses).mean() + self._losses.append(loss) + self.log("train_loss", loss, on_step=True, prog_bar=True, on_epoch=True) + assert losses[0] != losses[1] + + weight_before = self.layer.weight.clone() + + self.manual_optimizer_step(opt, optimizer_closure=optimizer_closure) + + weight_after = self.layer.weight.clone() + assert not torch.equal(weight_before, weight_after) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 2 + trainer = Trainer( + automatic_optimization=False, + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + ) + + trainer.fit(model) + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2 + assert trainer.logger_connector.progress_bar_metrics["train_loss_step"] == model._losses[-1] + assert trainer.logger_connector.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean() + + +def test_manual_optimizer_step_with_optimizer_closure_and_accumulated_grad(tmpdir): + """ + Tests that `manual_optimizer_step` works with optimizer_closure and accumulated_grad + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + # manual + opt = self.optimizers() + x = batch[0] + + loss_1 = self(x) + loss_1 = self.loss(loss_1, loss_1) + + def optimizer_closure(): + # emulate bayesian optimization. + num_backward = 1 + for backward_idx in range(num_backward + 1): + retain_graph = num_backward != backward_idx # noqa E225 + self.manual_backward(loss_1, opt, retain_graph=retain_graph) + + weight_before = self.layer.weight.clone() + + self.manual_optimizer_step(opt, optimizer_closure=optimizer_closure) + + weight_after = self.layer.weight.clone() + if not self.trainer.train_loop.should_accumulate(): + assert not torch.equal(weight_before, weight_after) + else: + assert self.layer.weight.grad is not None + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 4 + trainer = Trainer( + automatic_optimization=False, + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + + trainer.fit(model) + assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2 + + +@patch("torch.optim.SGD.step") +def test_manual_optimizer_step_with_optimizer_closure_and_extra_arguments(step_mock, tmpdir): + """ + Tests that `manual_optimizer_step` works with optimizer_closure and extra arguments + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + # manual + opt = self.optimizers() + x = batch[0] + + loss_1 = self(x) + loss_1 = self.loss(loss_1, loss_1) + + def optimizer_closure(): + # emulate bayesian optimization. + num_backward = 1 + for backward_idx in range(num_backward + 1): + retain_graph = num_backward != backward_idx # noqa E225 + self.manual_backward(loss_1, opt, retain_graph=retain_graph) + + self.manual_optimizer_step(opt, 1, optimizer_closure=optimizer_closure, something="new") + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 4 + trainer = Trainer( + automatic_optimization=False, + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + + trainer.fit(model) + expected_calls = [call(1, closure=ANY, something="new") for s in range(2)] + step_mock.assert_has_calls(expected_calls) + + +@patch("torch.optim.Adam.step") +@patch("torch.optim.SGD.step") +def test_manual_optimizer_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, mock_adam_step, tmpdir): + """ + Tests that `manual_optimizer_step` works with optimizer_closure and different accumulated_gradient frequency + """ + os.environ['PL_DEV_DEBUG'] = '1' + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx): + + # emulate gans training + opt_gen, opt_dis = self.optimizers() + + # Note: Be careful, don't log on the same key in self.log in both closure + # as they will be aggregated together on epoch_end + + def compute_loss(): + x = batch[0] + x = F.dropout(x, 0.1) + predictions = self(x) + predictions = F.dropout(predictions, 0.1) + loss = self.loss(None, predictions) + return loss + + def gen_closure(): + loss_gen = compute_loss() + self.log("loss_gen", loss_gen, on_step=True, on_epoch=True) + self.manual_backward(loss_gen, opt_gen) + + def dis_closure(): + loss_dis = compute_loss() + self.log("loss_dis", loss_dis, on_step=True, on_epoch=True) + self.manual_backward(loss_dis, opt_dis) + + # this will accumulate gradients for 2 batches and then call opt_gen.step() + self.manual_optimizer_step( + opt_gen, + optimizer_closure=gen_closure, + make_optimizer_step=batch_idx % 2 == 0, + optim='sgd') + + # update discriminator every 4 baches + # therefore, no gradient accumulation for discriminator + if batch_idx % 4 == 0 : + # Note: Set make_optimizer_step to True or it will use by default + # Trainer(accumulate_grad_batches=x) + self.manual_optimizer_step( + opt_dis, + optimizer_closure=dis_closure, + make_optimizer_step=True, + optim='adam') + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) + return [optimizer_gen, optimizer_dis] + + model = TestModel() + model.val_dataloader = None + model.training_epoch_end = None + + limit_train_batches = 8 + trainer = Trainer( + automatic_optimization=False, + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + ) + + trainer.fit(model) + expected_calls = [call(closure=ANY, optim='sgd') for s in range(4)] + mock_sgd_step.assert_has_calls(expected_calls) + + expected_calls = [call(closure=ANY, optim='adam') for s in range(2)] + mock_adam_step.assert_has_calls(expected_calls)