diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c5c77d4711e6a..984f9a6842b4a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -17,9 +17,16 @@ from torch.optim import Optimizer from pytorch_lightning.core import LightningModule -from pytorch_lightning.plugins import TrainingTypePlugin +from pytorch_lightning.plugins.precision import ( + ApexMixedPrecisionPlugin, + MixedPrecisionPlugin, + NativeMixedPrecisionPlugin, + PrecisionPlugin, +) +from pytorch_lightning.plugins.training_type import TrainingTypePlugin +from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.enums import LightningEnum +from pytorch_lightning.utilities.enums import AMPType, LightningEnum class Accelerator(object): @@ -39,7 +46,7 @@ class Accelerator(object): def __init__( self, - precision_plugin, #: PrecisionPlugin # fixme + precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin, ) -> None: """ @@ -230,9 +237,8 @@ def backward( ) # TODO: this is a hack, find a better solution for this (hook?) - # fixme: uncomment when this class is added - # if isinstance(self.training_type_plugin, HorovodPlugin): - # optimizer.synchronize() + if isinstance(self.training_type_plugin, HorovodPlugin): + optimizer.synchronize() return output @@ -256,11 +262,9 @@ def optimizer_step( """ model_ref = self.lightning_module is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - # fixme: uncomment when this class is added - # is_native_amp = ( - # isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE - # ) - is_native_amp = False + native_amp = ( + isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE + ) self.precision_plugin.pre_optimizer_step(optimizer, opt_idx) self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx) @@ -273,7 +277,7 @@ def optimizer_step( optimizer_idx=opt_idx, optimizer_closure=lambda_closure, on_tpu=False, # TPUAccelerator class sets this as True - using_native_amp=is_native_amp, + using_native_amp=native_amp, using_lbfgs=is_lbfgs, ) @@ -326,7 +330,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn """ plugin.connect(model) - def connect_precision_plugin(self, plugin): #: PrecisionPlugin # fixme + def connect_precision_plugin(self, plugin: PrecisionPlugin): """Attaches the precision plugin to the accelerator""" model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers) self.model = model @@ -339,13 +343,12 @@ def to_device(self, batch: Any) -> Any: @property def amp_backend(self) -> Optional[LightningEnum]: - # fixme: uncomment when this class is added - # if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): - # return AMPType.APEX - # elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): - # return AMPType.NATIVE - # return None - pass + if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + return AMPType.APEX + elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + return AMPType.NATIVE + else: + return None @property def precision(self) -> int: @@ -372,4 +375,4 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() def on_save(self, checkpoint): - return checkpoint \ No newline at end of file + return checkpoint diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index c4eeff52751a6..0160afa559496 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -12,46 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from abc import ABC, abstractmethod +from typing import Any, Generator, Optional, overload, Sequence, Tuple import torch -class Plugin(object): +class Plugin(ABC): """Basic Plugin class to derive precision and training type plugins from.""" - def connect(self, model: torch.nn.Module, *args, **kwargs): + @abstractmethod + def connect(self, model: torch.nn.Module, *args: Sequence, **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]: """Connects the plugin with the accelerator (and thereby with trainer and model). Will be called by the accelerator. """ - pass - def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int): + def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: """Hook to do something before each optimizer step.""" - pass - def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int): + def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: """Hook to do something after each optimizer step.""" pass - def pre_training(self): + def pre_training(self) -> None: """Hook to do something before the training starts.""" pass - def post_training(self): + def post_training(self) -> None: """Hook to do something after the training finishes.""" pass @contextlib.contextmanager - def train_step_context(self): + def train_step_context(self) -> Generator: """A contextmanager for the trainstep""" yield @contextlib.contextmanager - def val_step_context(self): + def val_step_context(self) -> Generator: """A contextmanager for the validation step""" yield @contextlib.contextmanager - def test_step_context(self): + def test_step_context(self) -> Generator: """A contextmanager for the teststep""" - yield \ No newline at end of file + yield diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index 8b137891791fe..5249343f023b1 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1 +1,6 @@ - +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py new file mode 100644 index 0000000000000..b9720f19fe3eb --- /dev/null +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -0,0 +1,146 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn + +if _APEX_AVAILABLE: + from apex import amp + + +class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): + """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" + + def __init__(self, amp_level: str): + self.backend = AMPType.APEX + self.amp_level = amp_level + + def master_params(self, optimizer: torch.optim.Optimizer): + return amp.master_params(optimizer) + + def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + """Connects the precision plugin to the training process, + configures apex and reinits the schedulers + """ + model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level) + self.reinit_scheduler_properties(optimizers, lr_schedulers) + return model, optimizers, lr_schedulers + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ): + """performs the actual backpropagation + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: the optimizer to perform the step lateron + opt_idx: the optimizer's index + should_accumulate: whether to accumulate gradients or not + + """ + closure_loss = amp.scale_loss(closure_loss, optimizer) + + # enter apex context + context = closure_loss + closure_loss = closure_loss.__enter__() + + # do backward pass + # TODO: not entirely sure, why we need this + if model is not None and isinstance(model, LightningModule): + model.backward(closure_loss, optimizer, opt_idx) + else: + closure_loss.backward(*args, **kwargs) + + # exit amp context + a, b, c = None, None, None + error = context.__exit__(a, b, c) + if error: + rank_zero_warn(a, b, c) + raise Exception("apex unscale error") + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + return closure_loss + + def configure_apex( + self, + amp: object, + model: LightningModule, + optimizers: List[Optimizer], + amp_level: str, + ) -> Tuple[LightningModule, List[Optimizer]]: + r""" + Override to init AMP your own way. + Must return a model and list of optimizers. + + Args: + amp: pointer to amp library object. + model: pointer to current :class:`LightningModule`. + optimizers: list of optimizers passed in :meth:`configure_optimizers`. + amp_level: AMP mode chosen ('O1', 'O2', etc...) + + Return: + Apex wrapped model and optimizers + + Examples: + .. code-block:: python + + # Default implementation used by Trainer. + def configure_apex(self, amp, model, optimizers, amp_level): + model, optimizers = amp.initialize( + model, optimizers, opt_level=amp_level, + ) + + return model, optimizers + """ + model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) + return model, optimizers + + @staticmethod + def reinit_scheduler_properties(optimizers: list, schedulers: list): + """Reinitializes schedulers with correct properties""" + # Reinitialize optimizer.step properties added by schedulers + for scheduler in schedulers: + scheduler = scheduler["scheduler"] + + for optimizer in optimizers: + state = None + idx = 0 + + # check that we dont mix users optimizers and schedulers + if scheduler.optimizer == optimizer: + # Find the mro belonging to the base lr scheduler class + for i, mro in enumerate(scheduler.__class__.__mro__): + if mro in (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + idx = i + state = scheduler.state_dict() + else: + state = None + + scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) + if state is not None: + scheduler.load_state_dict(state) diff --git a/pytorch_lightning/plugins/precision/mixed.py b/pytorch_lightning/plugins/precision/mixed.py new file mode 100644 index 0000000000000..7e7716c3559f5 --- /dev/null +++ b/pytorch_lightning/plugins/precision/mixed.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities import AMPType + + +class MixedPrecisionPlugin(PrecisionPlugin): + """Base Class for mixed precision""" + + EPSILON = 1e-5 + backend: AMPType + precision = "mixed" diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py new file mode 100644 index 0000000000000..daba223169fc6 --- /dev/null +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from typing import Generator + +import torch + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): + def __init__(self): + self.backend = AMPType.NATIVE + self.scaler = torch.cuda.amp.GradScaler() + + def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: + """always called before the optimizer step. + Checks that the optimizer is not LBFGS, as this one is not supported by native amp + """ + if isinstance(optimizer, torch.optim.LBFGS): + raise MisconfigurationException( + f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." + " To request, please file a Github issue in PyTorch and tag @mcarilli" + ) + + def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None: + """Updates the GradScaler""" + self.scaler.update() + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ) -> torch.Tensor: + """performs the actual backpropagation + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: the optimizer to perform the step lateron + opt_idx: the optimizer's index + should_accumulate: whether to accumulate gradients or not + + """ + closure_loss = self.scaler.scale(closure_loss) + + automatic_optimization = model.automatic_optimization + + closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs) + + # unscale gradient to allow analyze within `on_after_backward` + if not should_accumulate and automatic_optimization: + self.scaler.unscale_(optimizer) + + return closure_loss + + @contextmanager + def train_step_context(self) -> Generator[torch.cuda.amp.autocast, None, None]: + """Enable autocast context""" + yield torch.cuda.amp.autocast() diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 0ff54bf1e8515..031b588737614 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Generator, Union +from typing import Any, Generator, Sequence, Tuple, Union import torch from torch.optim import Optimizer @@ -22,6 +22,9 @@ class PrecisionPlugin(Plugin): + """ Plugin handling the precision-specific parts of the training. + The static classattributes EPSILON and precision must be overwritten in child-classes and their default values reflect fp32 training + """ EPSILON = 1e-6 precision = 32 @@ -34,7 +37,7 @@ def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Ten for p in group["params"]: yield p - def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + def connect(self, model: torch.nn.Module, optimizers: Sequence, lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]: """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers @@ -45,9 +48,9 @@ def backward( optimizer: torch.optim.Optimizer, opt_idx: int, should_accumulate: bool, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: """performs the actual backpropagation Args: @@ -71,7 +74,7 @@ def backward( return closure_loss - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None: """Clips the gradients to a specific value""" # TODO: separate TPU case from here if clip_val is None: @@ -82,7 +85,7 @@ def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm if grad_clip_val <= 0: return - parameters = self.master_params(optimizer) + parameters = list(self.master_params(optimizer)) max_norm = grad_clip_val diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py new file mode 100644 index 0000000000000..ef8e1b8a95efe --- /dev/null +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import cast, Union + +from torch.optim import Optimizer + +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE + +if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + + +class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + """Mixed Precision for Sharded Training + """ + def __init__(self): + super().__init__() + self.scaler = ShardedGradScaler() + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + optimizer = cast(OSS, optimizer) + optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/precision/tpu_bfloat.py b/pytorch_lightning/plugins/precision/tpu_bfloat.py new file mode 100644 index 0000000000000..7f4916dd26a46 --- /dev/null +++ b/pytorch_lightning/plugins/precision/tpu_bfloat.py @@ -0,0 +1,28 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch + +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin + + +class TPUHalfPrecisionPlugin(PrecisionPlugin): + """Plugin that enables bfloats on TPUs""" + + precision = 16 + + def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + os.environ["XLA_USE_BF16"] = str(1) + return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index d1e7907d5d97f..5dbbf23881373 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import os from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional, Sequence, TYPE_CHECKING, Union import torch @@ -21,11 +21,14 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.base_plugin import Plugin +if TYPE_CHECKING: + from pytorch_lightning.trainer.trainer import Trainer + class TrainingTypePlugin(Plugin, ABC): """A Plugin to change the behaviour of the training, validation and test-loop.""" - def __init__(self): + def __init__(self) -> None: self._model = None self._results = None self.global_rank = 0 @@ -41,7 +44,7 @@ def root_device(self) -> torch.device: """Returns the root device""" @abstractmethod - def model_to_device(self): + def model_to_device(self) -> None: """Moves the model to the correct device""" @property @@ -50,11 +53,11 @@ def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" @abstractmethod - def reduce(self, output, *args, **kwargs): + def reduce(self, output: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: """Reduces the given output (e.g. across GPUs/Processes)""" @abstractmethod - def barrier(self, name: Optional[str] = None): + def barrier(self, name: Optional[str] = None) -> None: """Forces all possibly joined processes to wait for each other""" @abstractmethod @@ -62,7 +65,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes""" # TODO method this is currently unused. Check after complete refactors are pushed - def set_nvidia_flags(self, is_slurm_managing_tasks, device_ids): + def set_nvidia_flags(self, is_slurm_managing_tasks: bool, device_ids: Optional[Sequence]) -> None: if device_ids is None: return @@ -70,7 +73,8 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, device_ids): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.environ.get("CUDA_VISIBLE_DEVICES", all_gpu_ids) - log.info(f"LOCAL_RANK: {self.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") + if self.lightning_module is not None: + log.info(f"LOCAL_RANK: {self.lightning_module.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") def reduce_early_stopping_decision(self, should_stop: bool) -> bool: """Reduce the early stopping decision across all possibly spawned processes""" @@ -82,16 +86,16 @@ def model(self) -> torch.nn.Module: return self._model @model.setter - def model(self, new_model: torch.nn.Module): + def model(self, new_model: torch.nn.Module) -> None: self._model = new_model @property - def lightning_module(self) -> LightningModule: + def lightning_module(self) -> Optional[LightningModule]: """Returns the pure LightningModule without potential wrappers""" return self._model @property - def results(self): + def results(self) -> Any: """ The results of the last training/testing run will be cached here. In distributed training, we make sure to transfer the results to the appropriate master process. @@ -103,10 +107,10 @@ def results(self): def rpc_enabled(self) -> bool: return False - def start_training(self, trainer: "Trainer") -> None: + def start_training(self, trainer: 'Trainer') -> None: # double dispatch to initiate the training loop self._results = trainer.train() - def start_testing(self, trainer: "Trainer") -> None: + def start_testing(self, trainer: 'Trainer') -> None: # double dispatch to initiate the test loop self._results = trainer.run_test() diff --git a/setup.cfg b/setup.cfg index 31046d3d8b30c..ee23dd130de10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -142,7 +142,7 @@ ignore_errors = True ignore_errors = True # todo: add proper typing to this module... -[mypy-pytorch_lightning.accelerators.legacy.*] +[mypy-pytorch_lightning.accelerators.*] ignore_errors = True # todo: add proper typing to this module... @@ -165,6 +165,10 @@ ignore_errors = True [mypy-pytorch_lightning.profiler.*] ignore_errors = True +# todo: add proper typing to this module... +[mypy-pytorch_lightning.plugins.*] +ignore_errors = True + # todo: add proper typing to this module... [mypy-pytorch_lightning.pt_overrides.*] ignore_errors = True