From 782c70f46996af99b71679b72ad7bb00254e23f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 27 Oct 2021 09:25:36 +0200 Subject: [PATCH 01/47] lightning lite package and tests --- pytorch_lightning/lite/__init__.py | 17 ++ pytorch_lightning/lite/lite.py | 470 +++++++++++++++++++++++++++++ pytorch_lightning/lite/wrappers.py | 151 +++++++++ tests/lite/__init__.py | 0 tests/lite/test_lite.py | 392 ++++++++++++++++++++++++ tests/lite/test_parity.py | 237 +++++++++++++++ tests/lite/test_wrappers.py | 106 +++++++ 7 files changed, 1373 insertions(+) create mode 100644 pytorch_lightning/lite/__init__.py create mode 100644 pytorch_lightning/lite/lite.py create mode 100644 pytorch_lightning/lite/wrappers.py create mode 100644 tests/lite/__init__.py create mode 100644 tests/lite/test_lite.py create mode 100644 tests/lite/test_parity.py create mode 100644 tests/lite/test_wrappers.py diff --git a/pytorch_lightning/lite/__init__.py b/pytorch_lightning/lite/__init__.py new file mode 100644 index 0000000000000..f4634fe54e548 --- /dev/null +++ b/pytorch_lightning/lite/__init__.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.lite.lite import LightningLite + +__all__ = ["LightningLite"] diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py new file mode 100644 index 0000000000000..49798b138567e --- /dev/null +++ b/pytorch_lightning/lite/lite.py @@ -0,0 +1,470 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, overload, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler + +from pytorch_lightning import Trainer +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.plugins import ( + DDPShardedPlugin, + DDPSpawnPlugin, + DeepSpeedPlugin, + PLUGIN_INPUT, + TPUSpawnPlugin, + TrainingTypePlugin, +) +from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector +from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin +from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device +from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors +from pytorch_lightning.utilities.data import has_iterable_dataset +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class LightningLite(ABC): + """Lite accelerates your PyTorch training or inference code with minimal changes required. + + - Automatic placement of models and data onto the device + - Automatic support for mixed and double precision (smaller memory footprint) + - Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies + (data-parallel training, sharded training, etc.) + - Automated spawning of processes, no launch utilities required + - Multi-node support + + Args: + accelerator: The hardware to run on. Possible choices are: cpu, gpu, tpu, auto. + strategy: Strategy for how to run across multiple devices. Possible choices are: + dp, ddp, ddp_spawn, tpu_spawn, deepspeed, ddp_sharded. + devices: Number of devices to train on (int) or which GPUs to train on (list or str). The value applies + per node. + num_nodes: Number of GPU nodes for distributed training. + precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). + plugins: One or several custom plugins + gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``. + tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``. + """ + + def __init__( + self, + accelerator: Optional[Union[str, Accelerator]] = None, + strategy: Optional[Union[str, TrainingTypePlugin]] = None, + devices: Optional[Union[List[int], str, int]] = None, + num_nodes: int = 1, + precision: Union[int, str] = 32, + plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, + gpus: Optional[Union[List[int], str, int]] = None, + tpu_cores: Optional[Union[List[int], str, int]] = None, + ) -> None: + self._check_accelerator_support(accelerator) + self._check_strategy_support(strategy) + gpu_ids, tpu_cores = Trainer._parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores) + self._accelerator_connector = AcceleratorConnector( + num_processes=1, + devices=devices, + tpu_cores=tpu_cores, + ipus=None, + accelerator=accelerator, + strategy=strategy, + gpus=gpus, + gpu_ids=gpu_ids, + num_nodes=num_nodes, + sync_batchnorm=False, # TODO: add support? + benchmark=False, + replace_sampler_ddp=True, + deterministic=False, + precision=precision, + amp_type="native", + amp_level=None, + plugins=plugins, + ) + self._accelerator = self._accelerator_connector.accelerator + self._strategy = self._accelerator.training_type_plugin + self._precision_plugin = self._accelerator.precision_plugin + self._num_models: int = 0 + + # wrap the run method so we can inject setup logic or spawn processes for the user + setattr(self, "run", self._run_wrapper(self.run)) + + @property + def device(self) -> torch.device: + """The current device this process runs on. + + Use this to create tensors directly on the device if needed. + """ + return self._accelerator.root_device + + @property + def global_rank(self) -> int: + """The global index of the current process across all devices and nodes.""" + return getattr(self._strategy, "global_rank", 0) + + @property + def local_rank(self) -> int: + """The index of the current process among the processes running on the local node.""" + return getattr(self._strategy, "local_rank", 0) + + @property + def node_rank(self) -> int: + """The index of the current node.""" + return getattr(self._strategy, "node_rank", 0) + + @property + def world_size(self) -> int: + """The total number of processes running across all devices and nodes.""" + return getattr(self._strategy, "world_size", 1) + + @property + def is_global_zero(self) -> bool: + """Wether this rank is rank zero.""" + return self._strategy.is_global_zero + + @abstractmethod + def run(self, *args: Any, **kwargs: Any) -> Any: + """All the code inside this run method gets accelerated by Lite. + + Args: + *args: Add any positional arguments you need, e.g., the hyperparameters for your model + **kwargs: Add any keyword arguments you need, e.g., the hyperparameters for your model + """ + + def setup( + self, + model: nn.Module, + *optimizers: Optimizer, + move_to_device: bool = True, + ) -> Union[_LiteModule, List[Union[_LiteModule, _LiteOptimizer]]]: + """Setup a model and its optimizers for accelerated training. + + Args: + model: A model to setup + *optimizers: The optimizer(s) to setup (no optimizers is also possible) + move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` + and alternatively use :meth:`to_device` manually. + + Returns: + The tuple of the wrapped model and list of optimizers, in the same order they were passed in. + """ + self._validate_setup(model, optimizers) + + if move_to_device: + model = self._move_model_to_device(model=model, optimizers=list(optimizers)) + + # Let accelerator/plugin wrap and connect the models and optimizers + model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers)) + model = _LiteModule(model, self._accelerator) + optimizers = [_LiteOptimizer(optimizer=optimizer, accelerator=self._accelerator) for optimizer in optimizers] + self._num_models += 1 + if optimizers: + return [model] + optimizers # type: ignore + return model + + def setup_dataloaders( + self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True + ) -> Union[DataLoader, List[DataLoader], Iterable]: + """Setup one or multiple dataloaders for accelerated training. If you need different settings for each + dataloader, call this method individually for each one. + + Args: + *dataloaders: A single dataloader or a sequence of dataloaders. + replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader(s) + for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. + move_to_device: If set ``True`` (default), moves the data returned by the dataloader(s) automatially to + the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the + returned data. + + Returns: + The wrapped dataloaders, in the same order they were passed in. + """ + self._validate_setup_dataloaders(dataloaders) + dataloaders = [ + self._setup_dataloader(dataloader, replace_sampler=replace_sampler, move_to_device=move_to_device) + for dataloader in dataloaders + ] + dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders + return dataloaders + + def _setup_dataloader( + self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True + ) -> Union[Iterable, DataLoader]: + """Setup a single dataloader for accelerated training. + + Args: + dataloader: The dataloader to accelerate. + replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader + for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. + move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatially to + the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the + returned data. + + Returns: + The wrapped dataloader. + """ + sampler = dataloader.sampler + if replace_sampler and self._requires_distributed_sampler(dataloader): + if not isinstance(sampler, (SequentialSampler, RandomSampler)): + raise MisconfigurationException( + "You seem to have configured a sampler in your DataLoader. This will be replaced " + " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" + " distributed training. Either remove the sampler from your DataLoader or set" + " `replace_sampler=False` if you want to use your custom sampler." + ) + sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) + + kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) + device = self.device if move_to_device else None + if isinstance(self._strategy, TPUSpawnPlugin): + dataloader = DataLoader(**kwargs) + else: + dataloader = _LiteDataLoader(device=device, **kwargs) + return self._strategy.process_dataloader(dataloader) + + def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: + """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. + + Args: + tensor: The tensor (loss) to back-propagate gradients from. + *args: Optional positional arguments passed to the underlying backward function. + model: Optional model instance for plugins that require the model for backward(). + **kwargs: Optional named keyword arguments passed to the underlying backward function. + + Note: + When using ``strategy='deepspeed'`` and multiple models were setup, it is required to pass in the + model as argument here. + """ + module = model.module if model is not None else model + if self._num_models > 0 and isinstance(self._strategy, DeepSpeedPlugin): + if model is None: + raise MisconfigurationException( + "When using multiple models + deepspeed, please provide the model used to perform the optimization." + ) + + # requires to attach the current `DeepSpeedEngine` for the `_LiteOptimizer.step` call. + self._strategy.model = module + + self._precision_plugin._run_backward(tensor, module, *args, **kwargs) + + @contextmanager + def cast(self) -> Generator[None, None, None]: + """A context manager to automatically convert operations for the chosen precision. + + Use this only if the `forward` method of your model does not cover all operations you wish to run with the + chosen precision setting. + """ + with self._precision_plugin.forward_context(): + yield + + def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: + """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already + on that device. + + Args: + obj: An object to move to the device. Can be an instance of :class:`torch.nn.Module`, a tensor, or a + (nested) collection of tensors (e.g., a dictionary). + + Returns: + A reference to the object that was moved to the new device. + """ + if isinstance(obj, nn.Module): + if self.device.type == "cuda": + # need to call this manually here again in case we spawned with DDPSpawnPlugin + # TODO: refactor to let plugin handle this cleanly + torch.cuda.set_device(self.device) + return obj.to(self.device) + return move_data_to_device(obj, device=self.device) + + def print(self, *args: Any, **kwargs: Any) -> None: + """Print something only on the first process. + + Arguments passed to this method are forwarded to the Python built-in :func:`print` function. + """ + if self.local_rank == 0: + print(*args, **kwargs) + + def barrier(self, name: Optional[str] = None) -> None: + """Wait for all processes to enter this call. Use this to synchronize all parallel processes, but only if + necessary, otherwise the overhead of synchronization will cause your program to slow down. + + Example:: + + if self.global_rank == 0: + # let process 0 download the dataset + dataset.download_files() + + # let all processes wait before reading the dataset + self.barrier() + + # now all processes can read the files and start training + """ + self._strategy.barrier() + + def all_gather( + self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[torch.Tensor, Dict, List, Tuple]: + r""" + Gather tensors or collections of tensors from multiple processes. + + Args: + data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. + group: the process group to gather results from. Defaults to all processes (world) + sync_grads: flag that allows users to synchronize gradients for the all_gather operation + + Return: + A tensor of shape (world_size, batch, ...), or if the input was a collection + the output will also be a collection with tensors of this shape. + """ + group = group if group is not None else torch.distributed.group.WORLD + data = convert_to_tensors(data, device=self.device) + return apply_to_collection(data, torch.Tensor, self._strategy.all_gather, group=group, sync_grads=sync_grads) + + def broadcast(self, obj: object, src: int = 0) -> object: + return self._strategy.broadcast(obj, src=src) + + def save_checkpoint(self, filepath: Union[str, Path], content: Dict[str, Any]) -> None: + """Save a checkpoint contents to a file. + + How and which processes save gets determined by the `strategy`. For example, the `ddp` strategy + saves checkpoints only on process 0. + + Args: + filepath: A path to where the file should be saved + content: A dictionary with contents, i.e., the state dict of your model + """ + self._strategy.save_checkpoint(content, filepath) + + def _run_wrapper(self, run_method: Callable) -> Callable: + return partial(self._run_impl, run_method) + + def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + self._set_plugin_specific_precision_variables() + self._accelerator.setup_environment() + + # apply sharded context to prevent OOM + run_method = partial(self._run_with_sharded_context, run_method) + + if isinstance(self._strategy, DDPSpawnPlugin): + return self._strategy.spawn(run_method, *args, **kwargs) + else: + return run_method(*args, **kwargs) + + def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: + with self._strategy.model_sharded_context(): + return run_method(*args, **kwargs) + + def _set_plugin_specific_precision_variables(self) -> None: + # todo: these are hacks as plugins rely on access to the precision plugin + if isinstance(self._strategy, DeepSpeedPlugin): + self._set_deepspeed_precision_variables() + if isinstance(self._strategy, DDPShardedPlugin): + self._strategy._precision = self._accelerator_connector.precision + + def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: + if isinstance(self._strategy, TPUSpawnPlugin): + # When the user creates the optimizer, they reference the parameters on the CPU. + # However, when running with TPU the parameters get copied and the reference in the optimizer + # remains invalid. We need to update the references to point to the parameter tensors on the device. + params_on_cpu = dict(model.named_parameters()) + model = self.to_device(model) + params_on_device = dict(model.named_parameters()) + + mapping = {param: params_on_device[name] for name, param in params_on_cpu.items()} + for optimizer in optimizers: + for param_group in optimizer.param_groups: + param_group["params"] = [mapping.get(p, p) for p in param_group["params"]] + else: + model = self.to_device(model) + return model + + def _set_deepspeed_precision_variables(self) -> None: + # TODO: Refactor this once precision pluging is part of the strategy. + amp_type = self._accelerator_connector.amp_type + amp_level = self._accelerator_connector.amp_level + precision = self._accelerator_connector.precision + self._strategy.amp_level, self._strategy.amp_type, self._strategy._precision = amp_level, amp_type, precision + + def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: + return ( + self._accelerator_connector.is_distributed + and not isinstance(dataloader.sampler, DistributedSampler) + and not has_iterable_dataset(dataloader) + ) + + @staticmethod + def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler: + kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) + return DistributedSampler(dataloader.dataset, **kwargs) + + def _check_accelerator_support(self, accelerator: Optional[Union[str, Accelerator]]) -> None: + supported = [t.value.lower() for t in self._supported_device_types()] + ["auto"] + valid = accelerator is None or isinstance(accelerator, Accelerator) or accelerator in supported + if not valid: + raise MisconfigurationException( + f"`accelerator={repr(accelerator)}` is not a valid choice." + f" Choose one of {supported} or pass in a `Accelerator` instance." + ) + + def _check_strategy_support(self, strategy: Optional[Union[str, TrainingTypePlugin]]) -> None: + supported = [t.lower() for t in self._supported_strategy_types()] + valid = strategy is None or isinstance(strategy, TrainingTypePlugin) or strategy in supported + if not valid: + raise MisconfigurationException( + f"`strategy={repr(strategy)}` is not a valid choice." + f" Choose one of {supported} or pass in a `TrainingTypePlugin` instance." + ) + + @staticmethod + def _supported_device_types() -> Sequence[DeviceType]: + return ( + DeviceType.CPU, + DeviceType.GPU, + DeviceType.TPU, + ) + + @staticmethod + def _supported_strategy_types() -> Sequence[str]: + return ( + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.TPU_SPAWN, + DistributedType.DEEPSPEED, + DistributedType.DDP_SHARDED, + DistributedType.DDP_SHARDED_SPAWN, + ) + + @staticmethod + def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: + if isinstance(model, _LiteModule): + raise MisconfigurationException("A model should be passed only once to the `setup` method.") + + if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): + raise MisconfigurationException("An optimizer should be passed only once to the `setup` method.") + + @staticmethod + def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: + if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders): + raise MisconfigurationException("A dataloader should be passed only once to the `setup_dataloaders` method") + + if any(not isinstance(dl, DataLoader) for dl in dataloaders): + raise MisconfigurationException("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py new file mode 100644 index 0000000000000..e1d16ca8a3384 --- /dev/null +++ b/pytorch_lightning/lite/wrappers.py @@ -0,0 +1,151 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union + +import torch +from torch import nn as nn +from torch import Tensor +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device + + +def _do_nothing_closure() -> None: + return None + + +class _LiteOptimizer: + def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: + """LiteOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer + step calls to the accelerator/strategy plugin. + + The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`. + + Args: + optimizer: The optimizer to wrap + accelerator: Reference to the accelerator for handling the optimizer step + """ + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} + self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) + self._optimizer = optimizer + self._accelerator = accelerator + + @property + def optimizer(self) -> Optimizer: + return self._optimizer + + @property + def state(self) -> Dict[str, torch.Tensor]: + return self._optimizer.state + + @state.setter + def state(self, state: Dict[str, torch.Tensor]) -> None: + self._optimizer.state = state + + @property + def defaults(self) -> Dict[str, Any]: + return self._optimizer.defaults + + @defaults.setter + def defaults(self, defaults: Dict[str, Any]) -> None: + self._optimizer.defaults = defaults + + @property + def param_groups(self) -> List[Dict[str, torch.Tensor]]: + return self._optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups: List[Dict[str, torch.Tensor]]) -> None: + self._optimizer.param_groups = param_groups + + def step(self, closure: Optional[Callable] = None) -> None: + closure = closure or _do_nothing_closure + self._accelerator.optimizer_step( + self._optimizer, + opt_idx=0, + lambda_closure=closure, + model=self._accelerator.model, + ) + + def zero_grad(self, *args: Any, **kwargs: Any) -> None: + self._optimizer.zero_grad(*args, **kwargs) + + +class _LiteModule(nn.Module): + # TODO: Pass in the precision plugin instead of accelerator + def __init__(self, module: nn.Module, accelerator: Accelerator) -> None: + """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast + automatically for the forward pass. + + The underlying wrapped module can be accessed via the property :attr:`module`. + + Args: + module: The module to wrap + accelerator: Reference to the accelerator for handling precision context + """ + super().__init__() + self._module = module + self._accelerator = accelerator + + @property + def module(self) -> nn.Module: + return self._module + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Casts all inputs to the right precision and handles autocast for operations in the module forward + method.""" + precision = self._accelerator.precision_plugin.precision + precision_to_type = { + "mixed": torch.float16, + 16: torch.float16, + 32: torch.float32, + 64: torch.float64, + } + # TODO (@awaelchli): let the precision plugin handle the conversion + to_type = precision_to_type[precision] + args, kwargs = apply_to_collection([args, kwargs], function=lambda t: t.to(to_type), dtype=Tensor) + + with self._accelerator.precision_plugin.forward_context(): + output = self.module(*args, **kwargs) + + output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor) + return output + + +class _LiteDataLoader(DataLoader): + def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None: + """The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds + additional features such as moving the data to the device automatically. + + Args: + device: The device to which the data should be moved. By default the device is `None` and no data + transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). + **dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts. + """ + super().__init__(**dl_kwargs) + self._device = device + + @property + def device(self) -> Optional[torch.device]: + return self._device + + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: + iterator = super().__iter__() + if self._device is None: + return iterator + + for item in iterator: + yield move_data_to_device(item, self._device) diff --git a/tests/lite/__init__.py b/tests/lite/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py new file mode 100644 index 0000000000000..def9ce29ac9dc --- /dev/null +++ b/tests/lite/test_lite.py @@ -0,0 +1,392 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from unittest import mock +from unittest.mock import Mock, PropertyMock + +import pytest +import torch +import torch.distributed +import torch.nn.functional +from torch import nn +from torch.utils.data import DataLoader, DistributedSampler, Sampler + +from pytorch_lightning import seed_everything +from pytorch_lightning.lite import LightningLite +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.runif import RunIf + + +class EmptyLite(LightningLite): + def run(self): + pass + + +class BoringModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2, bias=False) + + def forward(self, x): + x = self.layer(x) + return torch.nn.functional.mse_loss(x, torch.ones_like(x)) + + +@pytest.mark.parametrize("accelerator", ["coconut"]) +def test_unsupported_accelerator(accelerator): + with pytest.raises(MisconfigurationException, match=f"`accelerator={repr(accelerator)}` is not a valid choice"): + EmptyLite(accelerator=accelerator) + + +@pytest.mark.parametrize("strategy", ["coconut"]) +def test_unsupported_strategy(strategy): + with pytest.raises(MisconfigurationException, match=f"`strategy={repr(strategy)}` is not a valid choice"): + EmptyLite(strategy=strategy) + + +def test_run_input_output(): + """Test that the dynamically patched run() method receives the input arguments and returns the result.""" + + class Lite(LightningLite): + + run_args = () + run_kwargs = {} + + def run(self, *args, **kwargs): + self.run_args = args + self.run_kwargs = kwargs + return "result" + + lite = Lite() + result = lite.run(1, 2, three=3) + assert result == "result" + assert lite.run_args == (1, 2) + assert lite.run_kwargs == {"three": 3} + + +def test_setup_optimizers(): + """Test that setup_optimizers can handle no optimizers, one optimizer, or multiple optimizers.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1) + optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1) + + # no optimizer + lite_model = lite.setup(model) + assert isinstance(lite_model, _LiteModule) + assert lite_model.module is model + + # single optimizer + lite_model, lite_optimizer = lite.setup(model, optimizer0) + assert isinstance(lite_model, _LiteModule) + assert isinstance(lite_optimizer, _LiteOptimizer) + assert lite_model.module is model + assert lite_optimizer.optimizer is optimizer0 + + # multiple optimizers + lite_model, lite_optimizer0, lite_optimizer1 = lite.setup(model, optimizer0, optimizer1) + assert isinstance(lite_model, _LiteModule) + assert isinstance(lite_optimizer0, _LiteOptimizer) + assert isinstance(lite_optimizer1, _LiteOptimizer) + assert lite_model.module is model + assert lite_optimizer0.optimizer is optimizer0 + assert lite_optimizer1.optimizer is optimizer1 + + +def test_setup_twice_fails(): + """Test that calling setup with a model or optimizer that is already wrapped fails.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer = torch.optim.Adam(model.parameters()) + + lite_model, lite_optimizer = lite.setup(model, optimizer) + with pytest.raises(MisconfigurationException, match="A model should be passed only once to the"): + lite.setup(lite_model, optimizer) + + lite_model, lite_optimizer = lite.setup(model, optimizer) + with pytest.raises(MisconfigurationException, match="An optimizer should be passed only once to the"): + lite.setup(model, lite_optimizer) + + +def test_setup_tracks_num_models(): + """Test that setup() tracks how many times it has setup a model.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer = torch.optim.Adam(model.parameters()) + + assert lite._num_models == 0 + lite.setup(model, optimizer) + assert lite._num_models == 1 + + lite.setup(model, optimizer) + assert lite._num_models == 2 + + +def test_setup_dataloaders_unsupported_type(): + """Test that the setup_dataloaders method fails when provided with non-DataLoader objects.""" + lite = EmptyLite() + with pytest.raises(MisconfigurationException, match="Only PyTorch DataLoader are currently supported"): + lite.setup_dataloaders(range(2)) # type: ignore + + +def test_setup_dataloaders_return_type(): + """Test that the setup method returns the dataloaders wrapped as LiteDataLoader and in the right order.""" + lite = EmptyLite() + + # single dataloader + lite_dataloader = lite.setup_dataloaders(DataLoader(range(2))) + assert isinstance(lite_dataloader, _LiteDataLoader) + + # multiple dataloaders + dataset0 = Mock() + dataset1 = Mock() + dataloader0 = DataLoader(dataset0) + dataloader1 = DataLoader(dataset1) + lite_dataloader0, lite_dataloader1 = lite.setup_dataloaders(dataloader0, dataloader1) + assert isinstance(lite_dataloader0, _LiteDataLoader) + assert isinstance(lite_dataloader1, _LiteDataLoader) + assert lite_dataloader0.dataset is dataset0 + assert lite_dataloader1.dataset is dataset1 + + +def test_setup_dataloaders_twice_fails(): + """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" + lite = EmptyLite() + dataloader = DataLoader(range(2)) + lite_dataloader = lite.setup_dataloaders(dataloader) + + with pytest.raises(MisconfigurationException, match="A dataloader should be passed only once to the"): + lite.setup_dataloaders(lite_dataloader) + + +@mock.patch( + "pytorch_lightning.lite.lite.LightningLite.device", + new_callable=PropertyMock, + return_value=torch.device("cuda", 1), +) +def test_setup_dataloaders_move_to_device(lite_device_mock): + """Test that the setup configures LiteDataLoader to move the data to the device automatically.""" + lite = EmptyLite() + lite_dataloaders = lite.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=False) + assert all(dl.device is None for dl in lite_dataloaders) + lite_device_mock.assert_not_called() + + lite = EmptyLite() + lite_dataloaders = lite.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=True) + assert all(dl.device == torch.device("cuda", 1) for dl in lite_dataloaders) + lite_device_mock.assert_called() + + +def test_setup_dataloaders_distributed_sampler_not_needed(): + """Test that replace_sampler option has no effect when no distributed sampler is needed.""" + custom_sampler = Mock(spec=Sampler) + dataloader = DataLoader(Mock(), sampler=custom_sampler) + + # keep the custom sampler when not needed to replace + lite = EmptyLite() + lite_dataloader = lite.setup_dataloaders(dataloader, replace_sampler=True) + assert lite_dataloader.sampler is custom_sampler + + +@pytest.mark.parametrize( + "strategy", + [ + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.TPU_SPAWN, + pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), + pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), + pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), + ], +) +def test_setup_dataloaders_replace_custom_sampler(strategy): + """Test that asking to replace a custom sampler results in an error when a distributed sampler would be + needed.""" + custom_sampler = Mock(spec=Sampler) + dataloader = DataLoader(Mock(), sampler=custom_sampler) + + # explicitly asking to replace when a custom sampler is already configured raises an exception + lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2) + if lite._accelerator_connector.is_distributed: + with pytest.raises(MisconfigurationException, match="You seem to have configured a sampler in your DataLoader"): + lite.setup_dataloaders(dataloader, replace_sampler=True) + + # setting `replace_sampler=False` leaves the sampler untouched + lite_dataloader = lite.setup_dataloaders(dataloader, replace_sampler=False) + assert lite_dataloader.sampler is custom_sampler + + +@pytest.mark.parametrize( + "strategy", + [ + DistributedType.DP, + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.TPU_SPAWN, + pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), + pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), + pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), + ], +) +@pytest.mark.parametrize("shuffle", [True, False]) +def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy): + """Test that Lite replaces the default samplers with DistributedSampler automatically.""" + lite = EmptyLite(accelerator="cpu", strategy=strategy, devices=2) + is_distributed = lite._accelerator_connector.is_distributed + lite_dataloader = lite.setup_dataloaders(DataLoader(range(3), shuffle=shuffle)) + assert not is_distributed or isinstance(lite_dataloader.sampler, DistributedSampler) + + +@pytest.mark.parametrize( + "accelerator, expected", + [ + ("cpu", torch.device("cpu")), + pytest.param("gpu", torch.device("cuda", 0), marks=RunIf(min_gpus=1)), + pytest.param("tpu", torch.device("xla", 0), marks=RunIf(tpu=True)), + ], +) +def test_to_device(accelerator, expected): + """Test that the to_device method can move various objects to the device determined by the accelerator.""" + lite = EmptyLite(accelerator=accelerator, devices=1) + + # module + module = torch.nn.Linear(2, 3) + module = lite.to_device(module) + assert all(param.device == expected for param in module.parameters()) + + # tensor + tensor = torch.rand(2, 2) + tensor = lite.to_device(tensor) + assert tensor.device == expected + + # collection + collection = {"data": torch.rand(2, 2), "int": 1} + collection = lite.to_device(collection) + assert collection["data"].device == expected + + +def test_rank_properties(): + """Test that the rank properties are determined by the strategy.""" + lite = EmptyLite() + lite._strategy = Mock(spec=TrainingTypePlugin) + lite._strategy.world_size = 1000 + assert lite.world_size == 1000 + lite._strategy.global_rank = 100 + assert lite.global_rank == 100 + lite._strategy.local_rank = 10 + assert lite.local_rank == 10 + lite._strategy.node_rank = 1 + assert lite.node_rank == 1 + + +def test_backward(): + """Test that backward() calls into the precision plugin.""" + lite = EmptyLite() + lite._precision_plugin = Mock(spec=PrecisionPlugin) + loss = Mock() + lite.backward(loss, "arg", keyword="kwarg") + lite._precision_plugin._run_backward.assert_called_with(loss, None, "arg", keyword="kwarg") + + +@RunIf(deepspeed=True) +def test_backward_model_input_required(): + """Test that when using deepspeed and multiple models, backward() requires the model as input.""" + lite = EmptyLite(strategy="deepspeed") + + model0 = nn.Linear(1, 2) + model1 = nn.Linear(1, 2) + + optimizer0 = torch.optim.Adam(model0.parameters()) + optimizer1 = torch.optim.Adam(model1.parameters()) + + lite._strategy._setup_model_and_optimizer = lambda *args: args + + lite.setup(model0, optimizer0) + lite.setup(model1, optimizer1) + + loss = model0(torch.randn(1, 1)).sum() + + with pytest.raises(MisconfigurationException, match="please provide the model used to perform"): + lite.backward(loss) + + +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_multiple_models(): + class Lite(LightningLite): + def run(self): + model = BoringModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) + model, optimizer = self.setup(model, optimizer) + state_dict = deepcopy(model.state_dict()) + + for _ in range(2): + optimizer.zero_grad() + x = model(torch.randn(1, 32).to(self.device)) + loss = x.sum() + self.backward(loss, model=model) + optimizer.step() + + for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()): + assert not torch.equal(mw_b, mw_a) + + seed_everything(42) + model_1 = BoringModel() + optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001) + + seed_everything(42) + model_2 = BoringModel() + optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001) + + for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): + assert torch.equal(mw_1, mw_2) + + model_1, optimizer_1 = self.setup(model_1, optimizer_1) + model_2, optimizer_2 = self.setup(model_2, optimizer_2) + + seed_everything(42) + data_list = [] + for _ in range(2): + optimizer_1.zero_grad() + data = torch.randn(1, 32).to(self.device) + data_list.append(data) + x = model_1(data) + loss = x.sum() + self.backward(loss, model=model_1) + optimizer_1.step() + + for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): + assert not torch.equal(mw_1, mw_2) + + for data in data_list: + optimizer_2.zero_grad() + x = model_2(data) + loss = x.sum() + self.backward(loss, model=model_2) + optimizer_2.step() + + for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): + assert torch.equal(mw_1, mw_2) + + # Verify collectives works as expected + ranks = self.all_gather(torch.tensor([self.local_rank]).to(self.device)) + assert torch.equal(ranks.cpu(), torch.tensor([[0], [1]])) + assert self.broadcast(True) + assert self.is_global_zero == (self.local_rank == 0) + + Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() diff --git a/tests/lite/test_parity.py b/tests/lite/test_parity.py new file mode 100644 index 0000000000000..4b52448ceff71 --- /dev/null +++ b/tests/lite/test_parity.py @@ -0,0 +1,237 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from contextlib import contextmanager +from copy import deepcopy +from functools import partial +from typing import Callable, Generator + +import pytest +import torch +import torch.distributed +import torch.multiprocessing as mp +import torch.nn.functional +from torch import nn +from torch.cuda import is_available +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from pytorch_lightning import seed_everything +from pytorch_lightning.lite import LightningLite +from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_DEV_1_10 +from tests.helpers.boring_model import RandomDataset +from tests.helpers.runif import RunIf + + +class BoringModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2, bias=False) + + def forward(self, x): + x = self.layer(x) + return torch.nn.functional.mse_loss(x, torch.ones_like(x)) + + +def configure_optimizers(module: nn.Module): + return torch.optim.SGD(module.parameters(), lr=0.0001) + + +def main( + move_to_device: Callable, + model: nn.Module, + train_dataloader: DataLoader, + num_epochs: int = 10, +): + model = move_to_device(model) + optimizer = configure_optimizers(model) + + for _ in range(num_epochs): + model.train() + for batch in train_dataloader: + batch = move_to_device(batch) + optimizer.zero_grad() + loss = model(batch) + loss.backward() + optimizer.step() + + return model.state_dict() + + +class LiteRunner(LightningLite): + def run(self, model: nn.Module, train_dataloader: DataLoader, num_epochs: int = 10, tmpdir: str = None): + optimizer = configure_optimizers(model) + model, optimizer = self.setup(model, optimizer) + train_dataloader = self.setup_dataloaders(train_dataloader) + + model.train() + for _ in range(num_epochs): + for batch in train_dataloader: + batch = self.to_device(batch) + optimizer.zero_grad() + loss = model(batch) + self.backward(loss) + optimizer.step() + + if isinstance(self._strategy, DDPSpawnPlugin) and tmpdir and self.global_rank == 0: + checkpoint_path = os.path.join(tmpdir, "model.pt") + atomic_save(model.state_dict(), checkpoint_path) + return checkpoint_path + + +@contextmanager +def precision_context(precision, accelerator) -> Generator[None, None, None]: + if precision == 32: + yield + return + if precision == 16 and accelerator == "gpu": + with torch.cuda.amp.autocast(): + yield + elif accelerator == "cpu": + with torch.cpu.amp.autocast(dtype=torch.float16 if precision == 16 else torch.bfloat16): + yield + else: + with torch.cuda.amp.autocast(): + yield + + +@pytest.mark.parametrize( + "precision, strategy, devices, accelerator", + [ + pytest.param(32, None, 1, "cpu"), + pytest.param(32, None, 1, "gpu", marks=pytest.mark.skipif(not is_available(), reason="requires a GPU")), + pytest.param(16, None, 1, "gpu", marks=pytest.mark.skipif(not is_available(), reason="requires a GPU")), + pytest.param( + "bf16", + None, + 1, + "gpu", + marks=pytest.mark.skipif( + not (_TORCH_GREATER_EQUAL_DEV_1_10 and is_available()), + reason="bfloat16 and requires GPU isn't available.", + ), + ), + ], +) +def test_boring_lite_model_single_device(precision, strategy, devices, accelerator, tmpdir): + seed_everything(42) + train_dataloader = DataLoader(RandomDataset(32, 8)) + model = BoringModel() + num_epochs = 1 + state_dict = deepcopy(model.state_dict()) + + lite = LiteRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + lite.run(model, train_dataloader, num_epochs=num_epochs) + lite_state_dict = model.state_dict() + + with precision_context(precision, accelerator): + model.load_state_dict(state_dict) + pure_state_dict = main(lite.to_device, model, train_dataloader, num_epochs=num_epochs) + + state_dict = apply_to_collection(state_dict, torch.Tensor, lite.to_device) + for w_pure, w_lite in zip(state_dict.values(), lite_state_dict.values()): + assert not torch.equal(w_pure, w_lite) + + for w_pure, w_lite in zip(pure_state_dict.values(), lite_state_dict.values()): + assert torch.equal(w_pure, w_lite) + + +def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir): + os.environ["LOCAL_RANK"] = str(rank) + if torch.distributed.is_available() and not torch.distributed.is_initialized(): + torch.distributed.init_process_group("gloo", rank=rank, world_size=2) + + to_device = partial(move_data_to_device, device=torch.device("cuda", rank)) + model = DistributedDataParallel( + to_device(model), + device_ids=[rank], + ) + train_dataloader = DataLoader( + train_dataloader.dataset, + sampler=DistributedSampler(train_dataloader.dataset, rank=rank, num_replicas=2, seed=42, drop_last=False), + ) + with precision_context(precision, accelerator): + main(to_device, model, train_dataloader, num_epochs=num_epochs) + + if rank == 0: + atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt")) + + +# @pytest.mark.skipif(True, reason="Skipping as it takes 80 seconds.") +@RunIf(min_gpus=2) +@pytest.mark.parametrize( + "precision, strategy, devices, accelerator", + [ + (32, "ddp_spawn", 2, "gpu"), + ], +) +def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator, tmpdir): + seed_everything(42) + train_dataloader = DataLoader(RandomDataset(32, 8)) + model = BoringModel() + num_epochs = 1 + state_dict = deepcopy(model.state_dict()) + + lite = LiteRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + checkpoint_path = lite.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir) + spawn_model_state_dict = torch.load(checkpoint_path) + + for w_pure, w_lite in zip(state_dict.values(), spawn_model_state_dict.values()): + assert not torch.equal(w_pure.cpu(), w_lite.cpu()) + + model.load_state_dict(state_dict) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(find_free_network_port()) + mp.spawn(run, args=(model, train_dataloader, num_epochs, precision, accelerator, tmpdir), nprocs=2) + spawn_pure_model_state_dict = torch.load(os.path.join(tmpdir, "model_spawn.pt")) + + for w_pure, w_lite in zip(spawn_pure_model_state_dict.values(), spawn_model_state_dict.values()): + assert torch.equal(w_pure.cpu(), w_lite.cpu()) + + +@RunIf(min_gpus=2, special=True) +@pytest.mark.parametrize( + "precision, strategy, devices, accelerator", + [ + (32, "ddp", 2, "gpu"), + ], +) +def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir): + seed_everything(42) + train_dataloader = DataLoader(RandomDataset(32, 4)) + model = BoringModel() + num_epochs = 1 + state_dict = deepcopy(model.state_dict()) + + lite = LiteRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) + lite.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir) + + lite_model_state_dict = model.state_dict() + + for w_pure, w_lite in zip(state_dict.values(), lite_model_state_dict.values()): + assert not torch.equal(w_pure.cpu(), w_lite.cpu()) + + seed_everything(42) + train_dataloader = DataLoader(RandomDataset(32, 4)) + model = BoringModel() + run(lite.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir) + pure_model_state_dict = model.state_dict() + + for w_pure, w_lite in zip(pure_model_state_dict.values(), lite_model_state_dict.values()): + assert torch.equal(w_pure.cpu(), w_lite.cpu()) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py new file mode 100644 index 0000000000000..faed290b75629 --- /dev/null +++ b/tests/lite/test_wrappers.py @@ -0,0 +1,106 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import ANY, Mock + +import pytest +import torch + +from pytorch_lightning.lite import LightningLite +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from tests.helpers.runif import RunIf + + +class EmptyLite(LightningLite): + def run(self): + pass + + +def test_lite_module_wraps(): + """Test that the wrapped module is accessible via the property.""" + module = Mock() + assert _LiteModule(module, Mock()).module is module + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize( + "precision, input_type, expected_type", + [ + (32, torch.float16, torch.float32), + (32, torch.float32, torch.float32), + (32, torch.float64, torch.float32), + (16, torch.float32, torch.float16), + (16, torch.float64, torch.float16), + # ("mixed", torch.float32, torch.float16), # TODO: support precision="mixed" + ], +) +def test_lite_module_forward_conversion(precision, input_type, expected_type): + """Test that the LiteModule performs autocasting on the input tensors and during forward().""" + lite = EmptyLite(precision=precision, accelerator="gpu", devices=1) + device = torch.device("cuda", 0) + + def check_autocast(forward_input): + assert precision not in (16, "mixed") or torch.is_autocast_enabled() + return forward_input + + module = Mock(wraps=torch.nn.Linear(1, 1), side_effect=check_autocast) + lite_module = _LiteModule(module, lite._accelerator).to(device) + out = lite_module(torch.rand(1, dtype=input_type, device=device)) + assert module.call_args[0][0].dtype == expected_type + assert out.dtype == torch.get_default_dtype() + + +@pytest.mark.parametrize( + "src_device, dest_device", + [ + (torch.device("cpu"), torch.device("cpu")), + pytest.param(torch.device("cpu"), torch.device("cuda", 0), marks=RunIf(min_gpus=1)), + pytest.param(torch.device("cuda", 0), torch.device("cpu"), marks=RunIf(min_gpus=1)), + ], +) +def test_lite_dataloader_device_placement(src_device, dest_device): + """Test that the LiteDataLoader moves data to the device in its iterator.""" + sample0 = torch.tensor(0, device=src_device) + sample1 = torch.tensor(1, device=src_device) + sample2 = {"data": torch.tensor(2, device=src_device)} + sample3 = {"data": torch.tensor(3, device=src_device)} + data = [sample0, sample1, sample2, sample3] + lite_dataloader = _LiteDataLoader(device=dest_device, dataset=data, batch_size=2) + iterator = iter(lite_dataloader) + + batch0 = next(iterator) + assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device)) + + batch1 = next(iterator) + assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) + + +def test_lite_optimizer_wraps(): + """Test that the LiteOptimizer fully wraps the optimizer.""" + optimizer_cls = torch.optim.SGD + optimizer = Mock(spec=optimizer_cls) + lite_optimizer = _LiteOptimizer(optimizer, Mock()) + assert lite_optimizer.optimizer is optimizer + assert isinstance(lite_optimizer, optimizer_cls) + + +def test_lite_optimizer_steps(): + """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" + optimizer = Mock() + accelerator = Mock() + lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) + lite_optimizer.step() + accelerator.optimizer_step.assert_called_once() + accelerator.optimizer_step.assert_called_with(optimizer, opt_idx=0, lambda_closure=ANY, model=accelerator.model) + lite_optimizer.zero_grad() + optimizer.zero_grad.assert_called_once() From be390986809d39f6a509e6847f44bfa32c31c2de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 27 Oct 2021 09:59:06 +0200 Subject: [PATCH 02/47] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e145012a2c914..462c4dc1f70ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -220,7 +220,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Implemented `DeepSpeedPlugin._setup_model_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064)) * Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_model_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064)) * Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023)) - + * Added `pytorch_lightning.lite` package ([#?](https://github.com/PyTorchLightning/pytorch-lightning/pull/?)) - Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972)) From e45f73661093a3ac5436cc66b86c25a2c9c3593a Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 27 Oct 2021 17:28:27 +0100 Subject: [PATCH 03/47] update --- tests/lite/test_parity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lite/test_parity.py b/tests/lite/test_parity.py index 4b52448ceff71..48ed2bb22cd98 100644 --- a/tests/lite/test_parity.py +++ b/tests/lite/test_parity.py @@ -34,7 +34,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_DEV_1_10 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 from tests.helpers.boring_model import RandomDataset from tests.helpers.runif import RunIf @@ -123,7 +123,7 @@ def precision_context(precision, accelerator) -> Generator[None, None, None]: 1, "gpu", marks=pytest.mark.skipif( - not (_TORCH_GREATER_EQUAL_DEV_1_10 and is_available()), + not (_TORCH_GREATER_EQUAL_1_10 and is_available()), reason="bfloat16 and requires GPU isn't available.", ), ), From 0decebae66341d9035faf9fa5ca27a12282381dd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 27 Oct 2021 19:20:37 +0200 Subject: [PATCH 04/47] Docstrings and CHANGELOG --- CHANGELOG.md | 2 +- pytorch_lightning/lite/lite.py | 25 +++++++++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 462c4dc1f70ce..5f3ca024aeddc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -220,7 +220,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Implemented `DeepSpeedPlugin._setup_model_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064)) * Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_model_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064)) * Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023)) - * Added `pytorch_lightning.lite` package ([#?](https://github.com/PyTorchLightning/pytorch-lightning/pull/?)) + * Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175)) - Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972)) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 49798b138567e..0394b6d1d7884 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, overload, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -46,21 +46,22 @@ class LightningLite(ABC): """Lite accelerates your PyTorch training or inference code with minimal changes required. - - Automatic placement of models and data onto the device - - Automatic support for mixed and double precision (smaller memory footprint) + - Automatic placement of models and data onto the device. + - Automatic support for mixed and double precision (smaller memory footprint). - Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies - (data-parallel training, sharded training, etc.) - - Automated spawning of processes, no launch utilities required - - Multi-node support + (data-parallel training, sharded training, etc.). + - Automated spawning of processes, no launch utilities required. + - Multi-node support. Args: - accelerator: The hardware to run on. Possible choices are: cpu, gpu, tpu, auto. + accelerator: The hardware to run on. Possible choices are: ```cpu"``, ```gpu"``, ```tpu"``, ```auto"``. strategy: Strategy for how to run across multiple devices. Possible choices are: - dp, ddp, ddp_spawn, tpu_spawn, deepspeed, ddp_sharded. - devices: Number of devices to train on (int) or which GPUs to train on (list or str). The value applies - per node. + ```dp"``, ```ddp"``, ```ddp_spawn"``, ```deepspeed"``, ```ddp_sharded"``. + devices: Number of devices to train on (``int``) or which GPUs to train on (``list`` or ``str``). + The value applies per node. num_nodes: Number of GPU nodes for distributed training. - precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). + precision: Double precision (``64``), full precision (``32``), half precision (``16``), + or bfloat16 precision (```bf16"``). plugins: One or several custom plugins gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``. tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``. @@ -250,7 +251,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No **kwargs: Optional named keyword arguments passed to the underlying backward function. Note: - When using ``strategy='deepspeed'`` and multiple models were setup, it is required to pass in the + When using ``strategy='deepspeed"`` and multiple models were setup, it is required to pass in the model as argument here. """ module = model.module if model is not None else model From 5d14e832bae5c4a43c961e298541107331425bcf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 27 Oct 2021 19:24:08 +0200 Subject: [PATCH 05/47] Fixes to previous commit. Mention devices=auto (not yet implemented). Remove tpu spawn --- pytorch_lightning/lite/lite.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 0394b6d1d7884..075b3ad729b36 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -54,14 +54,14 @@ class LightningLite(ABC): - Multi-node support. Args: - accelerator: The hardware to run on. Possible choices are: ```cpu"``, ```gpu"``, ```tpu"``, ```auto"``. + accelerator: The hardware to run on. Possible choices are: ``"cpu"``, ``"gpu"``, ``"tpu"``, ``"auto"``. strategy: Strategy for how to run across multiple devices. Possible choices are: - ```dp"``, ```ddp"``, ```ddp_spawn"``, ```deepspeed"``, ```ddp_sharded"``. - devices: Number of devices to train on (``int``) or which GPUs to train on (``list`` or ``str``). + ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``. + devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. The value applies per node. num_nodes: Number of GPU nodes for distributed training. precision: Double precision (``64``), full precision (``32``), half precision (``16``), - or bfloat16 precision (```bf16"``). + or bfloat16 precision (``"bf16"``). plugins: One or several custom plugins gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``. tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``. @@ -448,7 +448,6 @@ def _supported_strategy_types() -> Sequence[str]: DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, - DistributedType.TPU_SPAWN, DistributedType.DEEPSPEED, DistributedType.DDP_SHARDED, DistributedType.DDP_SHARDED_SPAWN, From 11862e869e04c121c71c2268320382c8819a0f48 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 27 Oct 2021 22:34:29 +0200 Subject: [PATCH 06/47] Fix test --- tests/lite/test_lite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index def9ce29ac9dc..b9508a64ec0e4 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -209,7 +209,6 @@ def test_setup_dataloaders_distributed_sampler_not_needed(): DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, - DistributedType.TPU_SPAWN, pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), From ffed5ced1c791f714c01b44ba6f42dc30f532179 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 27 Oct 2021 22:34:29 +0200 Subject: [PATCH 07/47] Fix test --- tests/lite/test_lite.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index def9ce29ac9dc..f47f9f1df1434 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -209,7 +209,6 @@ def test_setup_dataloaders_distributed_sampler_not_needed(): DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, - DistributedType.TPU_SPAWN, pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), @@ -238,7 +237,6 @@ def test_setup_dataloaders_replace_custom_sampler(strategy): DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, - DistributedType.TPU_SPAWN, pytest.param(DistributedType.DEEPSPEED, marks=RunIf(deepspeed=True)), pytest.param(DistributedType.DDP_SHARDED, marks=RunIf(fairscale=True)), pytest.param(DistributedType.DDP_SHARDED_SPAWN, marks=RunIf(fairscale=True)), From 93b7940b10575a6c90806f5ed380daac4f61a7b2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 28 Oct 2021 09:28:01 +0100 Subject: [PATCH 08/47] update --- tests/helpers/pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index 3e5066d708da0..643d3e50cb894 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -67,7 +67,7 @@ def run_model_test( assert trainer.state.finished, f"Training failed with {trainer.state}" # Check that the model is actually changed post-training change_ratio = torch.norm(initial_values - post_train_values) - assert change_ratio > 0.1, f"the model is changed of {change_ratio}" + assert change_ratio > 0.03, f"the model is changed of {change_ratio}" # test model loading pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model)) From a6414a2f25cbe6bbc4fed6b87225ff4fee0ee7c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 11:41:53 +0200 Subject: [PATCH 09/47] update access to deepspeed internal vars --- pytorch_lightning/lite/lite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 075b3ad729b36..e373ce3bfa3a7 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -365,7 +365,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: run_method = partial(self._run_with_sharded_context, run_method) if isinstance(self._strategy, DDPSpawnPlugin): - return self._strategy.spawn(run_method, *args, **kwargs) + return self._strategy.spawn(run_method, *args, return_result=True, **kwargs) else: return run_method(*args, **kwargs) @@ -402,7 +402,7 @@ def _set_deepspeed_precision_variables(self) -> None: amp_type = self._accelerator_connector.amp_type amp_level = self._accelerator_connector.amp_level precision = self._accelerator_connector.precision - self._strategy.amp_level, self._strategy.amp_type, self._strategy._precision = amp_level, amp_type, precision + self._strategy._amp_level, self._strategy._amp_type, self._strategy._precision = amp_level, amp_type, precision def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: return ( From 5e1aeb89569484c76b6dec7a1cdcbc38ad04b8ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 12:38:31 +0200 Subject: [PATCH 10/47] fix check for multiple models in deepspeed --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index e373ce3bfa3a7..1471a13c2e70d 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -255,7 +255,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No model as argument here. """ module = model.module if model is not None else model - if self._num_models > 0 and isinstance(self._strategy, DeepSpeedPlugin): + if self._num_models > 1 and isinstance(self._strategy, DeepSpeedPlugin): if model is None: raise MisconfigurationException( "When using multiple models + deepspeed, please provide the model used to perform the optimization." From f885b35c7e267e2b43007a70903eb73b104e2166 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 13:32:52 +0200 Subject: [PATCH 11/47] fix deepspeed precision --- pytorch_lightning/lite/lite.py | 20 ++++++++++++------- .../plugins/training_type/deepspeed.py | 3 ++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 1471a13c2e70d..71779670121ed 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -255,14 +255,20 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No model as argument here. """ module = model.module if model is not None else model - if self._num_models > 1 and isinstance(self._strategy, DeepSpeedPlugin): + if isinstance(self._strategy, DeepSpeedPlugin): if model is None: - raise MisconfigurationException( - "When using multiple models + deepspeed, please provide the model used to perform the optimization." - ) - - # requires to attach the current `DeepSpeedEngine` for the `_LiteOptimizer.step` call. - self._strategy.model = module + if self._num_models == 0: + raise MisconfigurationException( + "No models were setup for backward. Did you forget to call `self.setup`?" + ) + if self._num_models > 1: + raise MisconfigurationException( + "When using multiple models + deepspeed, please provide the model used to perform the optimization." + ) + module = self._strategy.model + else: + # requires to attach the current `DeepSpeedEngine` for the `_LiteOptimizer.step` call. + self._strategy.model = module self._precision_plugin._run_backward(tensor, module, *args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index f6b5481dd5ef9..b06406570306c 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -336,7 +336,8 @@ def precision(self) -> Union[str, int]: @property def amp_level(self) -> Optional[str]: - return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level + if self._amp_type == AMPType.APEX: + return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level @property def amp_type(self) -> Optional[str]: From db34e0907e433b9d798679382119f75e2a355100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 15:25:07 +0200 Subject: [PATCH 12/47] fix line too long --- pytorch_lightning/lite/lite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 71779670121ed..1862143307663 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -263,7 +263,8 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No ) if self._num_models > 1: raise MisconfigurationException( - "When using multiple models + deepspeed, please provide the model used to perform the optimization." + "When using multiple models + deepspeed, please provide the model used to perform" + " the optimization." ) module = self._strategy.model else: From 992fd45661631b65986028e1fe17ec627a1a3564 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 15:57:37 +0200 Subject: [PATCH 13/47] Minor changes --- pytorch_lightning/lite/lite.py | 20 ++++++++++---------- tests/lite/test_lite.py | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 1862143307663..7d751adcc139b 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -103,7 +103,7 @@ def __init__( self._accelerator = self._accelerator_connector.accelerator self._strategy = self._accelerator.training_type_plugin self._precision_plugin = self._accelerator.precision_plugin - self._num_models: int = 0 + self._models_setup: int = 0 # wrap the run method so we can inject setup logic or spawn processes for the user setattr(self, "run", self._run_wrapper(self.run)) @@ -146,8 +146,8 @@ def run(self, *args: Any, **kwargs: Any) -> Any: """All the code inside this run method gets accelerated by Lite. Args: - *args: Add any positional arguments you need, e.g., the hyperparameters for your model - **kwargs: Add any keyword arguments you need, e.g., the hyperparameters for your model + *args: Add any positional arguments you need, e.g., the hyperparameters for your model. + **kwargs: Add any keyword arguments you need, e.g., the hyperparameters for your model. """ def setup( @@ -176,7 +176,7 @@ def setup( model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers)) model = _LiteModule(model, self._accelerator) optimizers = [_LiteOptimizer(optimizer=optimizer, accelerator=self._accelerator) for optimizer in optimizers] - self._num_models += 1 + self._models_setup += 1 if optimizers: return [model] + optimizers # type: ignore return model @@ -251,20 +251,20 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No **kwargs: Optional named keyword arguments passed to the underlying backward function. Note: - When using ``strategy='deepspeed"`` and multiple models were setup, it is required to pass in the + When using ``strategy="deepspeed"`` and multiple models were setup, it is required to pass in the model as argument here. """ module = model.module if model is not None else model if isinstance(self._strategy, DeepSpeedPlugin): if model is None: - if self._num_models == 0: + if self._models_setup == 0: raise MisconfigurationException( - "No models were setup for backward. Did you forget to call `self.setup`?" + "No models were setup for backward. Did you forget to call `self.setup()`?" ) - if self._num_models > 1: + if self._models_setup > 1: raise MisconfigurationException( "When using multiple models + deepspeed, please provide the model used to perform" - " the optimization." + " the optimization: `self.backward(loss, model=model)`" ) module = self._strategy.model else: @@ -450,7 +450,7 @@ def _supported_device_types() -> Sequence[DeviceType]: ) @staticmethod - def _supported_strategy_types() -> Sequence[str]: + def _supported_strategy_types() -> Sequence[DistributedType]: return ( DistributedType.DP, DistributedType.DDP, diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index f47f9f1df1434..8ff8ccf863649 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -129,12 +129,12 @@ def test_setup_tracks_num_models(): model = nn.Linear(1, 2) optimizer = torch.optim.Adam(model.parameters()) - assert lite._num_models == 0 + assert lite._models_setup == 0 lite.setup(model, optimizer) - assert lite._num_models == 1 + assert lite._models_setup == 1 lite.setup(model, optimizer) - assert lite._num_models == 2 + assert lite._models_setup == 2 def test_setup_dataloaders_unsupported_type(): From b8d44ce09085775c9165c5a69c7b6c67e28403f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 16:11:55 +0200 Subject: [PATCH 14/47] remove identity wrapper --- pytorch_lightning/lite/lite.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 1862143307663..d83b2bd6cf165 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -106,7 +106,7 @@ def __init__( self._num_models: int = 0 # wrap the run method so we can inject setup logic or spawn processes for the user - setattr(self, "run", self._run_wrapper(self.run)) + setattr(self, "run", partial(self._run_impl, self.run)) @property def device(self) -> torch.device: @@ -361,9 +361,6 @@ def save_checkpoint(self, filepath: Union[str, Path], content: Dict[str, Any]) - """ self._strategy.save_checkpoint(content, filepath) - def _run_wrapper(self, run_method: Callable) -> Callable: - return partial(self._run_impl, run_method) - def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._set_plugin_specific_precision_variables() self._accelerator.setup_environment() From 04094c39cd69567407603ac54d0952b6d03f06c4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 16:15:55 +0200 Subject: [PATCH 15/47] Same annotations as Lightning which are identical to those in torch --- pytorch_lightning/lite/wrappers.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index e1d16ca8a3384..9c6641bfb4413 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union +from typing import Any, Callable, Generator, Iterator, List, Optional, Union import torch from torch import nn as nn @@ -48,27 +48,27 @@ def optimizer(self) -> Optimizer: return self._optimizer @property - def state(self) -> Dict[str, torch.Tensor]: - return self._optimizer.state - - @state.setter - def state(self, state: Dict[str, torch.Tensor]) -> None: - self._optimizer.state = state - - @property - def defaults(self) -> Dict[str, Any]: + def defaults(self) -> dict: return self._optimizer.defaults @defaults.setter - def defaults(self, defaults: Dict[str, Any]) -> None: + def defaults(self, defaults: dict) -> None: self._optimizer.defaults = defaults @property - def param_groups(self) -> List[Dict[str, torch.Tensor]]: + def state(self) -> dict: + return self._optimizer.state + + @state.setter + def state(self, state: dict) -> None: + self._optimizer.state = state + + @property + def param_groups(self) -> List[dict]: return self._optimizer.param_groups @param_groups.setter - def param_groups(self, param_groups: List[Dict[str, torch.Tensor]]) -> None: + def param_groups(self, param_groups: List[dict]) -> None: self._optimizer.param_groups = param_groups def step(self, closure: Optional[Callable] = None) -> None: From 1d9920ac090698ad4eebc1984b08a50822b48eca Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 16:17:21 +0200 Subject: [PATCH 16/47] Add comment --- pytorch_lightning/lite/lite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index e9770531f66ed..0b749fc85c861 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -391,6 +391,7 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) - # remains invalid. We need to update the references to point to the parameter tensors on the device. params_on_cpu = dict(model.named_parameters()) model = self.to_device(model) + # XLA makes a copy on the parameters, so the device should is not the same before and after to_device. params_on_device = dict(model.named_parameters()) mapping = {param: params_on_device[name] for name, param in params_on_cpu.items()} From a6df052f902312d3c43bba41129b3543a9837556 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 16:31:15 +0200 Subject: [PATCH 17/47] Simplify _LiteOptimizer --- pytorch_lightning/lite/wrappers.py | 39 +++--------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 9c6641bfb4413..c1612e7298ab3 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Generator, Iterator, List, Optional, Union +from typing import Any, Callable, Generator, Iterator, Optional, Union import torch from torch import nn as nn @@ -38,51 +38,20 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: optimizer: The optimizer to wrap accelerator: Reference to the accelerator for handling the optimizer step """ - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step",)} self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) - self._optimizer = optimizer + self.optimizer = optimizer self._accelerator = accelerator - @property - def optimizer(self) -> Optimizer: - return self._optimizer - - @property - def defaults(self) -> dict: - return self._optimizer.defaults - - @defaults.setter - def defaults(self, defaults: dict) -> None: - self._optimizer.defaults = defaults - - @property - def state(self) -> dict: - return self._optimizer.state - - @state.setter - def state(self, state: dict) -> None: - self._optimizer.state = state - - @property - def param_groups(self) -> List[dict]: - return self._optimizer.param_groups - - @param_groups.setter - def param_groups(self, param_groups: List[dict]) -> None: - self._optimizer.param_groups = param_groups - def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure self._accelerator.optimizer_step( - self._optimizer, + self.optimizer, opt_idx=0, lambda_closure=closure, model=self._accelerator.model, ) - def zero_grad(self, *args: Any, **kwargs: Any) -> None: - self._optimizer.zero_grad(*args, **kwargs) - class _LiteModule(nn.Module): # TODO: Pass in the precision plugin instead of accelerator From 5208e19544005c34a603da8f9a30f9c9861302bc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 16:44:29 +0200 Subject: [PATCH 18/47] Didn't mean to remove this :) --- pytorch_lightning/lite/wrappers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index c1612e7298ab3..fe42778dcb39e 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -40,9 +40,13 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: """ self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step",)} self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) - self.optimizer = optimizer + self._optimizer = optimizer self._accelerator = accelerator + @property + def optimizer(self) -> Optimizer: + return self._optimizer + def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure self._accelerator.optimizer_step( From 31406aeaef9211f3478a4b62c29741396a4aa5ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 16:58:22 +0200 Subject: [PATCH 19/47] rename cast to autocast --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 0b749fc85c861..42eb49f7a9e44 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -274,7 +274,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = No self._precision_plugin._run_backward(tensor, module, *args, **kwargs) @contextmanager - def cast(self) -> Generator[None, None, None]: + def autocast(self) -> Generator[None, None, None]: """A context manager to automatically convert operations for the chosen precision. Use this only if the `forward` method of your model does not cover all operations you wish to run with the From bda0f8ab6c92ac8a09c6f9556185c980435d5710 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 17:00:21 +0200 Subject: [PATCH 20/47] test: Remove unused parametrization --- tests/lite/test_lite.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 8ff8ccf863649..6135d67d0d026 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -47,14 +47,14 @@ def forward(self, x): return torch.nn.functional.mse_loss(x, torch.ones_like(x)) -@pytest.mark.parametrize("accelerator", ["coconut"]) -def test_unsupported_accelerator(accelerator): +def test_unsupported_accelerator(): + accelerator = "coconut" with pytest.raises(MisconfigurationException, match=f"`accelerator={repr(accelerator)}` is not a valid choice"): EmptyLite(accelerator=accelerator) -@pytest.mark.parametrize("strategy", ["coconut"]) -def test_unsupported_strategy(strategy): +def test_unsupported_strategy(): + strategy = "coconut" with pytest.raises(MisconfigurationException, match=f"`strategy={repr(strategy)}` is not a valid choice"): EmptyLite(strategy=strategy) From c34d00606de8f71dea1a00ca9efe63919b534985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 17:01:20 +0200 Subject: [PATCH 21/47] rename save_checkpoint to save --- pytorch_lightning/lite/lite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 42eb49f7a9e44..a54e59c7d5017 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -349,15 +349,15 @@ def all_gather( def broadcast(self, obj: object, src: int = 0) -> object: return self._strategy.broadcast(obj, src=src) - def save_checkpoint(self, filepath: Union[str, Path], content: Dict[str, Any]) -> None: + def save(self, content: Dict[str, Any], filepath: Union[str, Path]) -> None: """Save a checkpoint contents to a file. How and which processes save gets determined by the `strategy`. For example, the `ddp` strategy saves checkpoints only on process 0. Args: - filepath: A path to where the file should be saved content: A dictionary with contents, i.e., the state dict of your model + filepath: A path to where the file should be saved """ self._strategy.save_checkpoint(content, filepath) From f45c2c8223633590f6ffd3ac56689415ed67b020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 17:01:58 +0200 Subject: [PATCH 22/47] update docstring --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index a54e59c7d5017..3230f63db89cb 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -350,7 +350,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: return self._strategy.broadcast(obj, src=src) def save(self, content: Dict[str, Any], filepath: Union[str, Path]) -> None: - """Save a checkpoint contents to a file. + """Save checkpoint contents to a file. How and which processes save gets determined by the `strategy`. For example, the `ddp` strategy saves checkpoints only on process 0. From c84acb1743016303a206957c640aceff016d15e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 17:02:41 +0200 Subject: [PATCH 23/47] update comment --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 3230f63db89cb..2c1bf2a8ce393 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -391,7 +391,7 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) - # remains invalid. We need to update the references to point to the parameter tensors on the device. params_on_cpu = dict(model.named_parameters()) model = self.to_device(model) - # XLA makes a copy on the parameters, so the device should is not the same before and after to_device. + # XLA makes a copy on the parameters, so the device is not the same before and after to_device. params_on_device = dict(model.named_parameters()) mapping = {param: params_on_device[name] for name, param in params_on_cpu.items()} From 92752e66e2cc988ebc2e30a5dd5ce018ff43012d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 17:05:51 +0200 Subject: [PATCH 24/47] add load --- pytorch_lightning/lite/lite.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 2c1bf2a8ce393..d913ea580f912 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -361,6 +361,16 @@ def save(self, content: Dict[str, Any], filepath: Union[str, Path]) -> None: """ self._strategy.save_checkpoint(content, filepath) + def load(self, filepath: Union[str, Path]) -> Any: + """Load a checkpoint from a file. + + How and which processes load gets determined by the `strategy` + + Args: + filepath: A path to where the file is located + """ + return self._strategy.load_checkpoint(filepath) + def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._set_plugin_specific_precision_variables() self._accelerator.setup_environment() From c0ffc712d67d5f1afd9033b69a3dea376efa1885 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 17:11:22 +0200 Subject: [PATCH 25/47] tests: update autocast use --- tests/lite/test_parity.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/tests/lite/test_parity.py b/tests/lite/test_parity.py index 48ed2bb22cd98..2c7c58d249a13 100644 --- a/tests/lite/test_parity.py +++ b/tests/lite/test_parity.py @@ -23,7 +23,6 @@ import torch.multiprocessing as mp import torch.nn.functional from torch import nn -from torch.cuda import is_available from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -34,7 +33,6 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 from tests.helpers.boring_model import RandomDataset from tests.helpers.runif import RunIf @@ -100,14 +98,11 @@ def precision_context(precision, accelerator) -> Generator[None, None, None]: if precision == 32: yield return - if precision == 16 and accelerator == "gpu": + if accelerator == "gpu": with torch.cuda.amp.autocast(): yield elif accelerator == "cpu": - with torch.cpu.amp.autocast(dtype=torch.float16 if precision == 16 else torch.bfloat16): - yield - else: - with torch.cuda.amp.autocast(): + with torch.cpu.amp.autocast(): yield @@ -115,18 +110,9 @@ def precision_context(precision, accelerator) -> Generator[None, None, None]: "precision, strategy, devices, accelerator", [ pytest.param(32, None, 1, "cpu"), - pytest.param(32, None, 1, "gpu", marks=pytest.mark.skipif(not is_available(), reason="requires a GPU")), - pytest.param(16, None, 1, "gpu", marks=pytest.mark.skipif(not is_available(), reason="requires a GPU")), - pytest.param( - "bf16", - None, - 1, - "gpu", - marks=pytest.mark.skipif( - not (_TORCH_GREATER_EQUAL_1_10 and is_available()), - reason="bfloat16 and requires GPU isn't available.", - ), - ), + pytest.param(32, None, 1, "gpu", marks=RunIf(min_gpus=1)), + pytest.param(16, None, 1, "gpu", marks=RunIf(min_gpus=1)), + pytest.param("bf16", None, 1, "gpu", marks=RunIf(min_torch="1.10", min_gpus=1)), ], ) def test_boring_lite_model_single_device(precision, strategy, devices, accelerator, tmpdir): From af400092d3087b1439f61d9ca499db97556372ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 17:26:59 +0200 Subject: [PATCH 26/47] add test for autocast --- tests/lite/test_lite.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 6135d67d0d026..60da70c09afe0 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -324,6 +324,25 @@ def test_backward_model_input_required(): lite.backward(loss) +@mock.patch("torch.cuda.is_available") +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + "plugin_context, precision", + [ + ("pytorch_lightning.plugins.precision.double.DoublePrecisionPlugin.forward_context", 64), + ("pytorch_lightning.plugins.precision.precision_plugin.PrecisionPlugin.forward_context", 32), + ("pytorch_lightning.plugins.precision.native_amp.NativeMixedPrecisionPlugin.forward_context", 16), + ], +) +def test_autocast(_, __, plugin_context, precision): + lite = EmptyLite(gpus=1, precision=precision) + with mock.patch(plugin_context) as context: + context().__enter__.assert_not_called() + with lite.autocast(): + context().__enter__.assert_called() + context().__exit__.assert_called() + + @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multiple_models(): class Lite(LightningLite): From eb9b92ec7abf624722bff5d7dde832c528f2fa6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 17:32:10 +0200 Subject: [PATCH 27/47] simplify test --- tests/lite/test_lite.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 60da70c09afe0..0fde335a60bdb 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -14,7 +14,7 @@ from copy import deepcopy from unittest import mock -from unittest.mock import Mock, PropertyMock +from unittest.mock import Mock, PropertyMock, MagicMock import pytest import torch @@ -324,23 +324,14 @@ def test_backward_model_input_required(): lite.backward(loss) -@mock.patch("torch.cuda.is_available") -@mock.patch("torch.cuda.device_count", return_value=2) -@pytest.mark.parametrize( - "plugin_context, precision", - [ - ("pytorch_lightning.plugins.precision.double.DoublePrecisionPlugin.forward_context", 64), - ("pytorch_lightning.plugins.precision.precision_plugin.PrecisionPlugin.forward_context", 32), - ("pytorch_lightning.plugins.precision.native_amp.NativeMixedPrecisionPlugin.forward_context", 16), - ], -) -def test_autocast(_, __, plugin_context, precision): - lite = EmptyLite(gpus=1, precision=precision) - with mock.patch(plugin_context) as context: - context().__enter__.assert_not_called() - with lite.autocast(): - context().__enter__.assert_called() - context().__exit__.assert_called() +def test_autocast(): + lite = EmptyLite() + lite._precision_plugin.forward_context = MagicMock() + + lite._precision_plugin.forward_context().__enter__.assert_not_called() + with lite.autocast(): + lite._precision_plugin.forward_context().__enter__.assert_called() + lite._precision_plugin.forward_context().__exit__.assert_called() @RunIf(min_gpus=2, deepspeed=True, special=True) From 3e261e100c5cd6fe8203599effb2acdd87118c01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 17:33:14 +0200 Subject: [PATCH 28/47] add test description --- tests/lite/test_lite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 0fde335a60bdb..301ed6e9dad01 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -325,6 +325,7 @@ def test_backward_model_input_required(): def test_autocast(): + """Test that the Lite autocast context manager lets the precision plugin handle casting.""" lite = EmptyLite() lite._precision_plugin.forward_context = MagicMock() From 5754ad744616c1bb4980ca3a7baecd53f8ea7977 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Oct 2021 15:33:30 +0000 Subject: [PATCH 29/47] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/lite/test_lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 301ed6e9dad01..1e10b13f612a7 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -14,7 +14,7 @@ from copy import deepcopy from unittest import mock -from unittest.mock import Mock, PropertyMock, MagicMock +from unittest.mock import MagicMock, Mock, PropertyMock import pytest import torch From 85fe0cf125b708921fbc29f9f284f503d5f242da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 17:44:53 +0200 Subject: [PATCH 30/47] remove "mixed" string support --- pytorch_lightning/lite/wrappers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index fe42778dcb39e..c7705f4f9130a 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -82,7 +82,6 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: method.""" precision = self._accelerator.precision_plugin.precision precision_to_type = { - "mixed": torch.float16, 16: torch.float16, 32: torch.float32, 64: torch.float64, From 91a6b3c27c8f03e64a2bc7b657c6dd287c4f3551 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 17:48:34 +0200 Subject: [PATCH 31/47] More mixed references --- tests/lite/test_wrappers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index faed290b75629..9750b8f30da6c 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -41,7 +41,6 @@ def test_lite_module_wraps(): (32, torch.float64, torch.float32), (16, torch.float32, torch.float16), (16, torch.float64, torch.float16), - # ("mixed", torch.float32, torch.float16), # TODO: support precision="mixed" ], ) def test_lite_module_forward_conversion(precision, input_type, expected_type): @@ -50,7 +49,7 @@ def test_lite_module_forward_conversion(precision, input_type, expected_type): device = torch.device("cuda", 0) def check_autocast(forward_input): - assert precision not in (16, "mixed") or torch.is_autocast_enabled() + assert precision != 16 or torch.is_autocast_enabled() return forward_input module = Mock(wraps=torch.nn.Linear(1, 1), side_effect=check_autocast) From f45a97ad36a95bcb8885378dc1b99ffcd5aaefec Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 18:00:14 +0200 Subject: [PATCH 32/47] Implement `seed_everything` --- pytorch_lightning/lite/lite.py | 17 +++++++++++++++-- tests/lite/test_lite.py | 7 +++---- tests/lite/test_parity.py | 9 ++++----- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index d913ea580f912..9e91e07d0e1f5 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -24,8 +24,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler -from pytorch_lightning import Trainer -from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import ( DDPShardedPlugin, @@ -37,10 +36,12 @@ ) from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin +from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities import DeviceType, DistributedType, move_data_to_device from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import has_iterable_dataset from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import seed_everything class LightningLite(ABC): @@ -371,6 +372,18 @@ def load(self, filepath: Union[str, Path]) -> Any: """ return self._strategy.load_checkpoint(filepath) + @staticmethod + def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int: + """Helper function to seed everything without explicitly importing Lightning. + + See :func:`pytorch_lightning.seed_everything` for more details. + """ + if workers is None: + # Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new + # release, we can afford to do it. + workers = True + return seed_everything(seed=seed, workers=workers) + def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._set_plugin_specific_precision_variables() self._accelerator.setup_environment() diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 1e10b13f612a7..f0c760d3c4a16 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -23,7 +23,6 @@ from torch import nn from torch.utils.data import DataLoader, DistributedSampler, Sampler -from pytorch_lightning import seed_everything from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin @@ -354,11 +353,11 @@ def run(self): for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()): assert not torch.equal(mw_b, mw_a) - seed_everything(42) + self.seed_everything(42) model_1 = BoringModel() optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001) - seed_everything(42) + self.seed_everything(42) model_2 = BoringModel() optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001) @@ -368,7 +367,7 @@ def run(self): model_1, optimizer_1 = self.setup(model_1, optimizer_1) model_2, optimizer_2 = self.setup(model_2, optimizer_2) - seed_everything(42) + self.seed_everything(42) data_list = [] for _ in range(2): optimizer_1.zero_grad() diff --git a/tests/lite/test_parity.py b/tests/lite/test_parity.py index 2c7c58d249a13..b1578f3f47232 100644 --- a/tests/lite/test_parity.py +++ b/tests/lite/test_parity.py @@ -27,7 +27,6 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from pytorch_lightning import seed_everything from pytorch_lightning.lite import LightningLite from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin @@ -116,7 +115,7 @@ def precision_context(precision, accelerator) -> Generator[None, None, None]: ], ) def test_boring_lite_model_single_device(precision, strategy, devices, accelerator, tmpdir): - seed_everything(42) + LightningLite.seed_everything(42) train_dataloader = DataLoader(RandomDataset(32, 8)) model = BoringModel() num_epochs = 1 @@ -168,7 +167,7 @@ def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdi ], ) def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator, tmpdir): - seed_everything(42) + LightningLite.seed_everything(42) train_dataloader = DataLoader(RandomDataset(32, 8)) model = BoringModel() num_epochs = 1 @@ -199,7 +198,7 @@ def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator, ], ) def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir): - seed_everything(42) + LightningLite.seed_everything(42) train_dataloader = DataLoader(RandomDataset(32, 4)) model = BoringModel() num_epochs = 1 @@ -213,7 +212,7 @@ def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir for w_pure, w_lite in zip(state_dict.values(), lite_model_state_dict.values()): assert not torch.equal(w_pure.cpu(), w_lite.cpu()) - seed_everything(42) + LightningLite.seed_everything(42) train_dataloader = DataLoader(RandomDataset(32, 4)) model = BoringModel() run(lite.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir) From ba7ac5f311b9083fa0ba22f463e86706b91b68ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 18:30:26 +0200 Subject: [PATCH 33/47] add isinstance check --- pytorch_lightning/accelerators/accelerator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a8acec23c6ed3..d2b44fc3fca1c 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -333,6 +333,11 @@ def optimizer_step( """ model = model or self.lightning_module self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs) + + if not isinstance(model, pl.LightningModule): + # gradient clipping and norm tracking only available with a LightingModule/Trainer + return + trainer = model.trainer assert isinstance(trainer, pl.Trainer) # TODO: this is done for the entire model but should be changed to per-optimizer From f04b39877fe0e9b6810e77b0a4728ae41c41ab06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 18:35:02 +0200 Subject: [PATCH 34/47] add bfloat16 --- pytorch_lightning/lite/wrappers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index c7705f4f9130a..d2508f6527745 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -82,6 +82,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: method.""" precision = self._accelerator.precision_plugin.precision precision_to_type = { + "bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32, 64: torch.float64, From 229b02405a055edfd95720025e77a8a1011f6e58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 18:41:49 +0200 Subject: [PATCH 35/47] rename params_on_cpu --- pytorch_lightning/lite/lite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 9e91e07d0e1f5..17dbf47401b71 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -412,12 +412,12 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) - # When the user creates the optimizer, they reference the parameters on the CPU. # However, when running with TPU the parameters get copied and the reference in the optimizer # remains invalid. We need to update the references to point to the parameter tensors on the device. - params_on_cpu = dict(model.named_parameters()) + params_before_move = dict(model.named_parameters()) model = self.to_device(model) # XLA makes a copy on the parameters, so the device is not the same before and after to_device. params_on_device = dict(model.named_parameters()) - mapping = {param: params_on_device[name] for name, param in params_on_cpu.items()} + mapping = {param: params_on_device[name] for name, param in params_before_move.items()} for optimizer in optimizers: for param_group in optimizer.param_groups: param_group["params"] = [mapping.get(p, p) for p in param_group["params"]] From 95db246e8cb89125f820ead9a2944401f7d0c0e4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 20:10:03 +0200 Subject: [PATCH 36/47] Pass down the barrier name --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 17dbf47401b71..7b3ae24ee8b8a 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -326,7 +326,7 @@ def barrier(self, name: Optional[str] = None) -> None: # now all processes can read the files and start training """ - self._strategy.barrier() + self._strategy.barrier(name=name) def all_gather( self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False From 0c8e9141f8310750a410362ee6afb97fb63c0d16 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 20:11:13 +0200 Subject: [PATCH 37/47] Add back __del__ --- pytorch_lightning/lite/wrappers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index d2508f6527745..c97c280699b05 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -38,7 +38,9 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: optimizer: The optimizer to wrap accelerator: Reference to the accelerator for handling the optimizer step """ - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step",)} + # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would + # not want to call on desturction of the `_LiteOptimizer` + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._accelerator = accelerator From a93278d68bd347d5e47953e41c59888b64467f01 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 20:52:21 +0200 Subject: [PATCH 38/47] Fix mypy --- pytorch_lightning/lite/lite.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 7b3ae24ee8b8a..16898e73d2d79 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -143,12 +143,10 @@ def is_global_zero(self) -> bool: return self._strategy.is_global_zero @abstractmethod - def run(self, *args: Any, **kwargs: Any) -> Any: + def run(self) -> Any: """All the code inside this run method gets accelerated by Lite. - Args: - *args: Add any positional arguments you need, e.g., the hyperparameters for your model. - **kwargs: Add any keyword arguments you need, e.g., the hyperparameters for your model. + You can pass arbitrary arguments to this function when overriding it. """ def setup( @@ -156,7 +154,7 @@ def setup( model: nn.Module, *optimizers: Optimizer, move_to_device: bool = True, - ) -> Union[_LiteModule, List[Union[_LiteModule, _LiteOptimizer]]]: + ) -> Any: # no specific return because the way we want our API to look does not play well with mypy """Setup a model and its optimizers for accelerated training. Args: @@ -179,6 +177,7 @@ def setup( optimizers = [_LiteOptimizer(optimizer=optimizer, accelerator=self._accelerator) for optimizer in optimizers] self._models_setup += 1 if optimizers: + # join both types in a list for API convenience return [model] + optimizers # type: ignore return model From 65e289b625acc4292424f83cf1e0085b99ef4acc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 21:00:16 +0200 Subject: [PATCH 39/47] Fix test --- pytorch_lightning/lite/wrappers.py | 2 +- tests/lite/test_wrappers.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index c97c280699b05..370e84fa10940 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -39,7 +39,7 @@ def __init__(self, optimizer: Optimizer, accelerator: Accelerator) -> None: accelerator: Reference to the accelerator for handling the optimizer step """ # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would - # not want to call on desturction of the `_LiteOptimizer` + # not want to call on destruction of the `_LiteOptimizer self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")} self.__class__ = type("Lite" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 9750b8f30da6c..cbb359a4043ae 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -101,5 +101,3 @@ def test_lite_optimizer_steps(): lite_optimizer.step() accelerator.optimizer_step.assert_called_once() accelerator.optimizer_step.assert_called_with(optimizer, opt_idx=0, lambda_closure=ANY, model=accelerator.model) - lite_optimizer.zero_grad() - optimizer.zero_grad.assert_called_once() From d40822870d8e07e834e89bcba8cfe9c3ff3c82f1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 21:07:22 +0200 Subject: [PATCH 40/47] Add worker init fn --- pytorch_lightning/lite/lite.py | 4 ++++ pytorch_lightning/trainer/data_loading.py | 9 +++++---- tests/trainer/test_dataloaders.py | 8 ++++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 16898e73d2d79..d349672c4d176 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -239,6 +239,10 @@ def _setup_dataloader( dataloader = DataLoader(**kwargs) else: dataloader = _LiteDataLoader(device=device, **kwargs) + + # add worker_init_fn for correct seeding in worker processes + TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) + return self._strategy.process_dataloader(dataloader) def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 24206b8af1fc1..726336820b28a 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -114,9 +114,10 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: " in the `DataLoader` init to improve performance." ) - def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: + @staticmethod + def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: - dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) + dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) def _requires_distributed_sampler(self, dataloader) -> bool: return ( @@ -336,7 +337,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader") # add worker_init_fn for correct seeding in worker processes - apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) + apply_to_collection(self.train_dataloader, DataLoader, self._auto_add_worker_init_fn, rank=self.global_rank) # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): @@ -443,7 +444,7 @@ def _reset_eval_dataloader( dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] # add worker_init_fn for correct seeding in worker processes - apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn) + apply_to_collection(dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn) loader_num_batches = [] diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 2793c71560a81..2e8d552083b99 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -771,24 +771,24 @@ def test_auto_add_worker_init_fn(): trainer = Trainer() # without pl.seed_everything() - trainer.auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is None # with forcefully avoiding it seed_everything(0, workers=False) - trainer.auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is None # when user already has a worker_init_fn user_function = _user_worker_init_fn dataloader.worker_init_fn = user_function - trainer.auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is user_function dataloader.worker_init_fn = None # main use case seed_everything(0, workers=True) - trainer.auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is not None From 952e11c1b7c3acd8f0f8ccd91e121db4fc898b9f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 21:15:59 +0200 Subject: [PATCH 41/47] Forgot to pass the global rank --- tests/trainer/test_dataloaders.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 2e8d552083b99..ea31dbaf7d0a1 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -771,24 +771,24 @@ def test_auto_add_worker_init_fn(): trainer = Trainer() # without pl.seed_everything() - trainer._auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is None # with forcefully avoiding it seed_everything(0, workers=False) - trainer._auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is None # when user already has a worker_init_fn user_function = _user_worker_init_fn dataloader.worker_init_fn = user_function - trainer._auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is user_function dataloader.worker_init_fn = None # main use case seed_everything(0, workers=True) - trainer._auto_add_worker_init_fn(dataloader) + trainer._auto_add_worker_init_fn(dataloader, 0) assert dataloader.worker_init_fn is not None From 50d51248c3955734f26688cfba05e131728eac46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 21:16:02 +0200 Subject: [PATCH 42/47] add back skip of expensive spawn test --- tests/lite/test_parity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lite/test_parity.py b/tests/lite/test_parity.py index b1578f3f47232..bec9339ec8e2f 100644 --- a/tests/lite/test_parity.py +++ b/tests/lite/test_parity.py @@ -158,7 +158,7 @@ def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdi atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt")) -# @pytest.mark.skipif(True, reason="Skipping as it takes 80 seconds.") +@pytest.mark.skipif(True, reason="Skipping as it takes 80 seconds.") @RunIf(min_gpus=2) @pytest.mark.parametrize( "precision, strategy, devices, accelerator", From 13fb58a763f1fe372cc5655f129982fb8d2be999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 21:19:05 +0200 Subject: [PATCH 43/47] resolve todo in _LiteModule --- pytorch_lightning/lite/lite.py | 2 +- pytorch_lightning/lite/wrappers.py | 12 ++++++------ tests/lite/test_wrappers.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index d349672c4d176..7d0ff6a436b61 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -173,7 +173,7 @@ def setup( # Let accelerator/plugin wrap and connect the models and optimizers model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers)) - model = _LiteModule(model, self._accelerator) + model = _LiteModule(model, self._precision_plugin) optimizers = [_LiteOptimizer(optimizer=optimizer, accelerator=self._accelerator) for optimizer in optimizers] self._models_setup += 1 if optimizers: diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 370e84fa10940..991b86f25085b 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -20,6 +20,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.accelerators import Accelerator +from pytorch_lightning.plugins import PrecisionPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device @@ -60,8 +61,7 @@ def step(self, closure: Optional[Callable] = None) -> None: class _LiteModule(nn.Module): - # TODO: Pass in the precision plugin instead of accelerator - def __init__(self, module: nn.Module, accelerator: Accelerator) -> None: + def __init__(self, module: nn.Module, precision_plugin: PrecisionPlugin) -> None: """The LiteModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast automatically for the forward pass. @@ -69,11 +69,11 @@ def __init__(self, module: nn.Module, accelerator: Accelerator) -> None: Args: module: The module to wrap - accelerator: Reference to the accelerator for handling precision context + precision_plugin: Reference to the precision plugin for handling precision context """ super().__init__() self._module = module - self._accelerator = accelerator + self._precision_plugin = precision_plugin @property def module(self) -> nn.Module: @@ -82,7 +82,7 @@ def module(self) -> nn.Module: def forward(self, *args: Any, **kwargs: Any) -> Any: """Casts all inputs to the right precision and handles autocast for operations in the module forward method.""" - precision = self._accelerator.precision_plugin.precision + precision = self._precision_plugin.precision precision_to_type = { "bf16": torch.bfloat16, 16: torch.float16, @@ -93,7 +93,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: to_type = precision_to_type[precision] args, kwargs = apply_to_collection([args, kwargs], function=lambda t: t.to(to_type), dtype=Tensor) - with self._accelerator.precision_plugin.forward_context(): + with self._precision_plugin.forward_context(): output = self.module(*args, **kwargs) output = apply_to_collection(output, function=lambda t: t.to(torch.get_default_dtype()), dtype=Tensor) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index cbb359a4043ae..14a443c042601 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -53,7 +53,7 @@ def check_autocast(forward_input): return forward_input module = Mock(wraps=torch.nn.Linear(1, 1), side_effect=check_autocast) - lite_module = _LiteModule(module, lite._accelerator).to(device) + lite_module = _LiteModule(module, lite._precision_plugin).to(device) out = lite_module(torch.rand(1, dtype=input_type, device=device)) assert module.call_args[0][0].dtype == expected_type assert out.dtype == torch.get_default_dtype() From 9a1e93fa1d87c47ff8fa2d8301ae5e43ccd046d1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 28 Oct 2021 21:27:48 +0200 Subject: [PATCH 44/47] Add seed everything test --- tests/lite/test_lite.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index f0c760d3c4a16..916e0aa542b32 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from copy import deepcopy from unittest import mock from unittest.mock import MagicMock, Mock, PropertyMock @@ -28,6 +28,7 @@ from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.utilities import DistributedType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import pl_worker_init_function from tests.helpers.runif import RunIf @@ -202,6 +203,18 @@ def test_setup_dataloaders_distributed_sampler_not_needed(): assert lite_dataloader.sampler is custom_sampler +@mock.patch.dict(os.environ, {}, clear=True) +def test_seed_everything(): + """Test that seed everything is static and sets the worker init function on the dataloader.""" + EmptyLite.seed_everything(3) + + lite = EmptyLite() + lite_dataloader = lite.setup_dataloaders(DataLoader(Mock())) + + assert lite_dataloader.worker_init_fn.func is pl_worker_init_function + assert os.environ == {"PL_GLOBAL_SEED": "3", "PL_SEED_WORKERS": "1"} + + @pytest.mark.parametrize( "strategy", [ From f47c2ad69bd488f896ca1a243641c67dcb2bce2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 28 Oct 2021 22:29:00 +0200 Subject: [PATCH 45/47] fix type error --- pytorch_lightning/trainer/data_loading.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 726336820b28a..071eead5613b4 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -444,7 +444,9 @@ def _reset_eval_dataloader( dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] # add worker_init_fn for correct seeding in worker processes - apply_to_collection(dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn) + apply_to_collection( + dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn, rank=self.global_rank + ) loader_num_batches = [] From 947f9caff1d4fdc062aceebc54998ca213f0304c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 29 Oct 2021 23:02:06 +0200 Subject: [PATCH 46/47] update lambda_closure -> closure --- tests/lite/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 14a443c042601..3e2e9ac7a9f9a 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -100,4 +100,4 @@ def test_lite_optimizer_steps(): lite_optimizer = _LiteOptimizer(optimizer=optimizer, accelerator=accelerator) lite_optimizer.step() accelerator.optimizer_step.assert_called_once() - accelerator.optimizer_step.assert_called_with(optimizer, opt_idx=0, lambda_closure=ANY, model=accelerator.model) + accelerator.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=accelerator.model) From 015051a5e0491ed71cd6c5a9b579f9a0a88d509a Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 29 Oct 2021 22:15:13 +0100 Subject: [PATCH 47/47] update --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 373608c3c8b28..74522424c5326 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -559,7 +559,7 @@ def __init__( if gradient_clip_algorithm is not None else gradient_clip_algorithm ) - self.track_grad_norm = float(track_grad_norm) + self.track_grad_norm: float = float(track_grad_norm) self._detect_anomaly: bool = detect_anomaly self._setup_on_init(num_sanity_val_steps)