From f29ecbfd909ff431ef837fcc8ebff451e897cb0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 15 Apr 2021 18:48:16 +0200 Subject: [PATCH 01/21] Typing for accelerators and plugins (#7022) --- pytorch_lightning/accelerators/accelerator.py | 57 ++++++++++--------- pytorch_lightning/accelerators/tpu.py | 17 +++--- .../plugins/precision/apex_amp.py | 40 +++++++------ .../plugins/precision/deepspeed_precision.py | 7 +-- .../plugins/precision/precision_plugin.py | 54 ++++++++---------- .../plugins/precision/sharded_native_amp.py | 12 ++-- setup.cfg | 13 +---- setup.py | 1 - tests/trainer/test_training_loop.py | 4 +- tests/utilities/test_parsing.py | 1 - 10 files changed, 95 insertions(+), 111 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index db8eb28e2bce51..30454436994b5d 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -15,22 +15,26 @@ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union import torch +from torch import Tensor +from torch.nn import Module from torch.optim import Optimizer from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum +if _NATIVE_AMP_AVAILABLE: + from torch.cuda.amp import GradScaler + _STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None] -class Accelerator(object): +class Accelerator: """ The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware. @@ -52,7 +56,6 @@ def __init__( training_type_plugin: TrainingTypePlugin, ) -> None: """ - Args: precision_plugin: the plugin to handle precision-specific parts training_type_plugin: the plugin to handle different training routines @@ -64,7 +67,7 @@ def __init__( self.lr_schedulers: Sequence = [] self.optimizer_frequencies: Sequence = [] - def connect(self, model: LightningModule) -> None: + def connect(self, model: 'pl.LightningModule') -> None: """Transfers ownership of the model to this plugin""" self.training_type_plugin.connect(model) @@ -76,7 +79,7 @@ def setup_environment(self) -> None: """ self.training_type_plugin.setup_environment() - def setup(self, trainer: 'pl.Trainer', model: LightningModule) -> None: + def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None: """ Setup plugins for the trainer fit and creates optimizers. @@ -111,22 +114,22 @@ def post_dispatch(self, trainer: 'pl.Trainer') -> None: self.precision_plugin.post_dispatch() @property - def model(self) -> torch.nn.Module: - """Returns the model. This can also be a wrapped LightningModule. + def model(self) -> Module: + """ + Returns the model. This can also be a wrapped LightningModule. For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module` - """ return self.training_type_plugin.model @model.setter - def model(self, new_model: torch.nn.Module) -> None: + def model(self, new_model: Module) -> None: self.training_type_plugin.model = new_model @property - def lightning_module(self) -> LightningModule: - """Returns the pure LightningModule. + def lightning_module(self) -> 'pl.LightningModule': + """ + Returns the pure LightningModule. To get the potentially wrapped model use :attr:`Accelerator.model` - """ return self.training_type_plugin.lightning_module @@ -135,7 +138,8 @@ def root_device(self) -> torch.device: return self.training_type_plugin.root_device def teardown(self) -> None: - """This method is called to teardown the training process. + """ + This method is called to teardown the training process. It is the right place to release memory and free other ressources. """ pass @@ -268,13 +272,13 @@ def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: def backward( self, - closure_loss: torch.Tensor, + closure_loss: Tensor, optimizer: Optimizer, optimizer_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> torch.Tensor: + ) -> Tensor: """Forwards backward-calls to the precision plugin. Args: @@ -325,9 +329,7 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """clips all the optimizer parameters to the given value""" - self.precision_plugin.clip_gradients( - self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm - ) + self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm) def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None: """Hook to do something on the end of an training epoch @@ -342,11 +344,11 @@ def on_train_end(self) -> None: pass def setup_optimizers(self, trainer: 'pl.Trainer') -> None: - """creates optimizers and schedulers + """ + Creates optimizers and schedulers Args: trainer: the Trainer, these optimizers should be connected to - model: the model to be optimized by the created optimizers """ if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING): return @@ -357,7 +359,7 @@ def setup_optimizers(self, trainer: 'pl.Trainer') -> None: self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies - def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'pl.LightningModule') -> None: """Attaches the training type plugin to the accelerator.""" plugin.setup(model) @@ -390,22 +392,21 @@ def precision(self) -> Union[str, int]: return self.precision_plugin.precision @property - def scaler(self) -> Optional['torch.cuda.amp.GradScaler']: - + def scaler(self) -> Optional['GradScaler']: return getattr(self.precision_plugin, 'scaler', None) @property def rpc_enabled(self) -> bool: return self.training_type_plugin.rpc_enabled - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, torch.Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """ Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. """ return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer) - def on_save(self, checkpoint: Dict[str, Union[Any, torch.Tensor]]) -> Dict[str, Union[Any, torch.Tensor]]: + def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]: return self.training_type_plugin.on_save(checkpoint) def barrier(self, name: Optional[str] = None) -> None: @@ -420,7 +421,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: """ return self.training_type_plugin.broadcast(obj, src) - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """ Function to gather a tensor from several distributed processes. @@ -464,7 +465,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: yield # todo: remove in v1.5 - def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: 'pl.LightningModule') -> None: """ Attaches the training type plugin to the accelerator. Also transfers ownership of the model to this plugin diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 9aac6854db142a..b1b9a2d96f7f5f 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -15,6 +15,7 @@ from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin @@ -26,10 +27,9 @@ import torch_xla.core.xla_model as xm from torch_xla._patched_functions import clip_grad_norm_ + # rename to mock in a test xla_clip_grad_norm_ = clip_grad_norm_ -import pytorch_lightning as pl - class TPUAccelerator(Accelerator): """ Accelerator for TPU devices. """ @@ -59,19 +59,16 @@ def clip_gradients( self, optimizer: Optimizer, clip_val: Union[float, int], - norm_type: float = 2.0, - gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: - assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \ + assert gradient_clip_algorithm == GradClipAlgorithmType.NORM, \ "Only NORM gradient clipping is supported on TPU for now" - model = self.lightning_module - parameters = model.parameters() - grad_clip_val = float(clip_val) if grad_clip_val <= 0: return - max_norm = grad_clip_val + parameters = self.model.parameters() + norm_type = 2.0 - xla_clip_grad_norm_(parameters, max_norm, norm_type) + xla_clip_grad_norm_(parameters, grad_clip_val, norm_type) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index b2b1c726a04674..30614d3faa1877 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Generator, List, Sequence, Tuple, Type +from typing import Any, Callable, ContextManager, Iterator, List, Sequence, Tuple, Type import torch +from torch import Tensor +from torch.nn import Module from torch.optim import Optimizer from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn +from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType + +PARAMETERS = Iterator[torch.nn.Parameter] if _APEX_AVAILABLE: from apex import amp @@ -32,11 +36,15 @@ def __init__(self, amp_level: str = "O2") -> None: self.backend = AMPType.APEX self.amp_level = amp_level - def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]: + def master_params(self, optimizer: Optimizer) -> PARAMETERS: return amp.master_params(optimizer) - def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer], - lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence[Optimizer], Sequence[Any]]: + def connect( + self, + model: Module, + optimizers: Sequence[Optimizer], + lr_schedulers: Sequence[Any], + ) -> Tuple[Module, Sequence[Optimizer], Sequence[Any]]: """Connects the precision plugin to the training process, configures apex and reinits the schedulers """ @@ -49,28 +57,28 @@ def connect(self, model: torch.nn.Module, optimizers: Sequence[Optimizer], def backward( self, model: LightningModule, - closure_loss: torch.Tensor, + closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> torch.Tensor: + ) -> Tensor: """performs the actual backpropagation Args: model: the model to be optimized closure_loss: the loss value obtained from the closure optimizer: the optimizer to perform the step lateron - opt_idx: the optimizer's index + opt_idx: the optimizer index should_accumulate: whether to accumulate gradients or not """ - closure_loss = amp.scale_loss(closure_loss, model.trainer.optimizers if optimizer is None else optimizer) + opt = model.trainer.optimizers if optimizer is None else optimizer + scaled_loss: ContextManager[Tensor] = amp.scale_loss(closure_loss, opt) # enter apex context - context = closure_loss - closure_loss = closure_loss.__enter__() + closure_loss = scaled_loss.__enter__() # do backward pass # TODO: not entirely sure, why we need this @@ -84,10 +92,8 @@ def backward( closure_loss.backward(*args, **kwargs) # exit amp context - a, b, c = None, None, None - error = context.__exit__(a, b, c) + error = scaled_loss.__exit__(None, None, None) if error: - rank_zero_warn(a, b, c) raise Exception("apex unscale error") # once backward has been applied, release graph @@ -97,17 +103,17 @@ def backward( def configure_apex( self, amp: Type, - model: LightningModule, + model: Module, optimizers: List[Optimizer], amp_level: str, - ) -> Tuple[LightningModule, List[Optimizer]]: + ) -> Tuple[Module, List[Optimizer]]: r""" Override to init AMP your own way. Must return a model and list of optimizers. Args: amp: pointer to amp library object. - model: pointer to current :class:`LightningModule`. + model: pointer to current :class:`torch.nn.Module`. optimizers: list of optimizers passed in :meth:`configure_optimizers`. amp_level: AMP mode chosen ('O1', 'O2', etc...) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 22fa5bf0823574..dc29a5cee40145 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Callable, Union -import torch +from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl @@ -54,13 +54,13 @@ def pre_optimizer_step( def backward( self, model: 'pl.LightningModule', - closure_loss: torch.Tensor, + closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> torch.Tensor: + ) -> Tensor: if is_overridden('backward', model): warning_cache.warn( "Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles" @@ -76,7 +76,6 @@ def backward( def clip_gradients( self, - model: 'pl.LightningModule', optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 059506f830b8f3..ac33eeea287eb2 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Callable, Generator, Sequence, Tuple, Union +from typing import Any, Callable, Iterator, Sequence, Tuple, Union import torch -import torch.nn as nn +from torch import Tensor +from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.base_plugin import Plugin from pytorch_lightning.utilities import GradClipAlgorithmType +PARAMETERS = Iterator[torch.nn.Parameter] + class PrecisionPlugin(Plugin): """ @@ -32,17 +35,10 @@ class PrecisionPlugin(Plugin): EPSILON: float = 1e-6 precision: Union[str, int] = 32 - def __init__(self) -> None: - super().__init__() - self.clip_grad_funcs = { - GradClipAlgorithmType.VALUE: self.clip_grad_by_value, - GradClipAlgorithmType.NORM: self.clip_grad_by_norm, - } - - def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, None]: - """The master params of the model. Returns the plain model params here. + def master_params(self, optimizer: Optimizer) -> PARAMETERS: + """ + The master params of the model. Returns the plain model params here. Maybe different in other precision plugins. - """ for group in optimizer.param_groups: for p in group["params"]: @@ -50,23 +46,23 @@ def master_params(self, optimizer: Optimizer) -> Generator[torch.Tensor, None, N def connect( self, - model: nn.Module, + model: Module, optimizers: Sequence[Optimizer], lr_schedulers: Sequence[Any], - ) -> Tuple[nn.Module, Sequence[Optimizer], Sequence[Any]]: + ) -> Tuple[Module, Sequence[Optimizer], Sequence[Any]]: """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers def backward( self, model: 'pl.LightningModule', - closure_loss: torch.Tensor, + closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> torch.Tensor: + ) -> Tensor: """performs the actual backpropagation Args: @@ -106,7 +102,6 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: def clip_gradients( self, - model: 'pl.LightningModule', optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -119,24 +114,25 @@ def clip_gradients( if clip_val <= 0: return - clip_grad_func = self.clip_grad_funcs[gradient_clip_algorithm] - clip_grad_func(optimizer, clip_val) # type: ignore + if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: + self.clip_grad_by_value(optimizer, clip_val) + elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: + # TODO: there should be a mechanism to set `norm_type` + self.clip_grad_by_norm(optimizer, clip_val, eps=self.EPSILON) def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by value""" - parameters = list(self.master_params(optimizer)) + parameters = self.master_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = 2.0) -> None: + def clip_grad_by_norm( + self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6 + ) -> None: """Clip gradients by norm""" - # TODO: separate TPU case from here - parameters = list(self.master_params(optimizer)) - max_norm = clip_val + parameters = self.master_params(optimizer) - if isinstance(parameters, torch.Tensor): - parameters = [parameters] + # TODO: replace this with torch.nn.clip_grad_norm_ parameters = list(filter(lambda p: p.grad is not None, parameters)) - device = parameters[0].device if norm_type == math.inf: @@ -147,9 +143,7 @@ def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float], n torch.norm(p.grad.data.to(device), norm_type, out=out[i]) total_norm = torch.norm(out, norm_type) - eps = self.EPSILON - - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) + clip_coef = torch.tensor(clip_val, device=device) / (total_norm + eps) clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) for p in parameters: p.grad.data.mul_(clip_coef.to(p.grad.data.device)) diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index 28555a1a60b8d2..4d8a2f0f934ed1 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Union - -from torch.optim import Optimizer +from typing import Union from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE @@ -24,13 +22,13 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): - """Mixed Precision for Sharded Training - """ + """Mixed Precision for Sharded Training""" def __init__(self) -> None: super().__init__() self.scaler = ShardedGradScaler() - def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: - optimizer = cast(OSS, optimizer) + def clip_grad_by_norm( + self, optimizer: 'OSS', clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6 + ) -> None: optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/setup.cfg b/setup.cfg index 70139348462aa4..3fa6e390767254 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,6 @@ omit = [flake8] -# TODO: this should be 88 or 100 according PEP8 max-line-length = 120 exclude = .tox, @@ -105,9 +104,6 @@ NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = false [mypy] -# Typing tests is low priority, but enabling type checking on the -# untyped test functions (using `--check-untyped-defs`) is still -# high-value because it helps test the typing. files = pytorch_lightning, pl_examples, benchmarks, tests disallow_untyped_defs = True ignore_missing_imports = True @@ -115,12 +111,10 @@ show_error_codes = True warn_redundant_casts = True warn_unused_configs = True warn_unused_ignores = True +allow_redefinition = True +# disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ disable_error_code = attr-defined -# todo: this is magically failing, need to be revisited -[mypy-pytorch_lightning.accelerators.tpu.*] -ignore_errors = True - # todo: add proper typing to this module... [mypy-pytorch_lightning.callbacks.*] ignore_errors = True @@ -164,8 +158,7 @@ ignore_errors = True # todo: add proper typing to this module... [mypy-pytorch_lightning.trainer.*] ignore_errors = True - -# whitelist evaluation_loop.py +# whitelist [mypy-pytorch_lightning.trainer.evaluation_loop] ignore_errors = False diff --git a/setup.py b/setup.py index 7e75c514734b5a..264f219e22b55c 100755 --- a/setup.py +++ b/setup.py @@ -23,7 +23,6 @@ try: from pytorch_lightning import __about__ as info from pytorch_lightning import setup_tools - except ImportError: # alternative https://stackoverflow.com/a/67692/4521646 sys.path.append("pytorch_lightning") diff --git a/tests/trainer/test_training_loop.py b/tests/trainer/test_training_loop.py index 25be29d73f1a4f..b024a7eabbecc5 100644 --- a/tests/trainer/test_training_loop.py +++ b/tests/trainer/test_training_loop.py @@ -72,9 +72,7 @@ def optimizer_step( super().optimizer_step( epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs ) - self.called.append( - "optimizer_step" - ) # append after as closure calls other methods + self.called.append("optimizer_step") # append after as closure calls other methods def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.called.append("on_train_batch_end") diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index 9c6900f81fcae3..57e49df2df0662 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -14,7 +14,6 @@ import inspect import pytest - from torch.jit import ScriptModule from pytorch_lightning.utilities.parsing import ( From 4c07ab5e99dd20c1f309d9e73cdaacc1ebad9499 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 15 Apr 2021 15:10:34 -0700 Subject: [PATCH 02/21] Use PyTorch API logging for Lightning Trainer (#6771) * Update trainer.py * Update trainer.py * Update trainer.py --- pytorch_lightning/trainer/trainer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4d4fc53f857745..b1c29ff2c8892e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -295,7 +295,7 @@ def __init__( """ super().__init__() - + Trainer._log_api_event("init") distributed_backend = distributed_backend or accelerator # init connectors @@ -416,6 +416,7 @@ def fit( If the model has a predefined val_dataloaders method this will be skipped """ + Trainer._log_api_event("fit") # we reuse fit for other functions. When already set, it shouldn't be modified. if not self.state.running: self.state = TrainerState.FITTING @@ -881,6 +882,7 @@ def validate( # -------------------- # SETUP HOOK # -------------------- + Trainer._log_api_event("validate") self.verbose_evaluate = verbose self.state = TrainerState.VALIDATING @@ -943,6 +945,7 @@ def test( # -------------------- # SETUP HOOK # -------------------- + Trainer._log_api_event("test") self.verbose_evaluate = verbose self.state = TrainerState.TESTING @@ -1039,6 +1042,7 @@ def predict( # SETUP HOOK # -------------------- # If you supply a datamodule you can't supply dataloaders + Trainer._log_api_event("predict") model = model or self.lightning_module @@ -1084,6 +1088,7 @@ def tune( If the model has a predefined val_dataloaders method this will be skipped """ + Trainer._log_api_event("tune") self.state = TrainerState.TUNING self.tuning = True @@ -1174,3 +1179,7 @@ def call_hook(self, hook_name, *args, **kwargs): if not skip: self._cache_logged_metrics() return output + + @staticmethod + def _log_api_event(event: str) -> None: + torch._C._log_api_usage_once("lightning.trainer." + event) From 402a258705c10c8ad57bfdc16c39a8420b1425ee Mon Sep 17 00:00:00 2001 From: Matthew Sarmiento Date: Thu, 15 Apr 2021 16:52:47 -0700 Subject: [PATCH 03/21] [docs]: pass parser to Trainer.add_argparse_args() (#7029) --- docs/source/common/trainer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index eea68cbd460c5e..73d248aaf450ce 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -130,7 +130,7 @@ So you can run it like so: if __name__ == '__main__': parser = ArgumentParser() - parser = Trainer.add_argparse_args() + parser = Trainer.add_argparse_args(parser) args = parser.parse_args() main(args) From 67d21609c9f3947736796ad3ab1d953e06c403a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 13:38:57 +0200 Subject: [PATCH 04/21] Add Trainer max_time argument + Callback (#6823) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: Akihiro Nitta Co-authored-by: Akihiro Nitta --- CHANGELOG.md | 3 + docs/source/common/trainer.rst | 20 ++ pytorch_lightning/callbacks/__init__.py | 2 + pytorch_lightning/callbacks/timer.py | 173 +++++++++++++++ .../trainer/connectors/callback_connector.py | 18 +- pytorch_lightning/trainer/trainer.py | 18 +- tests/callbacks/test_timer.py | 204 ++++++++++++++++++ 7 files changed, 435 insertions(+), 3 deletions(-) create mode 100644 pytorch_lightning/callbacks/timer.py create mode 100644 tests/callbacks/test_timer.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 277fee3463e22b..85a827104d2c46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -99,6 +99,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningModule.lr_schedulers()` for manual optimization ([#6567](https://github.com/PyTorchLightning/pytorch-lightning/pull/6567)) +- Added `max_time` Trainer argument to limit training time ([#6823](https://github.com/PyTorchLightning/pytorch-lightning/pull/6823)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 73d248aaf450ce..dd432011261e54 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -987,6 +987,26 @@ Trainer will train model for at least min_steps or min_epochs (latest). # Run at least for 100 steps (disable min_epochs) trainer = Trainer(min_steps=100, min_epochs=0) +max_time +^^^^^^^^ + +Set the maximum amount of time for training. Training will get interrupted mid-epoch. +For customizable options use the :class:`~pytorch_lightning.callbacks.timer.Timer` callback. + +.. testcode:: + + # Default (disabled) + trainer = Trainer(max_time=None) + + # Stop after 12 hours of training or when reaching 10 epochs (string) + trainer = Trainer(max_time="00:12:00:00", max_epochs=10) + + # Stop after 1 day and 5 hours (dict) + trainer = Trainer(max_time={"days": 1, "hours": 5}) + +In case ``max_time`` is used together with ``min_steps`` or ``min_epochs``, the ``min_*`` requirement +always has precedence. + num_nodes ^^^^^^^^^ diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index fb61ad81aee283..76e6d8f7eb0b6e 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -23,6 +23,7 @@ from pytorch_lightning.callbacks.pruning import ModelPruning from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging +from pytorch_lightning.callbacks.timer import Timer __all__ = [ 'BackboneFinetuning', @@ -39,4 +40,5 @@ 'ProgressBarBase', 'QuantizationAwareTraining', 'StochasticWeightAveraging', + 'Timer', ] diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py new file mode 100644 index 00000000000000..9b93499c82ea17 --- /dev/null +++ b/pytorch_lightning/callbacks/timer.py @@ -0,0 +1,173 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Timer +^^^^^ +""" +import logging +import time +from datetime import timedelta +from typing import Any, Dict, Optional, Union + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import LightningEnum +from pytorch_lightning.utilities.distributed import rank_zero_info +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +log = logging.getLogger(__name__) + + +class Interval(LightningEnum): + step = "step" + epoch = "epoch" + + +class Timer(Callback): + """ + The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer + if the given time limit for the training loop is reached. + + Args: + duration: A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a :class:`datetime.timedelta`, + or a dict containing key-value compatible with :class:`~datetime.timedelta`. + interval: Determines if the interruption happens on epoch level or mid-epoch. + Can be either ``"epoch"`` or ``"step"``. + verbose: Set this to ``False`` to suppress logging messages. + + Raises: + MisconfigurationException: + If ``interval`` is not one of the supported choices. + + Example:: + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import Timer + + # stop training after 12 hours + timer = Timer(duration="00:12:00:00") + + # or provide a datetime.timedelta + from datetime import timedelta + timer = Timer(duration=timedelta(weeks=1)) + + # or provide a dictionary + timer = Timer(duration=dict(weeks=4, days=2)) + + # force training to stop after given time limit + trainer = Trainer(callbacks=[timer]) + + # query training/validation/test time (in seconds) + timer.time_elapsed("train") + timer.start_time("validate") + timer.end_time("test") + """ + + def __init__( + self, + duration: Optional[Union[str, timedelta, Dict[str, int]]] = None, + interval: str = Interval.step, + verbose: bool = True, + ) -> None: + super().__init__() + if isinstance(duration, str): + dhms = duration.strip().split(":") + dhms = [int(i) for i in dhms] + duration = timedelta(days=dhms[0], hours=dhms[1], minutes=dhms[2], seconds=dhms[3]) + if isinstance(duration, dict): + duration = timedelta(**duration) + if interval not in set(Interval): + raise MisconfigurationException( + f"Unsupported parameter value `Timer(interval={interval})`. Possible choices are:" + f" {', '.join(set(Interval))}" + ) + self._duration = duration.total_seconds() if duration is not None else None + self._interval = interval + self._verbose = verbose + self._start_time = {stage: None for stage in RunningStage} + self._end_time = {stage: None for stage in RunningStage} + self._offset = 0 + + def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + """Return the start time of a particular stage (in seconds)""" + stage = RunningStage(stage) + return self._start_time[stage] + + def end_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + """Return the end time of a particular stage (in seconds)""" + stage = RunningStage(stage) + return self._end_time[stage] + + def time_elapsed(self, stage: str = RunningStage.TRAINING) -> float: + """Return the time elapsed for a particular stage (in seconds)""" + start = self.start_time(stage) + end = self.end_time(stage) + offset = self._offset if stage == RunningStage.TRAINING else 0 + if start is None: + return offset + if end is None: + return time.monotonic() - start + offset + return end - start + offset + + def time_remaining(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + """Return the time remaining for a particular stage (in seconds)""" + if self._duration is not None: + return self._duration - self.time_elapsed(stage) + + def on_train_start(self, *args, **kwargs) -> None: + self._start_time[RunningStage.TRAINING] = time.monotonic() + + def on_train_end(self, *args, **kwargs) -> None: + self._end_time[RunningStage.TRAINING] = time.monotonic() + + def on_validation_start(self, *args, **kwargs) -> None: + self._start_time[RunningStage.VALIDATING] = time.monotonic() + + def on_validation_end(self, *args, **kwargs) -> None: + self._end_time[RunningStage.VALIDATING] = time.monotonic() + + def on_test_start(self, *args, **kwargs) -> None: + self._start_time[RunningStage.TESTING] = time.monotonic() + + def on_test_end(self, *args, **kwargs) -> None: + self._end_time[RunningStage.TESTING] = time.monotonic() + + def on_train_batch_end(self, trainer: 'pl.Trainer', *args, **kwargs) -> None: + if self._interval != Interval.step or self._duration is None: + return + self._check_time_remaining(trainer) + + def on_train_epoch_end(self, trainer: 'pl.Trainer', *args, **kwargs) -> None: + if self._interval != Interval.epoch or self._duration is None: + return + self._check_time_remaining(trainer) + + def on_save_checkpoint( + self, + trainer: 'pl.Trainer', + pl_module: 'pl.LightningModule', + checkpoint: Dict[str, Any], + ) -> Dict[str, Any]: + return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}} + + def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: + time_elapsed = callback_state.get("time_elapsed", {}) + self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) + + def _check_time_remaining(self, trainer: 'pl.Trainer') -> None: + should_stop = self.time_elapsed() >= self._duration + should_stop = trainer.accelerator.broadcast(should_stop) + trainer.should_stop = trainer.should_stop or should_stop + if should_stop and self._verbose: + rank_zero_info(f"Time limit reached. Elapsed time is {self.time_elapsed}. Signaling Trainer to stop.") diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 8a5289e608c945..544b229a217281 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import List, Union +from datetime import timedelta +from typing import List, Union, Optional, Dict from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -35,6 +37,7 @@ def on_trainer_init( weights_save_path, resume_from_checkpoint, stochastic_weight_avg, + max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, ): self.trainer.resume_from_checkpoint = resume_from_checkpoint @@ -55,6 +58,8 @@ def on_trainer_init( # configure swa callback self._configure_swa_callbacks() + self._configure_timer_callback(max_time) + # init progress bar self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) @@ -106,6 +111,17 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0): return progress_bar_callback + def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: + if max_time is None: + return + if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): + rank_zero_info( + "Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer." + ) + return + timer = Timer(duration=max_time, interval="step") + self.trainer.callbacks.append(timer) + def _trainer_has_checkpoint_callbacks(self): return len(self.trainer.checkpoint_callbacks) > 0 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b1c29ff2c8892e..62c8f530dca064 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -14,6 +14,7 @@ """Trainer to automate the training.""" import logging import warnings +from datetime import timedelta from itertools import count from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union @@ -108,6 +109,7 @@ def __init__( min_epochs: Optional[int] = None, max_steps: Optional[int] = None, min_steps: Optional[int] = None, + max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, limit_train_batches: Union[int, float] = 1.0, limit_val_batches: Union[int, float] = 1.0, limit_test_batches: Union[int, float] = 1.0, @@ -241,6 +243,11 @@ def __init__( min_steps: Force training for at least these number of steps. Disabled by default (None). + max_time: Stop training after this amount of time has passed. Disabled by default (None). + The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a + :class:`datetime.timedelta`, or a dictionary with keys that will be passed to + :class:`datetime.timedelta`. + num_nodes: number of GPU nodes for distributed training. num_processes: number of processes for distributed training with distributed_backend="ddp_cpu" @@ -332,8 +339,15 @@ def __init__( # init callbacks # Declare attributes to be set in callback_connector on_trainer_init self.callback_connector.on_trainer_init( - callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir, - weights_save_path, resume_from_checkpoint, stochastic_weight_avg + callbacks, + checkpoint_callback, + progress_bar_refresh_rate, + process_position, + default_root_dir, + weights_save_path, + resume_from_checkpoint, + stochastic_weight_avg, + max_time, ) # hook diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py new file mode 100644 index 00000000000000..c27eebbeb7805f --- /dev/null +++ b/tests/callbacks/test_timer.py @@ -0,0 +1,204 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from datetime import timedelta +from unittest.mock import Mock, patch + +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks.timer import Timer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +def test_trainer_flag(caplog): + + class TestModel(BoringModel): + + def on_fit_start(self): + raise SystemExit() + + trainer = Trainer(max_time=dict(seconds=1337)) + with pytest.raises(SystemExit): + trainer.fit(TestModel()) + timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0] + assert timer._duration == 1337 + + trainer = Trainer(max_time=dict(seconds=1337), callbacks=[Timer()]) + with pytest.raises(SystemExit), caplog.at_level(level=logging.INFO): + trainer.fit(TestModel()) + assert "callbacks list already contains a Timer" in caplog.text + + +@pytest.mark.parametrize( + "duration,expected", [ + (None, None), + ("00:00:00:22", timedelta(seconds=22)), + ("12:34:56:65", timedelta(days=12, hours=34, minutes=56, seconds=65)), + (timedelta(weeks=52, milliseconds=1), timedelta(weeks=52, milliseconds=1)), + (dict(weeks=52, days=1), timedelta(weeks=52, days=1)), + ] +) +def test_timer_parse_duration(duration, expected): + timer = Timer(duration=duration) + assert (timer.time_remaining() == expected is None) or (timer.time_remaining() == expected.total_seconds()) + + +def test_timer_interval_choice(): + Timer(duration=timedelta(), interval="step") + Timer(duration=timedelta(), interval="epoch") + with pytest.raises(MisconfigurationException, match="Unsupported parameter value"): + Timer(duration=timedelta(), interval="invalid") + + +@patch("pytorch_lightning.callbacks.timer.time") +def test_timer_time_remaining(time_mock): + """ Test that the timer tracks the elapsed and remaining time correctly. """ + start_time = time.monotonic() + duration = timedelta(seconds=10) + time_mock.monotonic.return_value = start_time + timer = Timer(duration=duration) + assert timer.time_remaining() == duration.total_seconds() + assert timer.time_elapsed() == 0 + + # timer not started yet + time_mock.monotonic.return_value = start_time + 60 + assert timer.start_time() is None + assert timer.time_remaining() == 10 + assert timer.time_elapsed() == 0 + + # start timer + time_mock.monotonic.return_value = start_time + timer.on_train_start(trainer=Mock(), pl_module=Mock()) + assert timer.start_time() == start_time + + # pretend time has elapsed + elapsed = 3 + time_mock.monotonic.return_value = start_time + elapsed + assert timer.start_time() == start_time + assert round(timer.time_remaining()) == 7 + assert round(timer.time_elapsed()) == 3 + + +def test_timer_stops_training(tmpdir): + """ Test that the timer stops training before reaching max_epochs """ + model = BoringModel() + duration = timedelta(milliseconds=100) + timer = Timer(duration=duration) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1000, + callbacks=[timer], + ) + trainer.fit(model) + assert trainer.global_step > 1 + assert trainer.current_epoch < 999 + + +@pytest.mark.parametrize("interval", ["step", "epoch"]) +def test_timer_zero_duration_stop(tmpdir, interval): + """ Test that the timer stops training immediately after the first check occurs. """ + model = BoringModel() + duration = timedelta(0) + timer = Timer(duration=duration, interval=interval) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[timer], + ) + trainer.fit(model) + if interval == "step": + # timer triggers stop on step end + assert trainer.global_step == 1 + assert trainer.current_epoch == 0 + else: + # timer triggers stop on epoch end + assert trainer.global_step == len(trainer.train_dataloader) + assert trainer.current_epoch == 0 + + +@pytest.mark.parametrize("min_steps,min_epochs", [ + (None, 2), + (3, None), + (3, 2), +]) +def test_timer_duration_min_steps_override(tmpdir, min_steps, min_epochs): + model = BoringModel() + duration = timedelta(0) + timer = Timer(duration=duration) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[timer], + min_steps=min_steps, + min_epochs=min_epochs, + ) + trainer.fit(model) + if min_epochs: + assert trainer.current_epoch >= min_epochs - 1 + if min_steps: + assert trainer.global_step >= min_steps - 1 + assert timer.time_elapsed() > duration.total_seconds() + + +def test_timer_resume_training(tmpdir): + """ Test that the timer can resume together with the Trainer. """ + model = BoringModel() + timer = Timer(duration=timedelta(milliseconds=200)) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) + + # initial training + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=100, + callbacks=[timer, checkpoint_callback], + ) + trainer.fit(model) + assert not timer._offset + assert timer.time_remaining() <= 0 + assert trainer.current_epoch < 99 + saved_global_step = trainer.global_step + + # resume training (with depleted timer + timer = Timer(duration=timedelta(milliseconds=200)) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[timer, checkpoint_callback], + resume_from_checkpoint=checkpoint_callback.best_model_path, + ) + trainer.fit(model) + assert timer._offset > 0 + assert trainer.global_step == saved_global_step + 1 + + +@RunIf(skip_windows=True) +def test_timer_track_stages(tmpdir): + """ Test that the timer tracks time also for other stages (train/val/test). """ + # note: skipped on windows because time resolution of time.monotonic() is not high enough for this fast test + model = BoringModel() + timer = Timer() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=5, + callbacks=[timer], + ) + trainer.fit(model) + assert timer.time_elapsed() == timer.time_elapsed("train") > 0 + assert timer.time_elapsed("validate") > 0 + assert timer.time_elapsed("test") == 0 + trainer.test(model) + assert timer.time_elapsed("test") > 0 From 832a03af7cae21bd4e6b7f752689a7ac6c2d4ce3 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 16 Apr 2021 18:01:56 +0530 Subject: [PATCH 05/21] Add Training Type Plugins Registry (#6982) Co-authored-by: Sean Naren Co-authored-by: thomas chaton --- pytorch_lightning/plugins/__init__.py | 11 ++ pytorch_lightning/plugins/plugins_registry.py | 146 ++++++++++++++++++ .../plugins/training_type/deepspeed.py | 20 +++ .../training_type/training_type_plugin.py | 4 + .../connectors/accelerator_connector.py | 14 +- tests/plugins/test_plugins_registry.py | 83 ++++++++++ 6 files changed, 276 insertions(+), 2 deletions(-) create mode 100644 pytorch_lightning/plugins/plugins_registry.py create mode 100644 tests/plugins/test_plugins_registry.py diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index a67235baa47679..444d2aaef978b5 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,4 +1,8 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 +from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 + call_training_type_register_plugins, + TrainingTypePluginsRegistry, +) from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 @@ -47,3 +51,10 @@ 'DDPShardedPlugin', 'DDPSpawnShardedPlugin', ] + +from pathlib import Path + +FILE_ROOT = Path(__file__).parent +TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type" + +call_training_type_register_plugins(FILE_ROOT, TRAINING_TYPE_BASE_MODULE) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py new file mode 100644 index 00000000000000..59dd7d8db6bffb --- /dev/null +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -0,0 +1,146 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +from collections import UserDict +from inspect import getmembers, isclass +from pathlib import Path +from typing import Any, Callable, List, Optional + +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class _TrainingTypePluginsRegistry(UserDict): + """ + This class is a Registry that stores information about the Training Type Plugins. + + The Plugins are mapped to strings. These strings are names that idenitify + a plugin, e.g., "deepspeed". It also returns Optional description and + parameters to initialize the Plugin, which were defined durng the + registration. + + The motivation for having a TrainingTypePluginRegistry is to make it convenient + for the Users to try different Plugins by passing just strings + to the plugins flag to the Trainer. + + Example:: + + @TrainingTypePluginsRegistry.register("lightning", description="Super fast", a=1, b=True) + class LightningPlugin: + def __init__(self, a, b): + ... + + or + + TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True) + + """ + + def register( + self, + name: str, + plugin: Optional[Callable] = None, + description: Optional[str] = None, + override: bool = False, + **init_params: Any, + ) -> Callable: + """ + Registers a plugin mapped to a name and with required metadata. + + Args: + name : the name that identifies a plugin, e.g. "deepspeed_stage_3" + plugin : plugin class + description : plugin description + override : overrides the registered plugin, if True + init_params: parameters to initialize the plugin + """ + if not (name is None or isinstance(name, str)): + raise TypeError(f'`name` must be a str, found {name}') + + if name in self and not override: + raise MisconfigurationException( + f"'{name}' is already present in the registry." + " HINT: Use `override=True`." + ) + + data = {} + data["description"] = description if description is not None else "" + + data["init_params"] = init_params + + def do_register(plugin: Callable) -> Callable: + data["plugin"] = plugin + self[name] = data + return plugin + + if plugin is not None: + return do_register(plugin) + + return do_register + + def get(self, name: str) -> Any: + """ + Calls the registered plugin with the required parameters + and returns the plugin object + + Args: + name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3" + """ + if name in self: + data = self[name] + return data["plugin"](**data["init_params"]) + + err_msg = "'{}' not found in registry. Available names: {}" + available_names = ", ".join(sorted(self.keys())) or "none" + raise KeyError(err_msg.format(name, available_names)) + + def remove(self, name: str) -> None: + """Removes the registered plugin by name""" + self.pop(name) + + def available_plugins(self) -> List: + """Returns a list of registered plugins""" + return list(self.keys()) + + def __str__(self) -> str: + return "Registered Plugins: {}".format(", ".join(self.keys())) + + +TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() + + +def is_register_plugins_overridden(plugin: Callable) -> bool: + method_name = "register_plugins" + plugin_attr = getattr(plugin, method_name) + super_attr = getattr(TrainingTypePlugin, method_name) + + if hasattr(plugin_attr, 'patch_loader_code'): + is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__) + else: + is_overridden = plugin_attr.__code__ is not super_attr.__code__ + return is_overridden + + +def call_training_type_register_plugins(root: Path, base_module: str) -> None: + # Ref: https://github.com/facebookresearch/ClassyVision/blob/master/classy_vision/generic/registry_utils.py#L14 + directory = "training_type" + for file in os.listdir(root / directory): + if file.endswith(".py") and not file.startswith("_"): + module = file[:file.find(".py")] + module = importlib.import_module(".".join([base_module, module])) + for _, mod in getmembers(module, isclass): + if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod): + mod.register_plugins(TrainingTypePluginsRegistry) + break diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index ce2fea839837b9..9c67a1ccb53a9a 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -542,3 +542,23 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) -> if total_batch_idx % self._original_accumulate_grad_batches == 0: current_global_step += 1 return current_global_step + + @classmethod + def register_plugins(cls, plugin_registry): + plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin") + plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) + plugin_registry.register( + "deepspeed_stage_2_offload", + cls, + description="DeepSpeed ZeRO Stage 2 and CPU Offload", + stage=2, + cpu_offload=True + ) + plugin_registry.register("deepspeed_stage_3", cls, description="DeepSpeed ZeRO Stage 3", stage=3) + plugin_registry.register( + "deepspeed_stage_3_offload", + cls, + description="DeepSpeed ZeRO Stage 3 and CPU Offload", + stage=3, + cpu_offload=True + ) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6dbcfadcb2bc41..1be8be78fd0ef1 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -286,3 +286,7 @@ def call_configure_sharded_model_hook(self) -> bool: @call_configure_sharded_model_hook.setter def call_configure_sharded_model_hook(self, mode: bool) -> None: self._call_configure_sharded_model_hook = mode + + @classmethod + def register_plugins(cls, plugin_registry): + pass diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 1f086bbee8ca3a..475f935fd835f6 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -42,6 +42,7 @@ TPUHalfPrecisionPlugin, TPUSpawnPlugin, TrainingTypePlugin, + TrainingTypePluginsRegistry, ) from pytorch_lightning.plugins.environments import ( ClusterEnvironment, @@ -163,7 +164,16 @@ def handle_given_plugins( cluster_environment = None for plug in plugins: - if isinstance(plug, str): + if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: + if training_type is None: + training_type = TrainingTypePluginsRegistry.get(plug) + else: + raise MisconfigurationException( + 'You can only specify one precision and one training type plugin.' + ' Found more than 1 training type plugin:' + f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}' + ) + elif isinstance(plug, str): # Reset the distributed type as the user has overridden training type # via the plugins argument self._distrib_type = None @@ -530,7 +540,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): rank_zero_warn( 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' ) - # todo: in some cases it yield in comarison None and int + # todo: in some cases it yield in comparison None and int if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): self._distrib_type = DistributedType.DDP else: diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py new file mode 100644 index 00000000000000..91d9596578dfc9 --- /dev/null +++ b/tests/plugins/test_plugins_registry.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import TrainingTypePluginsRegistry +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin +from tests.helpers.runif import RunIf + + +def test_training_type_plugins_registry_with_new_plugin(): + + class TestPlugin: + + def __init__(self, param1, param2): + self.param1 = param1 + self.param2 = param2 + + plugin_name = "test_plugin" + plugin_description = "Test Plugin" + + TrainingTypePluginsRegistry.register( + plugin_name, TestPlugin, description=plugin_description, param1="abc", param2=123 + ) + + assert plugin_name in TrainingTypePluginsRegistry + assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description + assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123} + assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin) + + TrainingTypePluginsRegistry.remove(plugin_name) + assert plugin_name not in TrainingTypePluginsRegistry + + +@pytest.mark.parametrize( + "plugin_name, init_params", + [ + ("deepspeed", {}), + ("deepspeed_stage_2", { + "stage": 2 + }), + ("deepspeed_stage_2_offload", { + "stage": 2, + "cpu_offload": True + }), + ("deepspeed_stage_3", { + "stage": 3 + }), + ("deepspeed_stage_3_offload", { + "stage": 3, + "cpu_offload": True + }), + ], +) +def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init_params): + + assert plugin_name in TrainingTypePluginsRegistry + assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == init_params + assert TrainingTypePluginsRegistry[plugin_name]["plugin"] == DeepSpeedPlugin + + +@RunIf(deepspeed=True) +@pytest.mark.parametrize("plugin", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"]) +def test_training_type_plugins_registry_with_trainer(tmpdir, plugin): + + trainer = Trainer( + default_root_dir=tmpdir, + plugins=plugin, + precision=16, + ) + + assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin) From 3fb8eada34057a514189220466ae76c1b1f24b1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 20:34:14 +0200 Subject: [PATCH 06/21] rc2 (#7057) --- pytorch_lightning/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/__about__.py b/pytorch_lightning/__about__.py index c699c6aa2bceb9..0ce2273febf005 100644 --- a/pytorch_lightning/__about__.py +++ b/pytorch_lightning/__about__.py @@ -1,7 +1,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.3.0rc1' +__version__ = '1.3.0rc2' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' From 6a7b4cf5d349aa64938149f2d2629cb3fdd364af Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Sat, 17 Apr 2021 01:33:41 +0530 Subject: [PATCH 07/21] Fix mypy for plugins registry (#7062) --- pytorch_lightning/plugins/plugins_registry.py | 9 ++++++--- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 59dd7d8db6bffb..755c4f5be18ada 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -16,7 +16,7 @@ from collections import UserDict from inspect import getmembers, isclass from pathlib import Path -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -75,7 +75,7 @@ def register( " HINT: Use `override=True`." ) - data = {} + data: Dict[str, Any] = {} data["description"] = description if description is not None else "" data["init_params"] = init_params @@ -90,7 +90,7 @@ def do_register(plugin: Callable) -> Callable: return do_register - def get(self, name: str) -> Any: + def get(self, name: str, default: Optional[Any] = None) -> Any: """ Calls the registered plugin with the required parameters and returns the plugin object @@ -102,6 +102,9 @@ def get(self, name: str) -> Any: data = self[name] return data["plugin"](**data["init_params"]) + if default is not None: + return default + err_msg = "'{}' not found in registry. Available names: {}" available_names = ", ".join(sorted(self.keys())) or "none" raise KeyError(err_msg.format(name, available_names)) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 9c67a1ccb53a9a..34a9f504082e1f 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -544,7 +544,7 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) -> return current_global_step @classmethod - def register_plugins(cls, plugin_registry): + def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin") plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) plugin_registry.register( From 8bcd16976797d4ac31b780c4dd854d587638c98c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 16 Apr 2021 13:18:54 -0700 Subject: [PATCH 08/21] [fix] Fix multi-node DDP launch by using local rank instead of global rank for main process (#7061) * Update ddp.py * Update CHANGELOG.md --- CHANGELOG.md | 3 +++ pytorch_lightning/plugins/training_type/ddp.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85a827104d2c46..84c7feb765443a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -202,6 +202,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed multi-node DDP sub-process launch by using `local_rank` instead of `global_rank` for main process assertion ([#7061](https://github.com/PyTorchLightning/pytorch-lightning/pull/7061)) + + - Fixed incorrect removal of `WORLD_SIZE` environment variable in DDP training when launching with torch distributed/torchelastic ([#6942](https://github.com/PyTorchLightning/pytorch-lightning/pull/6942)) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 977145a4cc7bae..28910e9b77fa31 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -110,7 +110,7 @@ def setup_environment(self): def _call_children_scripts(self): # bookkeeping of spawned processes - assert self.global_rank == 0 + assert self.local_rank == 0 self._check_can_spawn_children() self._has_spawned_children = True From 7b0b0d284494d08e3983321d0cc42fe9e5faeb41 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 16 Apr 2021 21:22:19 +0100 Subject: [PATCH 09/21] update (#7056) --- pytorch_lightning/accelerators/accelerator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 30454436994b5d..21cfe08b6852a9 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -372,12 +372,7 @@ def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None: def to_device(self, batch: Any) -> Any: """Pushes the batch to the root device""" - # Todo (tchaton) Better fix - is_dict = isinstance(batch, dict) - if is_dict: - batch = [batch] - batch = self.batch_to_device(batch, self.root_device) - return batch[0] if is_dict else batch + return self.batch_to_device(batch, self.root_device) @property def amp_backend(self) -> Optional[LightningEnum]: From 97be8432261d5a49e245ca6afb8941ad173c2574 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Sun, 18 Apr 2021 14:53:12 +0530 Subject: [PATCH 10/21] Better approach to register plugins (#7063) * Better approach to register plugins * Add ddp_with_find_unused_parameters_false * Remove unnecessary break * Revert back the ddp commit * Update register override logic * Update register override logic * fix mypy --- pytorch_lightning/plugins/plugins_registry.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 755c4f5be18ada..319c9053c2094a 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -import os +import inspect from collections import UserDict from inspect import getmembers, isclass from pathlib import Path @@ -124,10 +124,16 @@ def __str__(self) -> str: TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() -def is_register_plugins_overridden(plugin: Callable) -> bool: +def is_register_plugins_overridden(plugin: type) -> bool: + method_name = "register_plugins" plugin_attr = getattr(plugin, method_name) - super_attr = getattr(TrainingTypePlugin, method_name) + previous_super_cls = inspect.getmro(plugin)[1] + + if issubclass(previous_super_cls, TrainingTypePlugin): + super_attr = getattr(previous_super_cls, method_name) + else: + return False if hasattr(plugin_attr, 'patch_loader_code'): is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__) @@ -137,13 +143,7 @@ def is_register_plugins_overridden(plugin: Callable) -> bool: def call_training_type_register_plugins(root: Path, base_module: str) -> None: - # Ref: https://github.com/facebookresearch/ClassyVision/blob/master/classy_vision/generic/registry_utils.py#L14 - directory = "training_type" - for file in os.listdir(root / directory): - if file.endswith(".py") and not file.startswith("_"): - module = file[:file.find(".py")] - module = importlib.import_module(".".join([base_module, module])) - for _, mod in getmembers(module, isclass): - if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod): - mod.register_plugins(TrainingTypePluginsRegistry) - break + module = importlib.import_module(base_module) + for _, mod in getmembers(module, isclass): + if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod): + mod.register_plugins(TrainingTypePluginsRegistry) From 71b4611c64059d7589e4d80115209fd2c89e8bdb Mon Sep 17 00:00:00 2001 From: Soham Roy Date: Sun, 18 Apr 2021 15:58:04 +0530 Subject: [PATCH 11/21] Update default gym env version to CartPole-v1 (#7079) Version v1 generates a better baseline with higher max_episodes and reward_threshold attained. changed_params --> register( id='CartPole-v1', entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500, reward_threshold=475.0, ) --- pl_examples/domain_templates/reinforce_learn_Qnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 4d90faeb45bcfb..70726a748818c1 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -20,7 +20,7 @@ To run the template, just run: `python reinforce_learn_Qnet.py` -After ~1500 steps, you will see the total_reward hitting the max score of 200. +After ~1500 steps, you will see the total_reward hitting the max score of 475+. Open up TensorBoard to see the metrics: `tensorboard --logdir default` @@ -149,7 +149,7 @@ class Agent: """ Base Agent class handling the interaction with the environment - >>> env = gym.make("CartPole-v0") + >>> env = gym.make("CartPole-v1") >>> buffer = ReplayBuffer(10) >>> Agent(env, buffer) # doctest: +ELLIPSIS <...reinforce_learn_Qnet.Agent object at ...> @@ -229,7 +229,7 @@ def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') - class DQNLightning(pl.LightningModule): """ Basic DQN Model - >>> DQNLightning(env="CartPole-v0") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> DQNLightning(env="CartPole-v1") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE DQNLightning( (net): DQN( (net): Sequential(...) @@ -393,7 +393,7 @@ def add_model_specific_args(parent_parser): # pragma: no-cover parser = parent_parser.add_argument_group("DQNLightning") parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") - parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") + parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag") parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network") parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer") From 30b7440e123d5a91806b493995293e97a2ab80f8 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 19 Apr 2021 03:12:48 +0530 Subject: [PATCH 12/21] TPU Spawn Rank & root device Error (#7074) * TPU Spawn Rank Error * Update tpu spawn * Fix root device property for tpu spawn * Update changelog --- CHANGELOG.md | 3 +++ .../plugins/training_type/tpu_spawn.py | 17 ++++++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84c7feb765443a..763bd2248ef2b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -285,6 +285,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941)) +- Fixed the order to call for world ranks & the `root_device` property in `TPUSpawnPlugin` ([#7074](https://github.com/PyTorchLightning/pytorch-lightning/pull/7074)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 902471ea55f51a..4f0e1bc2ee1045 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -65,6 +65,10 @@ def local_rank(self) -> int: def world_size(self) -> int: return self.num_processes + @property + def root_device(self) -> torch.device: + return self.device + @staticmethod def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): if not isinstance(dataloaders, list): @@ -116,9 +120,7 @@ def is_distributed(self): def process_dataloader(self, dataloader: 'DataLoader') -> MpDeviceLoader: TPUSpawnPlugin._validate_dataloader(dataloader) - device = xm.xla_device() - dataloader = MpDeviceLoader(dataloader, device) - return dataloader + return MpDeviceLoader(dataloader, self.device) def configure_ddp(self) -> None: pass @@ -127,8 +129,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: pass def set_world_ranks(self, process_idx: int = 0) -> None: - self.tpu_local_core_rank = xm.get_local_ordinal() - self.tpu_global_core_rank = xm.get_ordinal() + pass def new_process(self, process_idx: int, trainer, mp_queue) -> None: self.mp_queue = mp_queue @@ -137,7 +138,8 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: if seed is not None: seed_everything(int(seed)) - self.set_world_ranks() + self.tpu_local_core_rank = xm.get_local_ordinal() + self.tpu_global_core_rank = xm.get_ordinal() # set warning rank rank_zero_only.rank = self.global_rank @@ -163,7 +165,8 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: time.sleep(2) def model_to_device(self) -> None: - self._model.to(xm.xla_device()) + self.device = xm.xla_device() + self.model.to(self.device) def barrier(self, name: Optional[str] = None) -> None: rendezvous(name) From 490ddce2ac0f45df425362eec241e327e2d299e8 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 19 Apr 2021 15:50:42 +0530 Subject: [PATCH 13/21] Update CI torch-xla version to 1.8 (#7019) * Update CI torch-xla version to 1.8 * Update minimal to 1.6 --- .github/workflows/ci_test-tpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_test-tpu.yml b/.github/workflows/ci_test-tpu.yml index b2c3d09107d932..22bb7bd7cd4e5f 100644 --- a/.github/workflows/ci_test-tpu.yml +++ b/.github/workflows/ci_test-tpu.yml @@ -23,7 +23,7 @@ jobs: fail-fast: false matrix: python-version: [3.7] - xla-version: [1.6, 1.7] + xla-version: [1.6, 1.8] # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 50 From 898ec8a94adc79a480b57c1992c4881ac862bcee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 19 Apr 2021 14:43:16 +0200 Subject: [PATCH 14/21] Create pytorch_lightning/utilities/types.py (#7048) --- pytorch_lightning/accelerators/accelerator.py | 27 ++++++------- pytorch_lightning/core/lightning.py | 23 ++++++----- .../plugins/precision/apex_amp.py | 13 +++--- pytorch_lightning/plugins/precision/double.py | 8 ++-- .../plugins/precision/precision_plugin.py | 13 +++--- .../plugins/precision/tpu_bfloat.py | 8 ++-- .../logger_connector/metrics_holder.py | 15 ++++--- pytorch_lightning/trainer/evaluation_loop.py | 40 +++++++++---------- pytorch_lightning/trainer/predict_loop.py | 2 +- pytorch_lightning/utilities/types.py | 13 ++++++ 10 files changed, 85 insertions(+), 77 deletions(-) create mode 100644 pytorch_lightning/utilities/types.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 21cfe08b6852a9..a7a82297623037 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch from torch import Tensor @@ -27,12 +27,11 @@ from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT if _NATIVE_AMP_AVAILABLE: from torch.cuda.amp import GradScaler -_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None] - class Accelerator: """ @@ -63,9 +62,9 @@ def __init__( self.precision_plugin = precision_plugin self.training_type_plugin = training_type_plugin - self.optimizers: Sequence = [] - self.lr_schedulers: Sequence = [] - self.optimizer_frequencies: Sequence = [] + self.optimizers: List = [] + self.lr_schedulers: List = [] + self.optimizer_frequencies: List = [] def connect(self, model: 'pl.LightningModule') -> None: """Transfers ownership of the model to this plugin""" @@ -166,7 +165,7 @@ def on_train_start(self) -> None: def training_step( self, args: List[Union[Any, int]], - ) -> _STEP_OUTPUT_TYPE: + ) -> STEP_OUTPUT: """The actual training step. Args: @@ -188,7 +187,7 @@ def training_step( def post_training_step(self) -> None: self.training_type_plugin.post_training_step() - def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + def validation_step(self, args: List[Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual validation step. Args: @@ -207,7 +206,7 @@ def validation_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context(): return self.training_type_plugin.validation_step(*args) - def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + def test_step(self, args: List[Union[Any, int]]) -> Optional[STEP_OUTPUT]: """The actual test step. Args: @@ -226,7 +225,7 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context(): return self.training_type_plugin.test_step(*args) - def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: + def predict_step(self, args: List[Union[Any, int]]) -> STEP_OUTPUT: """The actual predict step. Args: @@ -246,7 +245,7 @@ def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE: with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): return self.training_type_plugin.predict_step(*args) - def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: + def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: """A hook to do something at the end of the training step Args: @@ -254,7 +253,7 @@ def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """ return self.training_type_plugin.training_step_end(output) - def test_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: + def test_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: """A hook to do something at the end of the test step Args: @@ -262,7 +261,7 @@ def test_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """ return self.training_type_plugin.test_step_end(output) - def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: + def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: """A hook to do something at the end of the validation step Args: @@ -331,7 +330,7 @@ def clip_gradients( """clips all the optimizer parameters to the given value""" self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm) - def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None: + def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """Hook to do something on the end of an training epoch Args: diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 637e9159e4fc1a..32f6ee366e7a1d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -43,6 +43,7 @@ from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT log = logging.getLogger(__name__) @@ -483,7 +484,7 @@ def all_gather( all_gather = partial(all_gather, group=group, sync_grads=sync_grads) return apply_to_collection(data, torch.Tensor, all_gather) - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> Any: r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define the operations you want to use for prediction (i.e.: on a server or as a feature extractor). @@ -535,7 +536,7 @@ def forward(self, batch): """ return super().forward(*args, **kwargs) - def training_step(self, *args, **kwargs): + def training_step(self, *args, **kwargs) -> STEP_OUTPUT: r""" Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger. @@ -601,7 +602,7 @@ def training_step(self, batch, batch_idx, hiddens): """ rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer") - def training_step_end(self, *args, **kwargs): + def training_step_end(self, *args, **kwargs) -> STEP_OUTPUT: """ Use this when training with dp or ddp2 because :meth:`training_step` will operate on only part of the batch. However, this is still optional @@ -663,7 +664,7 @@ def training_step_end(self, training_step_outputs): See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details. """ - def training_epoch_end(self, outputs: List[Any]) -> None: + def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """ Called at the end of the training epoch with the outputs of all training steps. Use this in case you need to do something with all the outputs for every training_step. @@ -704,7 +705,7 @@ def training_epoch_end(self, training_step_outputs): # do something here """ - def validation_step(self, *args, **kwargs): + def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: r""" Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. @@ -791,7 +792,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx): the model goes back to training mode and gradients are enabled. """ - def validation_step_end(self, *args, **kwargs): + def validation_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: """ Use this when validating with dp or ddp2 because :meth:`validation_step` will operate on only part of the batch. However, this is still optional @@ -845,7 +846,7 @@ def validation_step_end(self, val_step_outputs): See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details. """ - def validation_epoch_end(self, outputs: List[Any]) -> None: + def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """ Called at the end of the validation epoch with the outputs of all validation steps. @@ -890,7 +891,7 @@ def validation_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def test_step(self, *args, **kwargs): + def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: r""" Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest @@ -966,7 +967,7 @@ def test_step(self, batch, batch_idx, dataloader_idx): to training mode and gradients are enabled. """ - def test_step_end(self, *args, **kwargs): + def test_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: """ Use this when testing with dp or ddp2 because :meth:`test_step` will operate on only part of the batch. However, this is still optional @@ -1020,7 +1021,7 @@ def test_step_end(self, output_results): See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details. """ - def test_epoch_end(self, outputs: List[Any]) -> None: + def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """ Called at the end of a test epoch with the output of all test steps. @@ -1071,7 +1072,7 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: """ Use this function with trainer.predict(...). Override if you need to add any processing logic. """ diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 30614d3faa1877..762095e10e0ae9 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, ContextManager, Iterator, List, Sequence, Tuple, Type +from typing import Any, Callable, ContextManager, List, Sequence, Tuple, Type import torch from torch import Tensor @@ -21,8 +21,7 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType - -PARAMETERS = Iterator[torch.nn.Parameter] +from pytorch_lightning.utilities.types import _PARAMETERS if _APEX_AVAILABLE: from apex import amp @@ -36,15 +35,15 @@ def __init__(self, amp_level: str = "O2") -> None: self.backend = AMPType.APEX self.amp_level = amp_level - def master_params(self, optimizer: Optimizer) -> PARAMETERS: + def master_params(self, optimizer: Optimizer) -> _PARAMETERS: return amp.master_params(optimizer) def connect( self, model: Module, - optimizers: Sequence[Optimizer], - lr_schedulers: Sequence[Any], - ) -> Tuple[Module, Sequence[Optimizer], Sequence[Any]]: + optimizers: List[Optimizer], + lr_schedulers: List[Any], + ) -> Tuple[Module, List[Optimizer], List[Any]]: """Connects the precision plugin to the training process, configures apex and reinits the schedulers """ diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index d7486345d2a080..268b7480d483ad 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import wraps -from typing import Any, List, Sequence, Tuple +from typing import Any, List, Tuple import torch import torch.nn as nn @@ -71,9 +71,9 @@ def __init__(self) -> None: def connect( self, model: nn.Module, - optimizers: Sequence[Optimizer], - lr_schedulers: Sequence[Any], - ) -> Tuple[nn.Module, Sequence[Optimizer], Sequence[Any]]: + optimizers: List[Optimizer], + lr_schedulers: List[Any], + ) -> Tuple[nn.Module, List[Optimizer], List[Any]]: """Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`, `predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter `optimizers` or `lr_schedulers`.""" diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index ac33eeea287eb2..c1ea3287964a8d 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Callable, Iterator, Sequence, Tuple, Union +from typing import Any, Callable, List, Tuple, Union import torch from torch import Tensor @@ -22,8 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.base_plugin import Plugin from pytorch_lightning.utilities import GradClipAlgorithmType - -PARAMETERS = Iterator[torch.nn.Parameter] +from pytorch_lightning.utilities.types import _PARAMETERS class PrecisionPlugin(Plugin): @@ -35,7 +34,7 @@ class PrecisionPlugin(Plugin): EPSILON: float = 1e-6 precision: Union[str, int] = 32 - def master_params(self, optimizer: Optimizer) -> PARAMETERS: + def master_params(self, optimizer: Optimizer) -> _PARAMETERS: """ The master params of the model. Returns the plain model params here. Maybe different in other precision plugins. @@ -47,9 +46,9 @@ def master_params(self, optimizer: Optimizer) -> PARAMETERS: def connect( self, model: Module, - optimizers: Sequence[Optimizer], - lr_schedulers: Sequence[Any], - ) -> Tuple[Module, Sequence[Optimizer], Sequence[Any]]: + optimizers: List[Optimizer], + lr_schedulers: List[Any], + ) -> Tuple[Module, List[Optimizer], List[Any]]: """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers diff --git a/pytorch_lightning/plugins/precision/tpu_bfloat.py b/pytorch_lightning/plugins/precision/tpu_bfloat.py index 8561b73208cf8e..6534aa11045b8c 100644 --- a/pytorch_lightning/plugins/precision/tpu_bfloat.py +++ b/pytorch_lightning/plugins/precision/tpu_bfloat.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Sequence, Tuple +from typing import Any, List, Tuple import torch.nn as nn from torch.optim import Optimizer @@ -28,8 +28,8 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin): def connect( self, model: nn.Module, - optimizers: Sequence[Optimizer], - lr_schedulers: Sequence[Any], - ) -> Tuple[nn.Module, Sequence[Optimizer], Sequence[Any]]: + optimizers: List[Optimizer], + lr_schedulers: List[Any], + ) -> Tuple[nn.Module, List[Optimizer], List[Any]]: os.environ["XLA_USE_BF16"] = str(1) return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 1efbcc638674fc..8f12f57c640b0d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import numbers -from typing import Any, Dict, Optional, Union +from typing import Dict, Optional import torch from torchmetrics import Metric from pytorch_lightning.utilities.exceptions import MisconfigurationException - -_METRIC_TYPE = Union[Metric, torch.Tensor, int, float, Any] +from pytorch_lightning.utilities.types import _METRIC class MetricsHolder: @@ -31,16 +30,16 @@ class MetricsHolder: """ def __init__(self, to_float: bool = False) -> None: - self.metrics: Dict[str, _METRIC_TYPE] = {} + self.metrics: Dict[str, _METRIC] = {} self._to_float = to_float def update(self, metrics: dict) -> None: self.metrics.update(metrics) - def pop(self, key: str, default: _METRIC_TYPE) -> _METRIC_TYPE: + def pop(self, key: str, default: _METRIC) -> _METRIC: return self.metrics.pop(key, default) - def reset(self, metrics: Dict[str, _METRIC_TYPE]) -> None: + def reset(self, metrics: Dict[str, _METRIC]) -> None: self.metrics = metrics def convert(self, device: Optional[torch.device]) -> None: @@ -57,7 +56,7 @@ def convert(self, device: Optional[torch.device]) -> None: self.metrics[key] = converted @staticmethod - def _convert_to_float(current: _METRIC_TYPE) -> float: + def _convert_to_float(current: _METRIC) -> float: if isinstance(current, Metric): current = current.compute().detach() @@ -70,7 +69,7 @@ def _convert_to_float(current: _METRIC_TYPE) -> float: return current @staticmethod - def _convert_to_tensor(current: _METRIC_TYPE, device: Optional[torch.device]) -> torch.Tensor: + def _convert_to_tensor(current: _METRIC, device: Optional[torch.device]) -> torch.Tensor: if isinstance(current, Metric): current = current.compute().detach() diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2019244362cd32..9f10ca8306ff30 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,30 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Tuple, Union +from torch.utils.data import DataLoader + +import pytorch_lightning as pl from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache -if TYPE_CHECKING: - from torch.utils.data import DataLoader - - from pytorch_lightning import Trainer - from pytorch_lightning.accelerators.accelerator import _STEP_OUTPUT_TYPE - class EvaluationLoop(object): - def __init__(self, trainer: 'Trainer'): - self.trainer: 'Trainer' = trainer - self.outputs: List['_STEP_OUTPUT_TYPE'] = [] + def __init__(self, trainer: 'pl.Trainer'): + self.trainer: 'pl.Trainer' = trainer + self.outputs: EPOCH_OUTPUT = [] self.predictions: Optional[PredictionCollection] = None self.max_batches: Optional[List[Union[int, float]]] = None - self.warning_cache: WarningCache = WarningCache() + self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None def on_trainer_init(self) -> None: @@ -51,7 +49,7 @@ def on_trainer_init(self) -> None: # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True - def get_evaluation_dataloaders(self) -> Tuple[Optional[List['DataLoader']], List[Union[int, float]]]: + def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[Union[int, float]]]: model = self.trainer.lightning_module # select dataloaders @@ -83,14 +81,14 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: else: self.trainer.call_hook('on_validation_start', *args, **kwargs) - def on_evaluation_model_eval(self, *_: Any, **__: Any) -> None: + def on_evaluation_model_eval(self) -> None: model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() - def on_evaluation_model_train(self, *_: Any, **__: Any) -> None: + def on_evaluation_model_train(self) -> None: model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_train() @@ -114,7 +112,7 @@ def reload_evaluation_dataloaders(self) -> None: else: self.trainer.reset_val_dataloader(model) - def setup(self, max_batches: List[Union[int, float]], dataloaders: List['DataLoader']) -> None: + def setup(self, max_batches: List[Union[int, float]], dataloaders: List[DataLoader]) -> None: # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) @@ -148,7 +146,7 @@ def _build_args(self, batch: Any, batch_idx: int, dataloader_idx: int) -> List[U return args - def _get_num_dataloaders(self, dataloaders: Optional[List['DataLoader']]) -> int: + def _get_num_dataloaders(self, dataloaders: Optional[List[DataLoader]]) -> int: # case where user does: # return dl1, dl2 if dataloaders is not None: @@ -159,7 +157,7 @@ def _get_num_dataloaders(self, dataloaders: Optional[List['DataLoader']]) -> int else: return 0 - def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> '_STEP_OUTPUT_TYPE': + def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]: # configure args args = self._build_args(batch, batch_idx, dataloader_idx) @@ -183,14 +181,14 @@ def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> '_ return output - def evaluation_step_end(self, *args: Any, **kwargs: Any) -> '_STEP_OUTPUT_TYPE': + def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: if self.trainer.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output - def evaluation_epoch_end(self, outputs: List['_STEP_OUTPUT_TYPE']) -> None: + def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() @@ -221,7 +219,7 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: def on_evaluation_batch_end( self, - output: '_STEP_OUTPUT_TYPE', + output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int, @@ -234,7 +232,7 @@ def on_evaluation_batch_end( # store predicitons if do_write_predictions and track eval loss history self.store_predictions(output, batch_idx, dataloader_idx) - def store_predictions(self, output: '_STEP_OUTPUT_TYPE', batch_idx: int, dataloader_idx: int) -> None: + def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, dataloader_idx: int) -> None: # Add step predictions to prediction collection to write later if output is not None and self.predictions is not None: if isinstance(output, Result) and self.trainer.testing: diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index dd77115c64a009..064ba6730d38c7 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -39,7 +39,7 @@ def get_predict_dataloaders(self): def should_skip_predict(self, max_batches): return sum(max_batches) == 0 - def on_predict_model_eval(self, *_, **__): + def on_predict_model_eval(self): model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py new file mode 100644 index 00000000000000..c1c40b98c71c75 --- /dev/null +++ b/pytorch_lightning/utilities/types.py @@ -0,0 +1,13 @@ +from typing import Any, Dict, Iterator, List, Union + +import torch +from torchmetrics import Metric +""" +Convention: + - Do not include any `_TYPE` suffix + - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) +""" +_METRIC = Union[Metric, torch.Tensor, int, float] +STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] +EPOCH_OUTPUT = List[STEP_OUTPUT] +_PARAMETERS = Iterator[torch.nn.Parameter] From e61daff5cc15f7680b0dd85d8f936f491f6979ae Mon Sep 17 00:00:00 2001 From: mlech26l Date: Mon, 19 Apr 2021 14:48:44 +0200 Subject: [PATCH 15/21] Typo LightningMoule -> LightningModule (#7038) --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 32f6ee366e7a1d..d207fac7276ae4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -248,7 +248,7 @@ def log( The default behavior per hook is as follows .. csv-table:: ``*`` also applies to the test loop - :header: "LightningMoule Hook", "on_step", "on_epoch", "prog_bar", "logger" + :header: "LightningModule Hook", "on_step", "on_epoch", "prog_bar", "logger" :widths: 20, 10, 10, 10, 10 "training_step", "T", "F", "F", "T" From fbee5a86e70282664f172c26aa307cf3c7882b34 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 19 Apr 2021 15:48:48 +0200 Subject: [PATCH 16/21] Correctly reset metric objects in self.log (#7055) * reset * fix tests * fix tests * Apply suggestions from code review Co-authored-by: ananthsub * move logic * chglog * pep8 * Add test * Improve test Co-authored-by: Ethan Harris Co-authored-by: ananthsub --- CHANGELOG.md | 3 + pytorch_lightning/core/step_result.py | 21 ++-- .../logger_connector/epoch_result_store.py | 17 ++- pytorch_lightning/trainer/trainer.py | 6 +- tests/core/test_metric_result_integration.py | 2 + .../trainer/logging_/test_logger_connector.py | 116 +++++++++++++++++- 6 files changed, 149 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 763bd2248ef2b2..f25f6e783bb581 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -288,6 +288,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the order to call for world ranks & the `root_device` property in `TPUSpawnPlugin` ([#7074](https://github.com/PyTorchLightning/pytorch-lightning/pull/7074)) +- Fixed metric objects passed directly to `self.log` not being reset correctly ([#7055](https://github.com/PyTorchLightning/pytorch-lightning/pull/7055)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 7a193662b597bd..f2cdd31ab739fc 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -287,16 +287,12 @@ def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict: if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # reset metric anyway so state does not accumulate - # NOTE: we must compute before reseting just in case the computed value is needed - # later (i.e. if the step metric gets visited first, and then the epoch metric) + # compute for reuse later self[k].compute() - self[k].reset() return result @@ -319,16 +315,12 @@ def get_epoch_pbar_metrics(self, add_dataloader_idx=False): if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] if k in self and not options['on_epoch'] and isinstance(self[k], Metric): - # reset metric anyway so state does not accumulate - # NOTE: we must compute before reseting just in case the computed value is needed - # later (i.e. if the step metric gets visited first, and then the epoch metric) + # compute for reuse later self[k].compute() - self[k].reset() return result @@ -348,7 +340,6 @@ def get_forked_metrics(self, add_dataloader_idx=False): if options['forked']: if isinstance(self[k], Metric): result[dl_key] = self[k].compute().detach() - self[k].reset() else: result[dl_key] = self[k] @@ -587,6 +578,14 @@ def get_non_metrics_keys(self): """ return [k for k, v in self.items() if not isinstance(v, Metric)] + def reset(self) -> None: + """ + Call at the end of epoch to reset all metric objects + """ + for k, value in self.items(): + if isinstance(value, Metric): + value.reset() + def choose_last(x): if isinstance(x, (torch.Tensor, list)): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 594da76192aed1..61f66dbed9dfaf 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -233,6 +233,18 @@ def auto_reduce_results_on_epoch_end(self) -> None: self.has_reduced = True + def reset(self) -> None: + """ + Call at the end of epoch to reset Result objects + """ + for dl_idx in range(self.num_dataloaders): + epoch_metrics = self._internals[dl_idx] if not self.has_reduced else self._internals_reduced[dl_idx] + if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + for opt_idx in list(epoch_metrics): + epoch_metrics[opt_idx].reset() + else: + epoch_metrics.reset() + def __getitem__(self, key: str) -> Any: return self._internals.get(key, None) @@ -262,6 +274,7 @@ def __init__(self, trainer: 'pl.Trainer') -> None: _should_warn = trainer.accelerator_connector.is_distributed _should_warn &= not trainer.training_type_plugin.rpc_enabled self._should_warn = _should_warn + self._internals = {} self.reset() @@ -442,7 +455,9 @@ def get_epoch_log_metrics(self) -> Dict: def get_forked_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_forked_metrics") - def reset(self): + def reset(self) -> None: + for k, value in self._internals.items(): + value.reset() self._internals = {} self._dataloader_idx: Optional[int] = None self._split_idx: Optional[int] = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 62c8f530dca064..772f2dc4aa6f57 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -667,9 +667,6 @@ def run_evaluation(self, on_epoch=False): ) self.validating = True - # reset cached results - self.logger_connector.reset() - # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders() @@ -759,6 +756,9 @@ def run_evaluation(self, on_epoch=False): # enable train mode again self.evaluation_loop.on_evaluation_model_train() + # reset cached results + self.logger_connector.reset() + torch.set_grad_enabled(True) return eval_loop_results diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 0b797dff0e42f0..725655c54136df 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -76,6 +76,7 @@ def _ddp_test_fn(rank, worldsize): assert batch_expected[k] == batch_log[k] epoch_log = result.get_epoch_log_metrics() + result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] @@ -127,6 +128,7 @@ def test_result_metric_integration(): assert batch_expected[k] == batch_log[k] epoch_log = result.get_epoch_log_metrics() + result.reset() # assert metric state reset to default values assert metric_a.x == metric_a._defaults['x'] diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index bddde7e77f5a84..118899e32276e8 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -22,10 +22,11 @@ import pytest import torch from torch.utils.data import DataLoader +from torchmetrics import Accuracy, AveragePrecision +from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.step_result import Result -from pytorch_lightning.metrics import Accuracy from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder @@ -590,3 +591,116 @@ def validation_step(self, batch, batch_idx): assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 1 assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 3 + + +def test_metrics_reset(tmpdir): + """Tests that metrics are reset correctly after the end of the train/val/test epoch.""" + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 1) + + for stage in ['train', 'val', 'test']: + acc = Accuracy() + acc.reset = mock.Mock(side_effect=acc.reset) + ap = AveragePrecision(num_classes=1, pos_label=1) + ap.reset = mock.Mock(side_effect=ap.reset) + self.add_module(f"acc_{stage}", acc) + self.add_module(f"ap_{stage}", ap) + + def forward(self, x): + return self.layer(x) + + def _step(self, stage, batch): + labels = (batch.detach().sum(1) > 0).float() # Fake some targets + logits = self.forward(batch) + loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1)) + probs = torch.sigmoid(logits.detach()) + self.log(f"loss/{stage}", loss) + + acc = self._modules[f"acc_{stage}"] + ap = self._modules[f"ap_{stage}"] + + labels_int = labels.to(torch.long) + acc(probs, labels_int) + ap(probs, labels_int) + + # Metric.forward calls reset so reset the mocks here + acc.reset.reset_mock() + ap.reset.reset_mock() + + self.log(f"{stage}/accuracy", acc) + self.log(f"{stage}/ap", ap) + + return loss + + def training_step(self, batch, batch_idx, *args, **kwargs): + return self._step('train', batch) + + def validation_step(self, batch, batch_idx, *args, **kwargs): + return self._step('val', batch) + + def test_step(self, batch, batch_idx, *args, **kwargs): + return self._step('test', batch) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def _assert_epoch_end(self, stage): + acc = self._modules[f"acc_{stage}"] + ap = self._modules[f"ap_{stage}"] + + acc.reset.asset_not_called() + ap.reset.assert_not_called() + + def on_train_epoch_end(self, outputs): + self._assert_epoch_end('train') + + def on_validation_epoch_end(self, outputs): + self._assert_epoch_end('val') + + def on_test_epoch_end(self, outputs): + self._assert_epoch_end('test') + + def _assert_called(model, stage): + acc = model._modules[f"acc_{stage}"] + ap = model._modules[f"ap_{stage}"] + + acc.reset.assert_called_once() + acc.reset.reset_mock() + + ap.reset.assert_called_once() + ap.reset.reset_mock() + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=1, + progress_bar_refresh_rate=0, + ) + + trainer.fit(model) + _assert_called(model, 'train') + _assert_called(model, 'val') + + trainer.validate(model) + _assert_called(model, 'val') + + trainer.test(model) + _assert_called(model, 'test') From 4d5f10108a7ee0bb938152999ebad704854baa18 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 19 Apr 2021 15:49:25 +0200 Subject: [PATCH 17/21] skipp drafts for full test (#7046) --- .github/workflows/ci_test-full.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index ec9e71c5b83b29..c668186689b607 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -6,11 +6,14 @@ on: # Trigger the workflow on push or pull request, but only for the master bra branches: [master, "release/*"] pull_request: branches: [master, "release/*"] + types: [opened, reopened, ready_for_review, synchronize] jobs: + pytest: runs-on: ${{ matrix.os }} + if: github.event.pull_request.draft == false strategy: fail-fast: false matrix: From e9fca760ac651954dc16c91c16c5fd66aa1ce6fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Apr 2021 15:50:31 +0200 Subject: [PATCH 18/21] Set `DistributedSampler` seed if `seed_everything` was called (#7024) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- CHANGELOG.md | 3 +++ docs/source/advanced/multi_gpu.rst | 2 ++ pytorch_lightning/trainer/data_loading.py | 7 +++++-- tests/accelerators/ddp_model.py | 2 +- tests/plugins/test_deepspeed_plugin.py | 2 +- tests/trainer/test_dataloaders.py | 22 ++++++++++++++++------ 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f25f6e783bb581..c105b76c8ea8db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -128,6 +128,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed warnings and recommendations for dataloaders in `ddp_spawn` ([#6762](https://github.com/PyTorchLightning/pytorch-lightning/pull/6762/)) +- `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024)) + + ### Deprecated - Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/)) diff --git a/docs/source/advanced/multi_gpu.rst b/docs/source/advanced/multi_gpu.rst index 2ab5d5ec8ac68d..cda6b3c6d4e9c5 100644 --- a/docs/source/advanced/multi_gpu.rst +++ b/docs/source/advanced/multi_gpu.rst @@ -102,6 +102,8 @@ Lightning adds the correct samplers when needed, so no need to explicitly add sa .. note:: By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. ``drop_last`` in :class:`~torch.utils.data.distributed.DistributedSampler` will be set to its default value in PyTorch. + If you called :func:`~pytorch_lightning.utilities.seed.seed_everyting`, Lightning will set the same seed for the + sampler. .. note:: You can disable this behavior with ``Trainer(replace_sampler_ddp=False)`` diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 1f2f6a82f5267e..8fd39dfe94a899 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect import multiprocessing +import os from abc import ABC from copy import deepcopy from typing import Iterable, List, Optional, Tuple, Union @@ -24,7 +25,7 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger @@ -197,7 +198,9 @@ def __init__(self, num_features, dataset, *args, **kwargs): def _get_distributed_sampler(self, dataloader, shuffle): kwargs = self.distributed_sampler_kwargs - kwargs['shuffle'] = shuffle and not self.overfit_batches + kwargs["shuffle"] = shuffle and not self.overfit_batches + if _TORCH_GREATER_EQUAL_1_6: + kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler diff --git a/tests/accelerators/ddp_model.py b/tests/accelerators/ddp_model.py index 78d1306665c59d..b052c8f49d51f1 100644 --- a/tests/accelerators/ddp_model.py +++ b/tests/accelerators/ddp_model.py @@ -25,7 +25,7 @@ def main(): - seed_everything(1234) + seed_everything(4321) parser = ArgumentParser(add_help=False) parser = Trainer.add_argparse_args(parser) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index ac3c922c338113..c768a9aabf8fb8 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -463,7 +463,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): def run_checkpoint_test(tmpdir, save_full_weights): - seed_everything(42) + seed_everything(1) model = ModelParallelClassificationModel() dm = ClassifDataModule() ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 7f9cf6210ce7c7..89793ba71ed14c 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -23,8 +23,9 @@ from torch.utils.data.sampler import SequentialSampler import tests.helpers.pipelines as tpipes -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -739,26 +740,35 @@ class CustomSampler(torch.utils.data.Sampler): class DistribSamplerCallback(Callback): + def __init__(self, expected_seeds=(0, 0, 0)): + self.expected_seed = expected_seeds + def on_train_start(self, trainer, pl_module): train_sampler = trainer.train_dataloader.sampler assert isinstance(train_sampler, DistributedSampler) assert train_sampler.shuffle + if _TORCH_GREATER_EQUAL_1_6: + assert train_sampler.seed == self.expected_seed[0] def on_validation_start(self, trainer, pl_module): val_sampler = trainer.val_dataloaders[0].sampler assert isinstance(val_sampler, DistributedSampler) assert not val_sampler.shuffle + if _TORCH_GREATER_EQUAL_1_6: + assert val_sampler.seed == self.expected_seed[1] def on_test_start(self, trainer, pl_module): test_sampler = trainer.test_dataloaders[0].sampler assert isinstance(test_sampler, DistributedSampler) assert not test_sampler.shuffle + if _TORCH_GREATER_EQUAL_1_6: + assert test_sampler.seed == self.expected_seed[2] @RunIf(min_gpus=2, skip_windows=True) def test_dataloader_distributed_sampler(tmpdir): """ Test DistributedSampler and it's arguments for DDP backend """ - + seed_everything(123) model = EvalModelTemplate() trainer = Trainer( gpus=[0, 1], @@ -766,7 +776,7 @@ def test_dataloader_distributed_sampler(tmpdir): accelerator='ddp_spawn', default_root_dir=tmpdir, max_steps=1, - callbacks=[DistribSamplerCallback()], + callbacks=[DistribSamplerCallback(expected_seeds=(123, 123, 123))], ) trainer.fit(model) trainer.test(ckpt_path=None) @@ -776,7 +786,7 @@ class ModelWithDataLoaderDistributedSampler(EvalModelTemplate): def train_dataloader(self): dataloader = super().train_dataloader() - dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True) + dist_sampler = DistributedSampler(dataloader.dataset, shuffle=True, seed=11) return DataLoader( dataloader.dataset, batch_size=self.batch_size, drop_last=False, sampler=dist_sampler, shuffle=False ) @@ -785,7 +795,7 @@ def train_dataloader(self): @RunIf(min_gpus=2, skip_windows=True) def test_dataloader_distributed_sampler_already_attached(tmpdir): """ Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader """ - + seed_everything(123) model = ModelWithDataLoaderDistributedSampler() trainer = Trainer( gpus=[0, 1], @@ -793,7 +803,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir): accelerator='ddp_spawn', default_root_dir=tmpdir, max_steps=100, - callbacks=[DistribSamplerCallback()], + callbacks=[DistribSamplerCallback(expected_seeds=(11, 123, 0))], replace_sampler_ddp=True, ) trainer.fit(model) From a5e356adb1bd80ff9b39dd55fe8a39dea32202fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 19 Apr 2021 15:53:21 +0200 Subject: [PATCH 19/21] Deprecate `@auto_move_data` in favor of `trainer.predict` (#6993) * Deprecated `@auto_move_data` in favor of `trainer.predict` * Update CHANGELOG --- CHANGELOG.md | 3 ++ docs/source/common/trainer.rst | 27 ----------- docs/source/starter/introduction_guide.rst | 24 ++++++---- pytorch_lightning/core/__init__.py | 1 - pytorch_lightning/core/decorators.py | 7 ++- pytorch_lightning/core/lightning.py | 55 +++++----------------- tests/deprecated_api/test_remove_1-5.py | 11 +++++ 7 files changed, 49 insertions(+), 79 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c105b76c8ea8db..59575e2d0216c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -148,6 +148,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349)) +- Deprecated `@auto_move_data` in favor of `trainer.predict` ([#6993](https://github.com/PyTorchLightning/pytorch-lightning/pull/6993)) + + - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), [#6540](https://github.com/PyTorchLightning/pytorch-lightning/pull/6540), diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index dd432011261e54..36dda97917e037 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -174,33 +174,6 @@ Once you're done training, feel free to run the test set! ------------ -Deployment / prediction ------------------------ -You just trained a LightningModule which is also just a torch.nn.Module. -Use it to do whatever! - -.. code-block:: python - - # load model - pretrained_model = LightningModule.load_from_checkpoint(PATH) - pretrained_model.freeze() - - # use it for finetuning - def forward(self, x): - features = pretrained_model(x) - classes = classifier(features) - - # or for prediction - out = pretrained_model(x) - api_write({'response': out} - - -You may wish to run the model on a variety of devices. Instead of moving the data -manually to the correct device, decorate the forward method (or any other method you use for inference) -with :func:`~pytorch_lightning.core.decorators.auto_move_data` and Lightning will take care of the rest. - ------------- - Reproducibility --------------- diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index c399047034fc04..8d35e271856497 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -855,7 +855,8 @@ In this case, we've set this LightningModel to predict logits. But we could also x = mnist_image() feature_maps = model(x) -Or maybe we have a model that we use to do generation +Or maybe we have a model that we use to do generation. +A :class:`~pytorch_lightning.core.lightning.LightningModule` is also just a :class:`torch.nn.Module`. .. testcode:: @@ -880,8 +881,11 @@ Or maybe we have a model that we use to do generation generated_imgs = model(z) -To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict_step`` function -By default, LightningModule ``predict_step`` calls forward, but it can be overriden to add any processing logic. +To perform inference at scale, it is possible to use :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict` +with :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` +By default, :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` +calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`, +but it can be overridden to add any processing logic. .. code-block:: python @@ -899,14 +903,18 @@ By default, LightningModule ``predict_step`` calls forward, but it can be overri trainer.predict(model, datamodule) -How you split up what goes in ``forward`` vs ``training_step`` vs ``predict`` depends on how you want to use this model for -prediction. -However, we recommend ``forward`` to contain only tensor operation with your model, ``training_step`` to encapsulate ``forward`` logic with logging, -metrics and loss computation and ``predict`` to encapsulate ``forward`` with preprocess, postprocess functions. +How you split up what goes in :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` +vs :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` +vs :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` depends on how you want to use this model for prediction. +However, we recommend :meth:`~pytorch_lightning.core.lightning.LightningModule.forward` to contain only tensor operations with your model. +:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` to encapsulate +:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` logic with logging, metrics, and loss computation. +:meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` to encapsulate +:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` with any necessary preprocess or postprocess functions. ---------------- -The nonessentials +The non-essentials ================== Extensibility diff --git a/pytorch_lightning/core/__init__.py b/pytorch_lightning/core/__init__.py index bcab67d821e098..7a909fe8c916d4 100644 --- a/pytorch_lightning/core/__init__.py +++ b/pytorch_lightning/core/__init__.py @@ -19,4 +19,3 @@ 'LightningDataModule', 'LightningModule', ] -# __call__ = __all__ diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 5def3c0caa445b..51c602add95410 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -16,7 +16,7 @@ from functools import wraps from typing import Callable -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn def auto_move_data(fn: Callable) -> Callable: @@ -61,6 +61,11 @@ def auto_transfer_args(self, *args, **kwargs): args, kwargs = self.transfer_batch_to_device((args, kwargs)) return fn(self, *args, **kwargs) + rank_zero_deprecation( + "The `@auto_move_data` decorator is deprecated in v1.3 and will be removed in v1.5." + f" Please use `trainer.predict` instead for inference. The decorator was applied to `{fn.__name__}`" + ) + return auto_transfer_args diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d207fac7276ae4..8466dc88cfce00 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -486,53 +486,14 @@ def all_gather( def forward(self, *args, **kwargs) -> Any: r""" - Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define - the operations you want to use for prediction (i.e.: on a server or as a feature extractor). - - Normally you'd call ``self()`` from your :meth:`training_step` method. - This makes it easy to write a complex system for training with the outputs - you'd want in a prediction setting. - - You may also find the :func:`~pytorch_lightning.core.decorators.auto_move_data` decorator useful - when using the module outside Lightning in a production setting. + Same as :meth:`torch.nn.Module.forward()`. Args: *args: Whatever you decide to pass into the forward method. **kwargs: Keyword arguments are also possible. Return: - Predicted output - - Examples:: - - # example if we were using this model as a feature extractor - def forward(self, x): - feature_maps = self.convnet(x) - return feature_maps - - def training_step(self, batch, batch_idx): - x, y = batch - feature_maps = self(x) - logits = self.classifier(feature_maps) - - # ... - return loss - - # splitting it this way allows model to be used a feature extractor - model = MyModelAbove() - - inputs = server.get_request() - results = model(inputs) - server.write_results(results) - - # ------------- - # This is in stark contrast to torch.nn.Module where normally you would have this: - def forward(self, batch): - x, y = batch - feature_maps = self.convnet(x) - logits = self.classifier(feature_maps) - return logits - + Your model's output """ return super().forward(*args, **kwargs) @@ -1074,7 +1035,17 @@ def test_epoch_end(self, outputs): def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: """ - Use this function with trainer.predict(...). Override if you need to add any processing logic. + Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict` + By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. + Override to add any processing logic. + + Args: + batch: Current batch + batch_idx: Index of current batch + dataloader_idx: Index of the current dataloader + + Return: + Predicted output """ return self(batch) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 8757fb625d7605..b1227e20d18885 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,6 +20,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.core.decorators import auto_move_data from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache @@ -231,3 +232,13 @@ def test_v1_5_0_trainer_training_trick_mixin(tmpdir): dummy_loss = torch.tensor(1.0) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): trainer.detect_nan_tensors(dummy_loss) + + +def test_v1_5_0_auto_move_data(): + with pytest.deprecated_call(match="deprecated in v1.3 and will be removed in v1.5.*was applied to `bar`"): + + class Foo: + + @auto_move_data + def bar(self): + pass From 2b232d3fbd213da17b3ee7274d888160db34dbcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Apr 2021 16:08:09 +0200 Subject: [PATCH 20/21] fix docs rendering in datamodule (#7064) * [docs]: add newline to correctly render Example * whitespace Co-authored-by: Matthew Sarmiento --- pytorch_lightning/core/datamodule.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 4178c9eeacd503..2dd6f0b76a1b05 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -58,7 +58,7 @@ def track_data_hook_calls(fn): - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. - Its corresponding `dm_has_setup_{stage}` attribute gets set to True + Its corresponding `dm_has_setup_{stage}` attribute gets set to True - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` Args: @@ -319,11 +319,12 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): Args: args: The parser or namespace to take arguments from. Only known arguments will be - parsed and passed to the :class:`LightningDataModule`. + parsed and passed to the :class:`LightningDataModule`. **kwargs: Additional keyword arguments that may override ones in the parser or namespace. - These must be valid DataModule arguments. + These must be valid DataModule arguments. Example:: + parser = ArgumentParser(add_help=False) parser = LightningDataModule.add_argparse_args(parser) module = LightningDataModule.from_argparse_args(args) From d1529c28a18af26a99cf42ba3dc37b04f39f1ed2 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 19 Apr 2021 23:08:49 +0900 Subject: [PATCH 21/21] Optimization docs (#6907) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * . * . * Fix link to the section * Fix link to the section * Consistent indent * Update docs * Apply suggestions from code review Co-authored-by: Carlos Mocholí Co-authored-by: Adrian Wälchli * Add note for optimizer.optimizer * . * Update hooks * Update closure docstring * Update optimizer methods * Update optimizer * Remove manopt + grad clipping (by @flukeskywalker) Co-authored-by: Carlos Mocholí Co-authored-by: Adrian Wälchli --- docs/source/common/lightning_module.rst | 22 +- docs/source/common/optimizers.rst | 586 +++++++++++++-------- pytorch_lightning/core/hooks.py | 7 +- pytorch_lightning/core/lightning.py | 155 +++--- pytorch_lightning/trainer/training_loop.py | 4 +- 5 files changed, 454 insertions(+), 320 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 912c5e4d309440..28cd6f2b8858fd 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -698,6 +698,12 @@ log_dict .. automethod:: pytorch_lightning.core.lightning.LightningModule.log_dict :noindex: +manual_backward +~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward + :noindex: + print ~~~~~ @@ -916,7 +922,10 @@ True if using Automatic Mixed Precision (AMP) automatic_optimization ~~~~~~~~~~~~~~~~~~~~~~ -When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling your optimizers. However, we do take care of precision and any accelerators used. +When set to ``False``, Lightning does not automate the optimization process. This means you are responsible for handling +your optimizers. However, we do take care of precision and any accelerators used. + +See :ref:`manual optimization` for details. .. code-block:: python @@ -931,7 +940,9 @@ When set to ``False``, Lightning does not automate the optimization process. Thi self.manual_backward(loss) opt.step() -This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. +This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note +that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. +Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. .. code-block:: python @@ -1086,13 +1097,6 @@ get_progress_bar_dict .. automethod:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict :noindex: -manual_backward -~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward - :noindex: - - on_after_backward ~~~~~~~~~~~~~~~~~ diff --git a/docs/source/common/optimizers.rst b/docs/source/common/optimizers.rst index 422302ea8987e7..d9b8d259110090 100644 --- a/docs/source/common/optimizers.rst +++ b/docs/source/common/optimizers.rst @@ -3,27 +3,39 @@ ************ Optimization ************ - Lightning offers two modes for managing the optimization process: -- automatic optimization (AutoOpt) +- automatic optimization - manual optimization -For the majority of research cases, **automatic optimization** will do the right thing for you and it is what -most users should use. +For the majority of research cases, **automatic optimization** will do the right thing for you and it is what most +users should use. For advanced/expert users who want to do esoteric optimization schedules or techniques, use **manual optimization**. ------- +----- Manual optimization =================== -For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable -to manually manage the optimization process. To do so, do the following: +For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to +manually manage the optimization process. + +This is only recommended for experts who need ultimate flexibility. +Lightning will handle only precision and accelerators logic. +The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc.. + +To manually optimize, do the following: + +* Set ``self.automatic_optimization=False`` in your ``LightningModule``'s ``__init__``. +* Use the following functions and call them manually: -* Set the ``automatic_optimization`` property to ``False`` in your ``LightningModule`` ``__init__`` function -* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``. + * ``self.optimizers()`` to access your optimizers (one or multiple) + * ``optimizer.zero_grad()`` to clear the gradients from the previous training step + * ``self.manual_backward(loss)`` instead of ``loss.backward()`` + * ``optimizer.step()`` to update your model parameters +Here is a minimal example of manual optimization. + .. testcode:: python from pytorch_lightning import LightningModule @@ -32,25 +44,37 @@ to manually manage the optimization process. To do so, do the following: def __init__(self): super().__init__() - # Important: This property activate ``manual optimization`` for your model + # Important: This property activates manual optimization. self.automatic_optimization = False def training_step(batch, batch_idx): opt = self.optimizers() + opt.zero_grad() loss = self.compute_loss(batch) self.manual_backward(loss) + opt.step() -.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc.. +.. warning:: + Before 1.2, ``optimizer.step()`` was calling ``optimizer.zero_grad()`` internally. + From 1.2, it is left to the user's expertise. -.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertise. +.. tip:: + Be careful where you call ``optimizer.zero_grad()``, or your model won't converge. + It is good practice to call ``optimizer.zero_grad()`` before ``self.manual_backward(loss)``. -.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such. +----- -.. tip:: ``self.optimizers()`` will return ``LightningOptimizer`` objects. You can access your own optimizer with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you. +Gradient accumulation +--------------------- +You can accumulate gradients over batches similarly to +:attr:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches` of automatic optimization. +To perform gradient accumulation with one optimizer, you can do as such. -.. code-block:: python +.. testcode:: python + # accumulate gradients over `n` batches def __init__(self): + super().__init__() self.automatic_optimization = False def training_step(self, batch, batch_idx): @@ -59,36 +83,16 @@ to manually manage the optimization process. To do so, do the following: loss = self.compute_loss(batch) self.manual_backward(loss) - # accumulate gradient batches - if batch_idx % 2 == 0: + # accumulate gradients of `n` batches + if (batch_idx + 1) % n == 0: opt.step() opt.zero_grad() -.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward``, ``zero_grad`` and ``backward`` of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs `_. +----- -Here is the same example as above using a ``closure``. - -.. testcode:: python - - def __init__(self): - self.automatic_optimization = False - - def training_step(self, batch, batch_idx): - opt = self.optimizers() - - def closure(): - # Only zero_grad on the first batch to accumulate gradients - is_first_batch_to_accumulate = batch_idx % 2 == 0 - if is_first_batch_to_accumulate: - opt.zero_grad() - - loss = self.compute_loss(batch) - self.manual_backward(loss) - return loss - - opt.step(closure=closure) - -.. tip:: Be careful where you call ``zero_grad`` or your model won't converge. It is good pratice to call ``zero_grad`` before ``manual_backward``. +Use multiple optimizers (like GANs) [manual] +-------------------------------------------- +Here is an example training a simple GAN with multiple optimizers. .. testcode:: python @@ -97,13 +101,12 @@ Here is the same example as above using a ``closure``. from pytorch_lightning import LightningModule class SimpleGAN(LightningModule): - def __init__(self): super().__init__() self.G = Generator() self.D = Discriminator() - # Important: This property activate ``manual optimization`` for this model + # Important: This property activates manual optimization. self.automatic_optimization = False def sample_z(self, n) -> Tensor: @@ -115,7 +118,8 @@ Here is the same example as above using a ``closure``. return self.G(z) def training_step(self, batch, batch_idx): - # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html + # Implementation follows the PyTorch tutorial: + # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html g_opt, d_opt = self.optimizers() X, _ = batch @@ -126,10 +130,9 @@ Here is the same example as above using a ``closure``. g_X = self.sample_G(batch_size) - ########################### - # Optimize Discriminator # - ########################### - d_opt.zero_grad() + ########################## + # Optimize Discriminator # + ########################## d_x = self.D(X) errD_real = self.criterion(d_x, real_label) @@ -138,17 +141,17 @@ Here is the same example as above using a ``closure``. errD = (errD_real + errD_fake) + d_opt.zero_grad() self.manual_backward(errD) d_opt.step() - ####################### - # Optimize Generator # - ####################### - g_opt.zero_grad() - + ###################### + # Optimize Generator # + ###################### d_z = self.D(g_X) errG = self.criterion(d_z, real_label) + g_opt.zero_grad() self.manual_backward(errG) g_opt.step() @@ -159,32 +162,98 @@ Here is the same example as above using a ``closure``. d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5) return g_opt, d_opt -.. note:: ``LightningOptimizer`` provides a ``toggle_model`` function as a ``@context_manager`` for advanced users. It can be useful when performing gradient accumulation with several optimizers or training in a distributed setting. +----- + +Learning rate scheduling [manual] +--------------------------------- +You can call ``lr_scheduler.step()`` at arbitrary intervals. +Use ``self.lr_schedulers()`` in your :class:`~pytorch_lightning.LightningModule` to access any learning rate schedulers +defined in your :meth:`~pytorch_lightning.LightningModule.configure_optimizers`. + +.. warning:: + * Before 1.3, Lightning automatically called ``lr_scheduler.step()`` in both automatic and manual optimization. From + 1.3, ``lr_scheduler.step()`` is now for the user to call at arbitrary intervals. + * Note that the lr_dict keys, such as ``"step"`` and ``""interval"``, will be ignored even if they are provided in + your ``configure_optimizers()`` during manual optimization. + +Here is an example calling ``lr_scheduler.step()`` every step. + +.. testcode:: python + + # step every batch + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # do forward, backward, and optimization + ... + + # single scheduler + sch = self.lr_schedulers() + sch.step() + + # multiple schedulers + sch1, sch2 = self.lr_schedulers() + sch1.step() + sch2.step() + +If you want to call ``lr_scheduler.step()`` every ``n`` steps/epochs, do the following. + +.. testcode:: python + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + # do forward, backward, and optimization + ... + + sch = self.lr_schedulers() + + # step every `n` batches + if (batch_idx + 1) % n == 0: + sch.step() + + # step every `n` epochs + if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % n == 0: + sch.step() + +----- + +Improve training speed with model toggling +------------------------------------------ +Toggling models can improve your training speed when performing gradient accumulation with multiple optimizers in a +distributed setting. Here is an explanation of what it does: -Considering the current optimizer as A and all other optimizers as B. -Toggling means that all parameters from B exclusive to A will have their ``requires_grad`` attribute set to ``False``. Their original state will be restored when exiting the context manager. +* Considering the current optimizer as A and all other optimizers as B. +* Toggling means that all parameters from B exclusive to A will have their ``requires_grad`` attribute set to ``False``. +* Their original state will be restored when exiting the context manager. When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting ``sync_grad`` to ``False`` will block this synchronization and improve your training speed. +:class:`~pytorch_lightning.core.optimizer.LightningOptimizer` provides a +:meth:`~pytorch_lightning.core.optimizer.LightningOptimizer.toggle_model` function as a +:func:`contextlib.contextmanager` for advanced users. Here is an example for advanced use-case. .. testcode:: python # Scenario for a GAN with gradient accumulation every 2 batches and optimized for multiple gpus. - class SimpleGAN(LightningModule): - ... - def __init__(self): + super().__init__() self.automatic_optimization = False def training_step(self, batch, batch_idx): - # Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html + # Implementation follows the PyTorch tutorial: + # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html g_opt, d_opt = self.optimizers() X, _ = batch @@ -194,14 +263,18 @@ Here is an example for advanced use-case. real_label = torch.ones((batch_size, 1), device=self.device) fake_label = torch.zeros((batch_size, 1), device=self.device) - accumulated_grad_batches = batch_idx % 2 == 0 + # Sync and clear gradients + # at the end of accumulation or + # at the end of an epoch. + is_last_batch_to_accumulate = \ + (batch_idx + 1) % 2 == 0 or self.trainer.is_last_batch g_X = self.sample_G(batch_size) - ########################### - # Optimize Discriminator # - ########################### - with d_opt.toggle_model(sync_grad=accumulated_grad_batches): + ########################## + # Optimize Discriminator # + ########################## + with d_opt.toggle_model(sync_grad=is_last_batch_to_accumulate): d_x = self.D(X) errD_real = self.criterion(d_x, real_label) @@ -211,36 +284,88 @@ Here is an example for advanced use-case. errD = (errD_real + errD_fake) self.manual_backward(errD) - if accumulated_grad_batches: + if is_last_batch_to_accumulate: d_opt.step() d_opt.zero_grad() - ####################### - # Optimize Generator # - ####################### - with g_opt.toggle_model(sync_grad=accumulated_grad_batches): + ###################### + # Optimize Generator # + ###################### + with g_opt.toggle_model(sync_grad=is_last_batch_to_accumulate): d_z = self.D(g_X) errG = self.criterion(d_z, real_label) self.manual_backward(errG) - if accumulated_grad_batches: + if is_last_batch_to_accumulate: g_opt.step() g_opt.zero_grad() self.log_dict({'g_loss': errG, 'd_loss': errD}, prog_bar=True) +----- + +Use closure for LBFGS-like optimizers +------------------------------------- +It is a good practice to provide the optimizer with a closure function that performs a ``forward``, ``zero_grad`` and +``backward`` of your model. It is optional for most optimizers, but makes your code compatible if you switch to an +optimizer which requires a closure, such as :class:`torch.optim.LBFGS`. + +See `the PyTorch docs `_ for more about the closure. + +Here is an example using a closure function. + +.. testcode:: python + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def configure_optimizers(self): + return torch.optim.LBFGS(...) + + def training_step(self, batch, batch_idx): + opt = self.optimizers() + + def closure(): + loss = self.compute_loss(batch) + opt.zero_grad() + self.manual_backward(loss) + return loss + + opt.step(closure=closure) + ------ +Access your own optimizer [manual] +---------------------------------- +``optimizer`` is a :class:`~pytorch_lightning.core.optimizer.LightningOptimizer` object wrapping your own optimizer +configured in your :meth:`~pytorch_lightning.LightningModule.configure_optimizers`. You can access your own optimizer +with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to +support accelerators and precision for you. + +.. testcode:: python + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(batch, batch_idx): + optimizer = self.optimizers() + + # `optimizer` is a `LightningOptimizer` wrapping the optimizer. + # To access it, do the following. + # However, it won't work on TPU, AMP, etc... + optimizer = optimizer.optimizer + ... + +----- + Automatic optimization ====================== -With Lightning most users don't have to think about when to call ``.zero_grad()``, ``.backward()`` and ``.step()`` +With Lightning, most users don't have to think about when to call ``.zero_grad()``, ``.backward()`` and ``.step()`` since Lightning automates that for you. -.. warning:: - Before 1.2.2, ``.zero_grad()`` was called after ``.backward()`` and ``.step()`` internally. - From 1.2.2, Lightning calls ``.zero_grad()`` before ``.backward()``. - -Under the hood Lightning does the following: +Under the hood, Lightning does the following: .. code-block:: python @@ -269,221 +394,220 @@ In the case of multiple optimizers, Lightning does the following: for lr_scheduler in lr_schedulers: lr_scheduler.step() +.. warning:: + Before 1.2.2, Lightning internally calls ``backward``, ``step`` and ``zero_grad`` in the order. + From 1.2.2, the order is changed to ``zero_grad``, ``backward`` and ``step``. + +----- Learning rate scheduling ------------------------ -Every optimizer you use can be paired with any `Learning Rate Scheduler `_. -In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers`` method: +Every optimizer you use can be paired with any +`Learning Rate Scheduler `_. In the basic +use-case, the scheduler(s) should be returned as the second output from the +:meth:`~pytorch_lightning.LightningModule.configure_optimizers` method: -.. testcode:: +.. testcode:: python # no LR scheduler def configure_optimizers(self): - return Adam(...) + return Adam(...) # Adam + LR scheduler def configure_optimizers(self): - optimizer = Adam(...) - scheduler = LambdaLR(optimizer, ...) - return [optimizer], [scheduler] + optimizer = Adam(...) + scheduler = LambdaLR(optimizer, ...) + return [optimizer], [scheduler] # Two optimizers each with a scheduler def configure_optimizers(self): - optimizer1 = Adam(...) - optimizer2 = SGD(...) - scheduler1 = LambdaLR(optimizer1, ...) - scheduler2 = LambdaLR(optimizer2, ...) - return [optimizer1, optimizer2], [scheduler1, scheduler2] + optimizer1 = Adam(...) + optimizer2 = SGD(...) + scheduler1 = LambdaLR(optimizer1, ...) + scheduler2 = LambdaLR(optimizer2, ...) + return [optimizer1, optimizer2], [scheduler1, scheduler2] -When there are schedulers in which the ``.step()`` method is conditioned on a metric value (for example the -:class:`~torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler), Lightning requires that the output -from ``configure_optimizers`` should be dicts, one for each optimizer, with the keyword ``monitor`` -set to metric that the scheduler should be conditioned on. +When there are schedulers in which the ``.step()`` method is conditioned on a metric value, such as the +:class:`~torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, Lightning requires that the output from +:meth:`~pytorch_lightning.LightningModule.configure_optimizers` should be dicts, one for each optimizer, with the +keyword ``"monitor"`` set to metric that the scheduler should be conditioned on. .. testcode:: - # The ReduceLROnPlateau scheduler requires a monitor - def configure_optimizers(self): - return { - 'optimizer': Adam(...), - 'lr_scheduler': ReduceLROnPlateau(optimizer, ...), - 'monitor': 'metric_to_track' - } - - # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler - def configure_optimizers(self): - optimizer1 = Adam(...) - optimizer2 = SGD(...) - scheduler1 = ReduceLROnPlateau(optimizer1, ...) - scheduler2 = LambdaLR(optimizer2, ...) - return ( - {'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'}, - {'optimizer': optimizer2, 'lr_scheduler': scheduler2}, - ) + # The ReduceLROnPlateau scheduler requires a monitor + def configure_optimizers(self): + optimizer = Adam(...) + return { + 'optimizer': optimizer, + 'lr_scheduler': ReduceLROnPlateau(optimizer, ...), + 'monitor': 'metric_to_track', + } + + # In the case of two optimizers, only one using the ReduceLROnPlateau scheduler + def configure_optimizers(self): + optimizer1 = Adam(...) + optimizer2 = SGD(...) + scheduler1 = ReduceLROnPlateau(optimizer1, ...) + scheduler2 = LambdaLR(optimizer2, ...) + return ( + {'optimizer': optimizer1, 'lr_scheduler': scheduler1, 'monitor': 'metric_to_track'}, + {'optimizer': optimizer2, 'lr_scheduler': scheduler2}, + ) .. note:: - Metrics can be made availble to condition on by simply logging it using ``self.log('metric_to_track', metric_val)`` - in your lightning module. + Metrics can be made available to monitor by simply logging it using ``self.log('metric_to_track', metric_val)`` in + your :class:`~pytorch_lightning.LightningModule`. -By default, all schedulers will be called after each epoch ends. To change this behaviour, a scheduler configuration should be -returned as a dict which can contain the following keywords: +By default, all schedulers will be called after each epoch ends. To change this behaviour, a scheduler configuration +should be returned as a dict which can contain the following keywords: -* ``scheduler`` (required): the actual scheduler object -* ``monitor`` (optional): metric to condition -* ``interval`` (optional): either ``epoch`` (default) for stepping after each epoch ends or ``step`` for stepping +* ``"scheduler"`` (required): the actual scheduler object +* ``"monitor"`` (optional): metric to condition +* ``"interval"`` (optional): either ``"epoch"`` (default) for stepping after each epoch ends or ``"step"`` for stepping after each optimization step -* ``frequency`` (optional): how many epochs/steps should pass between calls to ``scheduler.step()``. Default is 1, +* ``"frequency"`` (optional): how many epochs/steps should pass between calls to ``scheduler.step()``. Default is 1, corresponding to updating the learning rate after every epoch/step. -* ``strict`` (optional): if set to ``True`` will enforce that value specified in ``monitor`` is available while trying - to call ``scheduler.step()``, and stop training if not found. If ``False`` will only give a warning and continue training - (without calling the scheduler). -* ``name`` (optional): if using the :class:`~pytorch_lightning.callbacks.LearningRateMonitor` callback to monitor the - learning rate progress, this keyword can be used to specify a specific name the learning rate should be logged as. +* ``"strict"`` (optional): if set to ``True``, will enforce that value specified in ``"monitor"`` is available while + trying to call ``scheduler.step()``, and stop training if not found. If ``False``, it will only give a warning and + continue training without calling the scheduler. +* ``"name"`` (optional): if using the :class:`~pytorch_lightning.callbacks.LearningRateMonitor` callback to monitor the + learning rate progress, this keyword can be used to specify a name the learning rate should be logged as. -.. testcode:: - - # Same as the above example with additional params passed to the first scheduler - # In this case the ReduceLROnPlateau will step after every 10 processed batches - def configure_optimizers(self): - optimizers = [Adam(...), SGD(...)] - schedulers = [ - { - 'scheduler': ReduceLROnPlateau(optimizers[0], ...), - 'monitor': 'metric_to_track', - 'interval': 'step', - 'frequency': 10, - 'strict': True, - }, - LambdaLR(optimizers[1], ...) - ] - return optimizers, schedulers +.. testcode:: python ----------- + # Same as the above example with additional params passed to the first scheduler + # In this case the ReduceLROnPlateau will step after every 10 processed batches + def configure_optimizers(self): + optimizers = [Adam(...), SGD(...)] + schedulers = [ + { + 'scheduler': ReduceLROnPlateau(optimizers[0], ...), + 'monitor': 'metric_to_track', + 'interval': 'step', + 'frequency': 10, + 'strict': True, + }, + LambdaLR(optimizers[1], ...) + ] + return optimizers, schedulers + +----- Use multiple optimizers (like GANs) ----------------------------------- -To use multiple optimizers return two or more optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers` +To use multiple optimizers (optionally with learning rate schedulers), return two or more optimizers from +:meth:`~pytorch_lightning.core.LightningModule.configure_optimizers`. -.. testcode:: +.. testcode:: python - # one optimizer - def configure_optimizers(self): - return Adam(...) + # two optimizers, no schedulers + def configure_optimizers(self): + return Adam(...), SGD(...) - # two optimizers, no schedulers - def configure_optimizers(self): - return Adam(...), SGD(...) + # two optimizers, one scheduler for adam only + def configure_optimizers(self): + opt1 = Adam(...) + opt2 = SGD(...) + optimizers = [opt1, opt2] + lr_schedulers = {'scheduler': ReduceLROnPlateau(opt1, ...), 'monitor': 'metric_to_track'} + return optimizers, lr_schedulers - # Two optimizers, one scheduler for adam only - def configure_optimizers(self): - return [Adam(...), SGD(...)], {'scheduler': ReduceLROnPlateau(), 'monitor': 'metric_to_track'} + # two optimizers, two schedulers + def configure_optimizers(self): + opt1 = Adam(...) + opt2 = SGD(...) + return [opt1, opt2], [StepLR(opt1, ...), OneCycleLR(opt2, ...)] -Lightning will call each optimizer sequentially: +Under the hood, Lightning will call each optimizer sequentially: .. code-block:: python - for epoch in epochs: - for batch in data: - for opt in optimizers: - loss = train_step(batch, batch_idx, optimizer_idx) - opt.zero_grad() - loss.backward() - opt.step() + for epoch in epochs: + for batch in data: + for opt in optimizers: + loss = train_step(batch, batch_idx, optimizer_idx) + opt.zero_grad() + loss.backward() + opt.step() - for lr_scheduler in lr_schedulers: - lr_scheduler.step() + for lr_scheduler in lr_schedulers: + lr_scheduler.step() ----------- +----- Step optimizers at arbitrary intervals -------------------------------------- To do more interesting things with your optimizers such as learning rate warm-up or odd scheduling, -override the :meth:`optimizer_step` function. +override the :meth:`~pytorch_lightning.LightningModule.optimizer_step` function. -For example, here step optimizer A every 2 batches and optimizer B every 4 batches +.. warning:: + If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to + ``optimizer.step()`` function as shown in the examples because ``training_step()``, ``optimizer.zero_grad()``, + ``backward()`` are called in the closure function. -.. testcode:: +For example, here step optimizer A every batch and optimizer B every 2 batches. - def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx): - optimizer.zero_grad() +.. testcode:: python - # Alternating schedule for optimizer steps (ie: GANs) - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - # update generator opt every 2 steps + # Alternating schedule for optimizer steps (e.g. GANs) + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False, + ): + # update generator every step if optimizer_idx == 0: - if batch_nb % 2 == 0 : - optimizer.step(closure=closure) + optimizer.step(closure=optimizer_closure) - # update discriminator opt every 4 steps + # update discriminator every 2 steps if optimizer_idx == 1: - if batch_nb % 4 == 0 : - optimizer.step(closure=closure) + if (batch_idx + 1) % 2 == 0: + optimizer.step(closure=optimizer_closure) -Here we add a learning-rate warm up + # ... + # add as many optimizers as you want -.. testcode:: +Here we add a learning rate warm-up. + +.. testcode:: python # learning rate warm-up - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - # warm up lr + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False, + ): + # skip the first 500 steps if self.trainer.global_step < 500: lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) for pg in optimizer.param_groups: pg['lr'] = lr_scale * self.hparams.learning_rate # update params - optimizer.step(closure=closure) + optimizer.step(closure=optimizer_closure) -.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches and much more ... - -.. testcode:: - - # function hook in LightningModule - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - optimizer.step(closure=closure) +----- -.. note:: To access your wrapped Optimizer from ``LightningOptimizer``, do as follow. +Access your own optimizer +------------------------- +``optimizer`` is a :class:`~pytorch_lightning.core.optimizer.LightningOptimizer` object wrapping your own optimizer +configured in your :meth:`~pytorch_lightning.LightningModule.configure_optimizers`. You can access your own optimizer +with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to +support accelerators and precision for you. -.. testcode:: +.. testcode:: python # function hook in LightningModule - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - - # `optimizer is a ``LightningOptimizer`` wrapping the optimizer. - # To access it, do as follow: - optimizer = optimizer.optimizer - - # run step. However, it won't work on TPU, AMP, etc... - optimizer.step(closure=closure) - - ----------- - -Using the closure functions for optimization --------------------------------------------- - -When using optimization schemes such as LBFGS, the `second_order_closure` needs to be enabled. By default, this function is defined by wrapping the `training_step` and the backward steps as follows - -.. warning:: - Before 1.2.2, ``.zero_grad()`` was called outside the closure internally. - From 1.2.2, the closure calls ``.zero_grad()`` inside, so there is no need to define your own closure - when using similar optimizers to :class:`torch.optim.LBFGS` which requires reevaluation of the loss with the closure in ``optimizer.step()``. - -.. testcode:: - - def second_order_closure(pl_module, split_batch, batch_idx, opt_idx, optimizer, hidden): - # Model training step on a given batch - result = pl_module.training_step(split_batch, batch_idx, opt_idx, hidden) - - # Model backward pass - pl_module.backward(result, optimizer, opt_idx) - - # on_after_backward callback - pl_module.on_after_backward(result.training_step_output, batch_idx, result.loss) - - return result - - # This default `second_order_closure` function can be enabled by passing it directly into the `optimizer.step` - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): - # update params - optimizer.step(second_order_closure) + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False, + ): + optimizer.step(closure=optimizer_closure) + + # `optimizer` is a `LightningOptimizer` wrapping the optimizer. + # To access it, do the following. + # However, it won't work on TPU, AMP, etc... + def optimizer_step( + self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, + on_tpu=False, using_native_amp=False, using_lbfgs=False, + ): + optimizer = optimizer.optimizer + optimizer.step(closure=optimizer_closure) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1efac79b63b833..72ee3b3c52e4a6 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -260,7 +260,7 @@ def on_predict_end(self) -> None: def on_before_zero_grad(self, optimizer: Optimizer) -> None: """ - Called after optimizer.step() and before optimizer.zero_grad(). + Called after ``training_step()`` and before ``optimizer.zero_grad()``. Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated. @@ -268,10 +268,13 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None: This is where it is called:: for optimizer in optimizers: - optimizer.step() + out = training_step(...) + model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() + backward() + Args: optimizer: The optimizer for which grads should be zeroed. """ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8466dc88cfce00..291631e2652151 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1083,28 +1083,22 @@ def configure_optimizers(self): Return: Any of these 6 options. - - Single optimizer. - - List or Tuple - List of optimizers. - - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict). - - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' + - **Single optimizer**. + - **List or Tuple** of optimizers. + - **Two lists** - The first list has multiple optimizers, and the second has multiple LR schedulers (or + multiple lr_dict). + - **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"`` key whose value is a single LR scheduler or lr_dict. - - Tuple of dictionaries as described, with an optional 'frequency' key. - - None - Fit will run without any optimizer. + - **Tuple of dictionaries** as described above, with an optional ``"frequency"`` key. + - **None** - Fit will run without any optimizer. Note: - The 'frequency' value is an int corresponding to the number of sequential batches - optimized with the specific optimizer. It should be given to none or to all of the optimizers. - There is a difference between passing multiple optimizers in a list, - and passing multiple optimizers in dictionaries with a frequency of 1: - In the former case, all optimizers will operate on the given batch in each optimization step. - In the latter, only one optimizer will operate on the given batch at every step. - - The lr_dict is a dictionary which contains the scheduler and its associated configuration. - The default configuration is shown below. + The lr_dict is a dictionary which contains the scheduler and its associated configuration. The default + configuration is shown below. .. code-block:: python - { + lr_dict = { 'scheduler': lr_scheduler, # The LR scheduler instance (required) 'interval': 'epoch', # The unit of the scheduler's step size 'frequency': 1, # The frequency of the scheduler @@ -1114,43 +1108,51 @@ def configure_optimizers(self): 'name': None, # Custom name for LearningRateMonitor to use } - Only the ``scheduler`` key is required, the rest will be set to the defaults above. + Only the ``"scheduler"`` key is required, the rest will be set to the defaults above. + + Note: + The ``"frequency"`` value is an ``int`` corresponding to the number of sequential batches optimized with the + specific optimizer. It should be given to none or to all of the optimizers. + + There is a difference between passing multiple optimizers in a list and passing multiple optimizers in + dictionaries with a frequency of 1: + In the former case, all optimizers will operate on the given batch in each optimization step. + In the latter, only one optimizer will operate on the given batch at every step. Examples:: # most cases def configure_optimizers(self): - opt = Adam(self.parameters(), lr=1e-3) - return opt + return Adam(self.parameters(), lr=1e-3) # multiple optimizer case (e.g.: GAN) def configure_optimizers(self): - generator_opt = Adam(self.model_gen.parameters(), lr=0.01) - disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) - return generator_opt, disriminator_opt + gen_opt = Adam(self.model_gen.parameters(), lr=0.01) + dis_opt = Adam(self.model_dis.parameters(), lr=0.02) + return gen_opt, dis_opt # example with learning rate schedulers def configure_optimizers(self): - generator_opt = Adam(self.model_gen.parameters(), lr=0.01) - disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) - discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) - return [generator_opt, disriminator_opt], [discriminator_sched] + gen_opt = Adam(self.model_gen.parameters(), lr=0.01) + dis_opt = Adam(self.model_dis.parameters(), lr=0.02) + dis_sch = CosineAnnealing(dis_opt, T_max=10) + return [gen_opt, dis_opt], [dis_sch] # example with step-based learning rate schedulers def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_disc.parameters(), lr=0.02) - gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99), - 'interval': 'step'} # called after each training step - dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch - return [gen_opt, dis_opt], [gen_sched, dis_sched] + dis_opt = Adam(self.model_dis.parameters(), lr=0.02) + gen_sch = {'scheduler': ExponentialLR(gen_opt, 0.99), + 'interval': 'step'} # called after each training step + dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch + return [gen_opt, dis_opt], [gen_sch, dis_sch] # example with optimizer frequencies # see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1 # https://arxiv.org/abs/1704.00028 def configure_optimizers(self): gen_opt = Adam(self.model_gen.parameters(), lr=0.01) - dis_opt = Adam(self.model_disc.parameters(), lr=0.02) + dis_opt = Adam(self.model_dis.parameters(), lr=0.02) n_critic = 5 return ( {'optimizer': dis_opt, 'frequency': n_critic}, @@ -1158,32 +1160,22 @@ def configure_optimizers(self): ) Note: - Some things to know: - - Lightning calls ``.backward()`` and ``.step()`` on each optimizer - and learning rate scheduler as needed. - - - If you use 16-bit precision (``precision=16``), Lightning will automatically - handle the optimizers for you. - - - If you use multiple optimizers, :meth:`training_step` will have an additional - ``optimizer_idx`` parameter. - - - If you use LBFGS Lightning handles the closure function automatically for you. - - - If you use multiple optimizers, gradients will be calculated only - for the parameters of current optimizer at each training step. - - - If you need to control how often those optimizers step or override the - default ``.step()`` schedule, override the :meth:`optimizer_step` hook. - - - If you only want to call a learning rate scheduler every ``x`` step or epoch, - or want to monitor a custom metric, you can specify these in a lr_dict: + - Lightning calls ``.backward()`` and ``.step()`` on each optimizer and learning rate scheduler as needed. + - If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizers. + - If you use multiple optimizers, :meth:`training_step` will have an additional ``optimizer_idx`` parameter. + - If you use :class:`torch.optim.LBFGS`, Lightning handles the closure function automatically for you. + - If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer + at each training step. + - If you need to control how often those optimizers step or override the default ``.step()`` schedule, + override the :meth:`optimizer_step` hook. + - If you only want to call a learning rate scheduler every ``x`` step or epoch, or want to monitor a custom + metric, you can specify these in a lr_dict: .. code-block:: python - { + lr_dict = { 'scheduler': lr_scheduler, 'interval': 'step', # or 'epoch' 'monitor': 'val_f1', @@ -1196,23 +1188,21 @@ def configure_optimizers(self): def manual_backward(self, loss: Tensor, optimizer: Optional[Optimizer] = None, *args, **kwargs) -> None: """ Call this directly from your training_step when doing optimizations manually. - By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you + By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you. This function forwards all args to the .backward() call as well. - .. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set - - .. tip:: In manual mode we still automatically accumulate grad over batches if - Trainer(accumulate_grad_batches=x) is set and you use `optimizer.step()` + See :ref:`manual optimization` for more examples. Example:: def training_step(...): - opt_a, opt_b = self.optimizers() + opt = self.optimizers() loss = ... + opt.zero_grad() # automatically applies scaling, etc... self.manual_backward(loss) - opt_a.step() + opt.step() """ if optimizer is not None: rank_zero_deprecation( @@ -1322,18 +1312,18 @@ def optimizer_step( Warning: If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter to ``optimizer.step()`` function as shown in the examples. This ensures that - ``train_step_and_backward_closure`` is called within + ``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within :meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`. Args: epoch: Current epoch batch_idx: Index of current batch optimizer: A PyTorch optimizer - optimizer_idx: If you used multiple optimizers this indexes into that list. - optimizer_closure: closure for all optimizers - on_tpu: true if TPU backward is required - using_native_amp: True if using native amp - using_lbfgs: True if the matching optimizer is lbfgs + optimizer_idx: If you used multiple optimizers, this indexes into that list. + optimizer_closure: Closure for all optimizers + on_tpu: ``True`` if TPU backward is required + using_native_amp: ``True`` if using native amp + using_lbfgs: True if the matching optimizer is :class:`torch.optim.LBFGS` Examples:: @@ -1345,22 +1335,18 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, # Alternating schedule for optimizer steps (i.e.: GANs) def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): - # update generator opt every 2 steps + # update generator opt every step if optimizer_idx == 0: - if batch_idx % 2 == 0 : - optimizer.step(closure=optimizer_closure) - optimizer.zero_grad() + optimizer.step(closure=optimizer_closure) - # update discriminator opt every 4 steps + # update discriminator opt every 2 steps if optimizer_idx == 1: - if batch_idx % 4 == 0 : + if (batch_idx + 1) % 2 == 0 : optimizer.step(closure=optimizer_closure) - optimizer.zero_grad() # ... # add as many optimizers as you want - Here's another example showing how to use this for more advanced things such as learning rate warm-up: @@ -1377,7 +1363,6 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, # update params optimizer.step(closure=optimizer_closure) - optimizer.zero_grad() """ if not isinstance(optimizer, LightningOptimizer): @@ -1386,6 +1371,26 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer.step(closure=optimizer_closure) def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + """Override this method to change the default behaviour of ``optimizer.zero_grad()``. + + Args: + epoch: Current epoch + batch_idx: Index of current batch + optimizer: A PyTorch optimizer + optimizer_idx: If you used multiple optimizers this indexes into that list. + + Examples:: + + # DEFAULT + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.zero_grad() + + # Set gradients to `None` instead of zero to improve performance. + def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.zero_grad(set_to_none=True) + + See :meth:`torch.optim.Optimizer.zero_grad` for the explanation of the above example. + """ optimizer.zero_grad() def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3504c483501636..ea56e54e76e135 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -748,9 +748,7 @@ def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: return batch_outputs def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): - """ - wrap the forward step in a closure so second order methods work - """ + """Wrap forward, zero_grad and backward in a closure so second order methods work""" with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)