From a6e6e209823058855873f253b8a851c65cc81b74 Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 14:51:36 +0100 Subject: [PATCH 01/13] add basic accelerator class. Co-Authored with @awaelchi --- pytorch_lightning/accelerators/accelerator.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c5c77d4711e6a..efb6b51a821a7 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -18,8 +18,12 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins import TrainingTypePlugin +from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.enums import LightningEnum +from pytorch_lightning.utilities.enums import AMPType, LightningEnum class Accelerator(object): @@ -39,7 +43,7 @@ class Accelerator(object): def __init__( self, - precision_plugin, #: PrecisionPlugin # fixme + precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin, ) -> None: """ @@ -230,9 +234,8 @@ def backward( ) # TODO: this is a hack, find a better solution for this (hook?) - # fixme: uncomment when this class is added - # if isinstance(self.training_type_plugin, HorovodPlugin): - # optimizer.synchronize() + if isinstance(self.training_type_plugin, HorovodPlugin): + optimizer.synchronize() return output @@ -256,11 +259,9 @@ def optimizer_step( """ model_ref = self.lightning_module is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - # fixme: uncomment when this class is added - # is_native_amp = ( - # isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE - # ) - is_native_amp = False + native_amp = ( + isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE + ) self.precision_plugin.pre_optimizer_step(optimizer, opt_idx) self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx) @@ -273,7 +274,7 @@ def optimizer_step( optimizer_idx=opt_idx, optimizer_closure=lambda_closure, on_tpu=False, # TPUAccelerator class sets this as True - using_native_amp=is_native_amp, + using_native_amp=native_amp, using_lbfgs=is_lbfgs, ) @@ -339,13 +340,12 @@ def to_device(self, batch: Any) -> Any: @property def amp_backend(self) -> Optional[LightningEnum]: - # fixme: uncomment when this class is added - # if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): - # return AMPType.APEX - # elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): - # return AMPType.NATIVE - # return None - pass + if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + return AMPType.APEX + elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + return AMPType.NATIVE + else: + return None @property def precision(self) -> int: From 3c5b0a6ff5e65a9a8137796034280cb1892a6eb9 Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 15:03:29 +0100 Subject: [PATCH 02/13] add basic trainign type plugin. Co-Authored with @awaelchi --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index d1e7907d5d97f..795638efb5ee1 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -18,7 +18,6 @@ import torch from pytorch_lightning import _logger as log -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.base_plugin import Plugin From 5ab4b9ecec3ff940e5500c0caa8910f80eff82cc Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 15:22:58 +0100 Subject: [PATCH 03/13] pep8 Co-authored-by: @awaelchi --- pytorch_lightning/accelerators/accelerator.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index efb6b51a821a7..7087f6a261010 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -17,11 +17,13 @@ from torch.optim import Optimizer from pytorch_lightning.core import LightningModule -from pytorch_lightning.plugins import TrainingTypePlugin -from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin -from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin -from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.plugins.training_type import TrainingTypePlugin, HorovodPlugin +from pytorch_lightning.plugins.precision import ( + PrecisionPlugin, + MixedPrecisionPlugin, + ApexMixedPrecisionPlugin, + NativeMixedPrecisionPlugin, +) from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -327,7 +329,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn """ plugin.connect(model) - def connect_precision_plugin(self, plugin): #: PrecisionPlugin # fixme + def connect_precision_plugin(self, plugin: PrecisionPlugin): """Attaches the precision plugin to the accelerator""" model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers) self.model = model From ab4660d85fb4c9f496dfbe22539bfdbedf93d173 Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 17:57:31 +0100 Subject: [PATCH 04/13] update copyright MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/plugins/base_plugin.py | 24 +++++++++-------- .../plugins/precision/precision_plugin.py | 14 +++++----- .../training_type/training_type_plugin.py | 27 ++++++++++--------- setup.cfg | 4 +++ 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index c4eeff52751a6..f5cbf1f14acf5 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -12,46 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from abc import ABC, abstractmethod +from typing import Any, Generator, Optional, overload, Sequence, Tuple import torch -class Plugin(object): +class Plugin(ABC): """Basic Plugin class to derive precision and training type plugins from.""" - def connect(self, model: torch.nn.Module, *args, **kwargs): + @abstractmethod + def connect(self, model: torch.nn.Module, *args: Sequence, **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]: """Connects the plugin with the accelerator (and thereby with trainer and model). Will be called by the accelerator. """ - pass - def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int): + def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: """Hook to do something before each optimizer step.""" pass - def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int): + def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: """Hook to do something after each optimizer step.""" pass - def pre_training(self): + def pre_training(self) -> None: """Hook to do something before the training starts.""" pass - def post_training(self): + def post_training(self) -> None: """Hook to do something after the training finishes.""" pass @contextlib.contextmanager - def train_step_context(self): + def train_step_context(self) -> Generator: """A contextmanager for the trainstep""" yield @contextlib.contextmanager - def val_step_context(self): + def val_step_context(self) -> Generator: """A contextmanager for the validation step""" yield @contextlib.contextmanager - def test_step_context(self): + def test_step_context(self) -> Generator: """A contextmanager for the teststep""" - yield \ No newline at end of file + yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 0ff54bf1e8515..22c1a7539fff4 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Generator, Union +from typing import Any, Generator, Sequence, Tuple, Union import torch from torch.optim import Optimizer @@ -34,7 +34,7 @@ def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Ten for p in group["params"]: yield p - def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + def connect(self, model: torch.nn.Module, optimizers: Sequence, lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]: """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers @@ -45,9 +45,9 @@ def backward( optimizer: torch.optim.Optimizer, opt_idx: int, should_accumulate: bool, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: """performs the actual backpropagation Args: @@ -71,7 +71,7 @@ def backward( return closure_loss - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None: """Clips the gradients to a specific value""" # TODO: separate TPU case from here if clip_val is None: @@ -82,7 +82,7 @@ def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm if grad_clip_val <= 0: return - parameters = self.master_params(optimizer) + parameters = list(self.master_params(optimizer)) max_norm = grad_clip_val diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 795638efb5ee1..349ed689254ad 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,18 +13,20 @@ # limitations under the License. import os from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional, Sequence, Union import torch from pytorch_lightning import _logger as log +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.base_plugin import Plugin +from pytorch_lightning.trainer import Trainer class TrainingTypePlugin(Plugin, ABC): """A Plugin to change the behaviour of the training, validation and test-loop.""" - def __init__(self): + def __init__(self) -> None: self._model = None self._results = None self.global_rank = 0 @@ -40,7 +42,7 @@ def root_device(self) -> torch.device: """Returns the root device""" @abstractmethod - def model_to_device(self): + def model_to_device(self) -> None: """Moves the model to the correct device""" @property @@ -49,11 +51,11 @@ def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" @abstractmethod - def reduce(self, output, *args, **kwargs): + def reduce(self, output: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: """Reduces the given output (e.g. across GPUs/Processes)""" @abstractmethod - def barrier(self, name: Optional[str] = None): + def barrier(self, name: Optional[str] = None) -> None: """Forces all possibly joined processes to wait for each other""" @abstractmethod @@ -61,7 +63,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes""" # TODO method this is currently unused. Check after complete refactors are pushed - def set_nvidia_flags(self, is_slurm_managing_tasks, device_ids): + def set_nvidia_flags(self, is_slurm_managing_tasks: bool, device_ids: Optional[Sequence]) -> None: if device_ids is None: return @@ -69,7 +71,8 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, device_ids): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.environ.get("CUDA_VISIBLE_DEVICES", all_gpu_ids) - log.info(f"LOCAL_RANK: {self.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") + if self.lightning_module is not None: + log.info(f"LOCAL_RANK: {self.lightning_module.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") def reduce_early_stopping_decision(self, should_stop: bool) -> bool: """Reduce the early stopping decision across all possibly spawned processes""" @@ -81,16 +84,16 @@ def model(self) -> torch.nn.Module: return self._model @model.setter - def model(self, new_model: torch.nn.Module): + def model(self, new_model: torch.nn.Module) -> None: self._model = new_model @property - def lightning_module(self) -> LightningModule: + def lightning_module(self) -> Optional[LightningModule]: """Returns the pure LightningModule without potential wrappers""" return self._model @property - def results(self): + def results(self) -> Any: """ The results of the last training/testing run will be cached here. In distributed training, we make sure to transfer the results to the appropriate master process. @@ -102,10 +105,10 @@ def results(self): def rpc_enabled(self) -> bool: return False - def start_training(self, trainer: "Trainer") -> None: + def start_training(self, trainer: Trainer) -> None: # double dispatch to initiate the training loop self._results = trainer.train() - def start_testing(self, trainer: "Trainer") -> None: + def start_testing(self, trainer: Trainer) -> None: # double dispatch to initiate the test loop self._results = trainer.run_test() diff --git a/setup.cfg b/setup.cfg index 31046d3d8b30c..deccd35af8f98 100644 --- a/setup.cfg +++ b/setup.cfg @@ -165,6 +165,10 @@ ignore_errors = True [mypy-pytorch_lightning.profiler.*] ignore_errors = True +# todo: add proper typing to this module... +[mypy-pytorch_lightning.plugins.*] +ignore_errors = True + # todo: add proper typing to this module... [mypy-pytorch_lightning.pt_overrides.*] ignore_errors = True From 1ade664d1152588c5f3e480d158ee7a367dd997f Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 17:58:21 +0100 Subject: [PATCH 05/13] add apex_amp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- .../plugins/precision/apex_amp.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 pytorch_lightning/plugins/precision/apex_amp.py diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py new file mode 100644 index 0000000000000..b9720f19fe3eb --- /dev/null +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -0,0 +1,146 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn + +if _APEX_AVAILABLE: + from apex import amp + + +class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): + """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" + + def __init__(self, amp_level: str): + self.backend = AMPType.APEX + self.amp_level = amp_level + + def master_params(self, optimizer: torch.optim.Optimizer): + return amp.master_params(optimizer) + + def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + """Connects the precision plugin to the training process, + configures apex and reinits the schedulers + """ + model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level) + self.reinit_scheduler_properties(optimizers, lr_schedulers) + return model, optimizers, lr_schedulers + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ): + """performs the actual backpropagation + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: the optimizer to perform the step lateron + opt_idx: the optimizer's index + should_accumulate: whether to accumulate gradients or not + + """ + closure_loss = amp.scale_loss(closure_loss, optimizer) + + # enter apex context + context = closure_loss + closure_loss = closure_loss.__enter__() + + # do backward pass + # TODO: not entirely sure, why we need this + if model is not None and isinstance(model, LightningModule): + model.backward(closure_loss, optimizer, opt_idx) + else: + closure_loss.backward(*args, **kwargs) + + # exit amp context + a, b, c = None, None, None + error = context.__exit__(a, b, c) + if error: + rank_zero_warn(a, b, c) + raise Exception("apex unscale error") + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + return closure_loss + + def configure_apex( + self, + amp: object, + model: LightningModule, + optimizers: List[Optimizer], + amp_level: str, + ) -> Tuple[LightningModule, List[Optimizer]]: + r""" + Override to init AMP your own way. + Must return a model and list of optimizers. + + Args: + amp: pointer to amp library object. + model: pointer to current :class:`LightningModule`. + optimizers: list of optimizers passed in :meth:`configure_optimizers`. + amp_level: AMP mode chosen ('O1', 'O2', etc...) + + Return: + Apex wrapped model and optimizers + + Examples: + .. code-block:: python + + # Default implementation used by Trainer. + def configure_apex(self, amp, model, optimizers, amp_level): + model, optimizers = amp.initialize( + model, optimizers, opt_level=amp_level, + ) + + return model, optimizers + """ + model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) + return model, optimizers + + @staticmethod + def reinit_scheduler_properties(optimizers: list, schedulers: list): + """Reinitializes schedulers with correct properties""" + # Reinitialize optimizer.step properties added by schedulers + for scheduler in schedulers: + scheduler = scheduler["scheduler"] + + for optimizer in optimizers: + state = None + idx = 0 + + # check that we dont mix users optimizers and schedulers + if scheduler.optimizer == optimizer: + # Find the mro belonging to the base lr scheduler class + for i, mro in enumerate(scheduler.__class__.__mro__): + if mro in (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + idx = i + state = scheduler.state_dict() + else: + state = None + + scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) + if state is not None: + scheduler.load_state_dict(state) From bcd0e7a385e06b708dfd66d377c5e5ae20896633 Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 17:58:52 +0100 Subject: [PATCH 06/13] add mixed base class MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/plugins/precision/mixed.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 pytorch_lightning/plugins/precision/mixed.py diff --git a/pytorch_lightning/plugins/precision/mixed.py b/pytorch_lightning/plugins/precision/mixed.py new file mode 100644 index 0000000000000..7e7716c3559f5 --- /dev/null +++ b/pytorch_lightning/plugins/precision/mixed.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities import AMPType + + +class MixedPrecisionPlugin(PrecisionPlugin): + """Base Class for mixed precision""" + + EPSILON = 1e-5 + backend: AMPType + precision = "mixed" From a287567031d2d787bcc8b8a1301c4e8193423ef2 Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 17:59:19 +0100 Subject: [PATCH 07/13] add native amp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- .../plugins/precision/native_amp.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 pytorch_lightning/plugins/precision/native_amp.py diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py new file mode 100644 index 0000000000000..daba223169fc6 --- /dev/null +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from typing import Generator + +import torch + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): + def __init__(self): + self.backend = AMPType.NATIVE + self.scaler = torch.cuda.amp.GradScaler() + + def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: + """always called before the optimizer step. + Checks that the optimizer is not LBFGS, as this one is not supported by native amp + """ + if isinstance(optimizer, torch.optim.LBFGS): + raise MisconfigurationException( + f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." + " To request, please file a Github issue in PyTorch and tag @mcarilli" + ) + + def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: + """Updates the GradScaler""" + self.scaler.update() + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ) -> torch.Tensor: + """performs the actual backpropagation + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: the optimizer to perform the step lateron + opt_idx: the optimizer's index + should_accumulate: whether to accumulate gradients or not + + """ + closure_loss = self.scaler.scale(closure_loss) + + automatic_optimization = model.automatic_optimization + + closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs) + + # unscale gradient to allow analyze within `on_after_backward` + if not should_accumulate and automatic_optimization: + self.scaler.unscale_(optimizer) + + return closure_loss + + @contextmanager + def train_step_context(self) -> Generator[torch.cuda.amp.autocast, None, None]: + """Enable autocast context""" + yield torch.cuda.amp.autocast() From 1ff9b42211e9bf2bda0a2dbd778f1f049ef3bfcb Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 17:59:40 +0100 Subject: [PATCH 08/13] add native amp sharded MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- .../plugins/precision/sharded_native_amp.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 pytorch_lightning/plugins/precision/sharded_native_amp.py diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py new file mode 100644 index 0000000000000..ef8e1b8a95efe --- /dev/null +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import cast, Union + +from torch.optim import Optimizer + +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE + +if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + + +class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + """Mixed Precision for Sharded Training + """ + def __init__(self): + super().__init__() + self.scaler = ShardedGradScaler() + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + optimizer = cast(OSS, optimizer) + optimizer.clip_grad_norm(clip_val, norm_type=norm_type) From 1b82554f74d661adfd7356ff7c162811b3775dbf Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 18:00:04 +0100 Subject: [PATCH 09/13] add tpu bfloat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- .../plugins/precision/tpu_bfloat.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 pytorch_lightning/plugins/precision/tpu_bfloat.py diff --git a/pytorch_lightning/plugins/precision/tpu_bfloat.py b/pytorch_lightning/plugins/precision/tpu_bfloat.py new file mode 100644 index 0000000000000..7f4916dd26a46 --- /dev/null +++ b/pytorch_lightning/plugins/precision/tpu_bfloat.py @@ -0,0 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + + +class TPUHalfPrecisionPlugin(PrecisionPlugin): + """Plugin that enables bfloats on TPUs""" + + precision = 16 + + def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + os.environ["XLA_USE_BF16"] = str(1) + return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) From 791fe19b54cf857050cb06f2daa8fe886755a9d3 Mon Sep 17 00:00:00 2001 From: justusschock Date: Sat, 30 Jan 2021 18:00:34 +0100 Subject: [PATCH 10/13] add inits MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/plugins/precision/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index 8b137891791fe..5249343f023b1 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1 +1,6 @@ - +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin From 94e0b28661e026c179b9e8f333840b955072e74b Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sun, 31 Jan 2021 11:34:26 +0100 Subject: [PATCH 11/13] Update precision_plugin.py --- pytorch_lightning/plugins/precision/precision_plugin.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 22c1a7539fff4..031b588737614 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -22,6 +22,9 @@ class PrecisionPlugin(Plugin): + """ Plugin handling the precision-specific parts of the training. + The static classattributes EPSILON and precision must be overwritten in child-classes and their default values reflect fp32 training + """ EPSILON = 1e-6 precision = 32 From 97eab50264d95122dc75f3ae0a3ad2bbbdb8fda4 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sun, 31 Jan 2021 18:17:05 +0100 Subject: [PATCH 12/13] Update pytorch_lightning/plugins/base_plugin.py Co-authored-by: Jirka Borovec --- pytorch_lightning/plugins/base_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index f5cbf1f14acf5..0160afa559496 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -29,7 +29,6 @@ def connect(self, model: torch.nn.Module, *args: Sequence, **kwargs: Sequence) - def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: """Hook to do something before each optimizer step.""" - pass def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: """Hook to do something after each optimizer step.""" From de8fe1b3e2e20851849567a1ad401296b8887b47 Mon Sep 17 00:00:00 2001 From: justusschock Date: Sun, 31 Jan 2021 18:45:30 +0100 Subject: [PATCH 13/13] fix imports --- pytorch_lightning/accelerators/accelerator.py | 9 +++++---- .../plugins/training_type/training_type_plugin.py | 10 ++++++---- setup.cfg | 2 +- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 7087f6a261010..984f9a6842b4a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -17,13 +17,14 @@ from torch.optim import Optimizer from pytorch_lightning.core import LightningModule -from pytorch_lightning.plugins.training_type import TrainingTypePlugin, HorovodPlugin from pytorch_lightning.plugins.precision import ( - PrecisionPlugin, - MixedPrecisionPlugin, ApexMixedPrecisionPlugin, + MixedPrecisionPlugin, NativeMixedPrecisionPlugin, + PrecisionPlugin, ) +from pytorch_lightning.plugins.training_type import TrainingTypePlugin +from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -374,4 +375,4 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() def on_save(self, checkpoint): - return checkpoint \ No newline at end of file + return checkpoint diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 349ed689254ad..5dbbf23881373 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,14 +13,16 @@ # limitations under the License. import os from abc import ABC, abstractmethod -from typing import Any, Optional, Sequence, Union +from typing import Any, Optional, Sequence, TYPE_CHECKING, Union import torch from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.trainer import Trainer + +if TYPE_CHECKING: + from pytorch_lightning.trainer.trainer import Trainer class TrainingTypePlugin(Plugin, ABC): @@ -105,10 +107,10 @@ def results(self) -> Any: def rpc_enabled(self) -> bool: return False - def start_training(self, trainer: Trainer) -> None: + def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop self._results = trainer.train() - def start_testing(self, trainer: Trainer) -> None: + def start_testing(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop self._results = trainer.run_test() diff --git a/setup.cfg b/setup.cfg index deccd35af8f98..ee23dd130de10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -142,7 +142,7 @@ ignore_errors = True ignore_errors = True # todo: add proper typing to this module... -[mypy-pytorch_lightning.accelerators.legacy.*] +[mypy-pytorch_lightning.accelerators.*] ignore_errors = True # todo: add proper typing to this module...