diff --git a/CHANGELOG.md b/CHANGELOG.md index cfcebf775b437..764378d9e3983 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950)) +- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index c9e054c032804..588b103640ccc 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod class ClusterEnvironment: @@ -24,9 +25,23 @@ def master_address(self): def master_port(self): pass + @abstractmethod def world_size(self) -> int: - return self._world_size + """ The number of processes across all devices and nodes. """ + @abstractmethod + def set_world_size(self, size: int) -> None: + pass + + @abstractmethod + def global_rank(self) -> int: + """ The rank (index) of the currently running process across all nodes and devices. """ + + @abstractmethod + def set_global_rank(self, rank: int) -> None: + pass + + @abstractmethod def local_rank(self) -> int: pass diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py new file mode 100644 index 0000000000000..67752535fe4e1 --- /dev/null +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities import rank_zero_only + + +class LightningEnvironment(ClusterEnvironment): + """ + The default environment used by Lightning for a single node or free cluster (not managed). + + The master process must be launched by the user and Lightning will spawn new + worker processes for distributed training, either in a single node or across multiple nodes. + + If the master address and port are not provided, the default environment will choose them + automatically. It is recommended to use this default environment for single-node distributed + training as it provides the most convenient way to launch the training script. + """ + + def __init__(self): + super().__init__() + self._master_port = None + self._global_rank: int = 0 + self._world_size: int = 1 + + def creates_children(self) -> bool: + return False + + def master_address(self) -> str: + return os.environ.get("MASTER_ADDR", "127.0.0.1") + + def master_port(self) -> int: + if self._master_port is None: + self._master_port = os.environ.get("MASTER_PORT", find_free_network_port()) + return int(self._master_port) + + def world_size(self) -> int: + return self._world_size + + def set_world_size(self, size: int) -> None: + self._world_size = size + + def global_rank(self) -> int: + return self._global_rank + + def set_global_rank(self, rank: int) -> None: + self._global_rank = rank + rank_zero_only.rank = rank + + def local_rank(self) -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + def node_rank(self) -> int: + group_rank = os.environ.get("GROUP_RANK", 0) + return int(os.environ.get("NODE_RANK", group_rank)) + + +def find_free_network_port() -> int: + """ + Finds a free port on localhost. + It is useful in single-node training when we don't want to connect to a real master node but + have to set the `MASTER_PORT` environment variable. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + s.close() + return port diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index 7f9586cab0ace..2cb5f2fcf8dc7 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -23,10 +23,10 @@ class SLURMEnvironment(ClusterEnvironment): - def __init__(self): - super().__init__() + def creates_children(self) -> bool: + return True - def master_address(self): + def master_address(self) -> str: # figure out the root node addr slurm_nodelist = os.environ.get("SLURM_NODELIST") if slurm_nodelist: @@ -66,10 +66,19 @@ def master_port(self): return default_port - def world_size(self): - return self._world_size + def world_size(self) -> int: + return int(os.environ["SLURM_NTASKS"]) - def local_rank(self): + def set_world_size(self, size: int) -> None: + log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return int(os.environ["SLURM_PROCID"]) + + def set_global_rank(self, rank: int) -> None: + log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") + + def local_rank(self) -> int: return int(os.environ['SLURM_LOCALID']) def node_rank(self): diff --git a/pytorch_lightning/plugins/environments/torchelastic_environment.py b/pytorch_lightning/plugins/environments/torchelastic_environment.py index 5ac7d9f1c9a40..85aa1c7c2169a 100644 --- a/pytorch_lightning/plugins/environments/torchelastic_environment.py +++ b/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -23,8 +23,11 @@ class TorchElasticEnvironment(ClusterEnvironment): - def __init__(self): - super().__init__() + @staticmethod + def is_using_torchelastic() -> bool: + """ Returns ``True`` if the current process was launched using the torchelastic command. """ + required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE") + return all(v in os.environ for v in required_env_vars) def master_address(self): if "MASTER_ADDR" not in os.environ: @@ -46,7 +49,18 @@ def master_port(self): def world_size(self): return os.environ.get('WORLD_SIZE') - def local_rank(self): + def set_world_size(self, size: int) -> None: + log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + + def global_rank(self) -> int: + return int(os.environ["RANK"]) + + def set_global_rank(self, rank: int) -> None: + log.debug( + "TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored." + ) + + def local_rank(self) -> int: return int(os.environ['LOCAL_RANK']) def node_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index e6ece8c8cffb1..46e049d9fe1c3 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -74,8 +74,8 @@ def __init__( self._ddp_kwargs = kwargs self._has_spawned_children = False self.task_idx = None - self.node_rank = 0 self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices + self.set_world_ranks() @property def root_device(self): @@ -171,6 +171,34 @@ def _call_children_scripts(self): delay = np.random.uniform(1, 5, 1)[0] sleep(delay) + def setup_distributed(self): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # determine which process we are and world size + self.set_world_ranks() + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection() + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not torch.distributed.is_initialized(): + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( @@ -178,11 +206,11 @@ def _check_can_spawn_children(self): " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." ) - def set_world_ranks(self): - self.local_rank = self.task_idx - self.node_rank = self.cluster_environment.node_rank() - self.global_rank = self.node_rank * self.num_processes + self.local_rank - self.world_size = self.num_nodes * self.num_processes + def set_world_ranks(self) -> None: + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() def pre_configure_ddp(self): # if unset, default `find_unused_parameters` `True` @@ -215,12 +243,11 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def init_ddp_connection(self, global_rank: int, world_size: int) -> None: - # TODO: From where to get cluster environment? - os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None: + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() + world_size = world_size if world_size is not None else self.cluster_environment.world_size() + os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) - if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index a7c8477a40c2d..0a4970f10efe9 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -19,6 +19,14 @@ class DDP2Plugin(DDPPlugin): + @property + def global_rank(self) -> int: + return self.node_rank + + @property + def world_size(self) -> int: + return self.num_nodes + def setup(self, model): self._model = model # set the task idx @@ -48,7 +56,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs def set_world_ranks(self): - self.local_rank = self.task_idx - self.node_rank = self.cluster_environment.node_rank() - self.global_rank = self.node_rank - self.world_size = self.num_nodes + self.cluster_environment.set_global_rank(self.node_rank) + self.cluster_environment.set_world_size(self.num_nodes) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index dcd6443b0e6fd..572aa27781e4c 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -59,9 +59,14 @@ def __init__( self.sync_batchnorm = sync_batchnorm self._ddp_kwargs = kwargs self.dist = LightningDistributed() - self.num_processes = len(parallel_devices) - self.node_rank = 0 + self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None + self._local_rank = 0 + self.set_world_ranks() + + @property + def local_rank(self) -> int: + return self._local_rank def __getstate__(self): """ Makes this plugin pickleable without destroying the queue in the current process. """ @@ -90,12 +95,12 @@ def setup(self, model): smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() - def set_world_ranks(self, process_idx): - self.local_rank = process_idx - self.node_rank = self.cluster_environment.node_rank() - self.task_idx = self.cluster_local_rank - self.global_rank = self.node_rank * self.num_processes + self.local_rank - self.world_size = self.num_nodes * self.num_processes + def set_world_ranks(self, process_idx: int = 0) -> None: + self._local_rank = process_idx + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() @property def mp_spawn_kwargs(self): @@ -198,11 +203,12 @@ def configure_ddp(self): **self._ddp_kwargs, ) - def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + def init_ddp_connection(self, global_rank: Optional[int], world_size: Optional[int]) -> None: # TODO: this code is duplicated in DDP and DDPSpawn, make this a function - os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() + world_size = world_size if world_size is not None else self.cluster_environment.world_size() + os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 7a1f7ac1201c0..a0e55b2788afb 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -27,6 +27,22 @@ class DataParallelPlugin(ParallelPlugin): def __init__(self, parallel_devices: Optional[List[torch.device]]): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + @property + def global_rank(self) -> int: + return 0 + + @property + def local_rank(self) -> int: + return 0 + + @property + def node_rank(self) -> int: + return 0 + + @property + def world_size(self) -> int: + return 1 + def setup(self, model): # model needs to be moved to the device before it is wrapped model.to(self.root_device) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index ee6ba2ef9ef78..7c6c2d5525e8b 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -31,6 +31,19 @@ class HorovodPlugin(ParallelPlugin): def __init__(self, parallel_devices: Optional[List[torch.device]] = None): super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + rank_zero_only.rank = self.global_rank + + @property + def global_rank(self) -> int: + return hvd.rank() + + @property + def local_rank(self) -> int: + return hvd.local_rank() + + @property + def world_size(self) -> int: + return hvd.size() @property def root_device(self): @@ -38,17 +51,11 @@ def root_device(self): @property def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) + distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs def setup(self, model): self._model = model - - self.global_rank = hvd.rank() - self.local_rank = hvd.local_rank() - self.world_size = hvd.size() - rank_zero_only.rank = self.global_rank - self.model_to_device() def pre_dispatch(self): @@ -63,14 +70,14 @@ def _unpack_lightning_optimizer(opt): # increased total batch size for optimizer in optimizers: for param_group in optimizer.param_groups: - param_group["lr"] *= hvd.size() + param_group["lr"] *= self.world_size # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR lr_schedulers = self.lightning_module.trainer.lr_schedulers for scheduler in lr_schedulers: scheduler = scheduler["scheduler"] if isinstance(scheduler, _LRScheduler): - scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] + scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 9809443aff3fb..d13d85ab65a4c 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -36,9 +36,6 @@ def __init__( super().__init__() self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment - self.global_rank = 0 - self.world_size = 1 - self.local_rank = 0 @property def cluster_local_rank(self): @@ -68,6 +65,22 @@ def connect(self, model, *args, **kwargs): self.setup(model) return self.model + @property + def global_rank(self) -> int: + return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0 + + @property + def local_rank(self) -> int: + return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0 + + @property + def node_rank(self) -> int: + return self.cluster_environment.node_rank() if self.cluster_environment is not None else 0 + + @property + def world_size(self) -> int: + return self.cluster_environment.world_size() if self.cluster_environment is not None else 1 + @property def is_global_zero(self) -> bool: return self.global_rank == 0 diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 3878aa9db3ea4..bfe891f5f0b94 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -100,8 +100,8 @@ def __init__( def init_ddp_connection( self, - global_rank: int, - world_size: int, + global_rank: Optional[int] = None, + world_size: Optional[int] = None, ) -> None: if self.lightning_module.trainer.amp_backend is not None: raise MisconfigurationException( @@ -110,10 +110,10 @@ def init_ddp_connection( if self._skip_init_connections(): return - super().init_ddp_connection( - global_rank=global_rank, - world_size=world_size, - ) + + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() + world_size = world_size if world_size is not None else self.cluster_environment.world_size() + super().init_ddp_connection(global_rank, world_size) super().init_rpc_connection(global_rank=global_rank, world_size=world_size) model = self.lightning_module self.gpus_per_model = self._infer_check_num_gpus() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a29503d9c673b..b6c2b9a390337 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -50,8 +50,21 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None: super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False) self.tpu_local_core_rank = 0 + self.tpu_global_core_rank = 0 self.start_method = None + @property + def global_rank(self) -> int: + return self.tpu_local_core_rank + + @property + def local_rank(self) -> int: + return self.tpu_local_core_rank + + @property + def world_size(self) -> int: + return self.num_processes + @staticmethod def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']): if not isinstance(dataloaders, list): @@ -111,11 +124,9 @@ def configure_ddp(self) -> None: def init_ddp_connection(self, global_rank: int, world_size: int) -> None: pass - def set_world_ranks(self, process_idx: int) -> None: + def set_world_ranks(self, process_idx: int = 0) -> None: self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - self.global_rank = self.tpu_local_core_rank - self.world_size = self.num_nodes * self.num_processes def new_process(self, process_idx: int, trainer, mp_queue) -> None: self.mp_queue = mp_queue @@ -124,7 +135,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: if seed is not None: seed_everything(int(seed)) - self.set_world_ranks(process_idx) + self.set_world_ranks() # set warning rank rank_zero_only.rank = self.global_rank diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index e49d524ec48e9..a336945d83f8a 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -54,7 +54,7 @@ DeviceType, DistributedType, ) -from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException if _HOROVOD_AVAILABLE: @@ -296,8 +296,18 @@ def root_gpu(self) -> Optional[int]: @property def is_using_torchelastic(self) -> bool: - te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ) - return te_flags_passed + """ + .. deprecated:: v1.3 + Will be removed in v1.5.0. + + Returns: + ``True`` if the current process was launched using the torchelastic command. + """ + rank_zero_deprecation( + "The property `AcceleratorConnector.is_using_torchelastic` was deprecated in v1.3" + " and will be removed in 1.5. Use `TorchElasticEnvironment.is_using_torchelastic()` instead.", + ) + return TorchElasticEnvironment.is_using_torchelastic() def select_precision_plugin(self) -> PrecisionPlugin: # set precision type @@ -351,7 +361,12 @@ def select_precision_plugin(self) -> PrecisionPlugin: def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: - plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self.cluster_environment) + plugin = DDP2Plugin( + parallel_devices=self.parallel_devices, + num_nodes=self.num_nodes, + cluster_environment=self.cluster_environment, + sync_batchnorm=self.sync_batchnorm, + ) elif self.use_ddp and self.use_deepspeed: plugin = DeepSpeedPlugin( num_nodes=self.num_nodes, @@ -360,11 +375,11 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: ) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks - use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic + use_torchelastic_ddp = self.use_ddp and TorchElasticEnvironment.is_using_torchelastic() use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN use_ddp_cpu_spawn = self.use_ddp and self.on_cpu use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN - use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic + use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and TorchElasticEnvironment.is_using_torchelastic() use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN @@ -455,7 +470,7 @@ def select_cluster_environment(self) -> ClusterEnvironment: # TODO: decouple DDP from SLURM # refactor and let generic cluster env hold the information about who spawns the processes os.environ["PL_IN_DDP_SUBPROCESS"] = "1" - elif self.is_using_torchelastic: + elif TorchElasticEnvironment.is_using_torchelastic(): env = TorchElasticEnvironment() # TODO: decouple DDP from TE # refactor and let generic cluster env hold the information about who spawns the processes diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 9eaa4c7e2b57e..9a402ded7b3a9 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -92,7 +92,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", - "SLURM_LOCALID": "10" + "SLURM_PROCID": "1", + "SLURM_LOCALID": "1", } ) def test_accelerator_choice_ddp_slurm(): @@ -105,8 +106,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -121,15 +122,15 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") +@RunIf(min_gpus=2) @mock.patch.dict( os.environ, { "CUDA_VISIBLE_DEVICES": "0,1", "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "10" + "SLURM_PROCID": "1", + "SLURM_LOCALID": "1" } ) @mock.patch('torch.cuda.device_count', return_value=2) @@ -143,8 +144,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDP2Plugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -159,8 +160,17 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) +@RunIf(min_gpus=1) +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "WORLD_SIZE": "2", + "LOCAL_WORLD_SIZE": "2", + "RANK": "1", + "LOCAL_RANK": "1", + "GROUP_RANK": "0", + } +) @mock.patch('torch.cuda.device_count', return_value=2) def test_accelerator_choice_ddp_te(device_count_mock): @@ -171,8 +181,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -187,8 +197,17 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU") -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"}) +@RunIf(min_gpus=1) +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "WORLD_SIZE": "2", + "LOCAL_WORLD_SIZE": "2", + "RANK": "1", + "LOCAL_RANK": "1", + "GROUP_RANK": "0", + } +) @mock.patch('torch.cuda.device_count', return_value=2) def test_accelerator_choice_ddp2_te(device_count_mock): @@ -199,8 +218,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDP2Plugin) assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -215,11 +234,15 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@mock.patch.dict(os.environ, { - "WORLD_SIZE": "1", - "LOCAL_RANK": "10", - "NODE_RANK": "0", -}) +@mock.patch.dict( + os.environ, { + "WORLD_SIZE": "2", + "LOCAL_WORLD_SIZE": "2", + "RANK": "1", + "LOCAL_RANK": "1", + "GROUP_RANK": "0", + } +) @mock.patch('torch.cuda.device_count', return_value=0) def test_accelerator_choice_ddp_cpu_te(device_count_mock): @@ -230,8 +253,8 @@ def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 10 - assert trainer.training_type_plugin.task_idx == 10 + assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 + assert trainer.training_type_plugin.task_idx == 1 raise SystemExit() model = BoringModel() @@ -252,7 +275,8 @@ def on_fit_start(self, trainer, pl_module): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", } ) @mock.patch('torch.cuda.device_count', return_value=0) @@ -287,7 +311,8 @@ def on_fit_start(self, trainer, pl_module): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", } ) @mock.patch('torch.cuda.device_count', return_value=0) @@ -364,7 +389,8 @@ class TrainTypePlugin(SingleDevicePlugin): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", } ) @mock.patch('torch.cuda.device_count', return_value=0) diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 2f16f2fe64e75..69a8131007cad 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -147,7 +147,8 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" + "SLURM_LOCALID": "0", + "SLURM_PROCID": "0", } ) def test_amp_gpu_ddp_slurm_managed(tmpdir): diff --git a/tests/plugins/environments/test_lightning_environment.py b/tests/plugins/environments/test_lightning_environment.py new file mode 100644 index 0000000000000..3f89b88bfc215 --- /dev/null +++ b/tests/plugins/environments/test_lightning_environment.py @@ -0,0 +1,57 @@ +import os +from unittest import mock + +from pytorch_lightning.plugins.environments import LightningEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = LightningEnvironment() + assert not env.creates_children() + assert env.master_address() == "127.0.0.1" + assert isinstance(env.master_port(), int) + assert env.world_size() == 1 + assert env.local_rank() == 0 + assert env.node_rank() == 0 + + +@mock.patch.dict(os.environ, { + "MASTER_ADDR": "1.2.3.4", + "MASTER_PORT": "500", + "LOCAL_RANK": "2", + "NODE_RANK": "3", +}) +def test_attributes_from_environment_variables(): + """ Test that the default cluster environment takes the attributes from the environment variables. """ + env = LightningEnvironment() + assert env.master_address() == "1.2.3.4" + assert env.master_port() == 500 + assert env.world_size() == 1 + assert env.global_rank() == 0 + assert env.local_rank() == 2 + assert env.node_rank() == 3 + env.set_global_rank(100) + assert env.global_rank() == 100 + env.set_world_size(100) + assert env.world_size() == 100 + + +@mock.patch.dict(os.environ, { + "GROUP_RANK": "1", +}) +def test_node_rank_from_group_rank(): + """ Test that the GROUP_RANK substitutes NODE_RANK. """ + env = LightningEnvironment() + assert "NODE_RANK" not in os.environ + assert env.node_rank() == 1 + + +@mock.patch.dict(os.environ, {}) +def test_random_master_port(): + """ Test randomly chosen master port when no master port was given by user. """ + env = LightningEnvironment() + port = env.master_port() + assert isinstance(port, int) + # repeated calls do not generate a new port number + assert env.master_port() == port diff --git a/tests/plugins/environments/test_slurm_environment.py b/tests/plugins/environments/test_slurm_environment.py new file mode 100644 index 0000000000000..0be88dbeb91c6 --- /dev/null +++ b/tests/plugins/environments/test_slurm_environment.py @@ -0,0 +1,72 @@ +import logging +import os +from unittest import mock + +import pytest + +from pytorch_lightning.plugins.environments import SLURMEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = SLURMEnvironment() + assert env.creates_children() + assert env.master_address() == "127.0.0.1" + assert env.master_port() == 12910 + with pytest.raises(KeyError): + # world size is required to be passed as env variable + env.world_size() + with pytest.raises(KeyError): + # local rank is required to be passed as env variable + env.local_rank() + with pytest.raises(KeyError): + # node_rank is required to be passed as env variable + env.node_rank() + + +@mock.patch.dict( + os.environ, { + "SLURM_NODELIST": "1.1.1.1, 1.1.1.2", + "SLURM_JOB_ID": "0001234", + "SLURM_NTASKS": "20", + "SLURM_LOCALID": "2", + "SLURM_PROCID": "1", + "SLURM_NODEID": "3", + } +) +def test_attributes_from_environment_variables(caplog): + """ Test that the SLURM cluster environment takes the attributes from the environment variables. """ + env = SLURMEnvironment() + assert env.master_address() == "1.1.1.1" + assert env.master_port() == 15000 + 1234 + assert env.world_size() == 20 + assert env.global_rank() == 1 + assert env.local_rank() == 2 + assert env.node_rank() == 3 + # setter should be no-op + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): + env.set_global_rank(100) + assert env.global_rank() == 1 + assert "setting global rank is not allowed" in caplog.text + + caplog.clear() + + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): + env.set_world_size(100) + assert env.world_size() == 20 + assert "setting world size is not allowed" in caplog.text + + +@pytest.mark.parametrize( + "slurm_node_list,expected", [ + ("alpha,beta,gamma", "alpha"), + ("alpha beta gamma", "alpha"), + ("1.2.3.[100-110]", "1.2.3.100"), + ] +) +def test_master_address_from_slurm_node_list(slurm_node_list, expected): + """ Test extracting the master node from different formats for the SLURM_NODELIST. """ + with mock.patch.dict(os.environ, {"SLURM_NODELIST": slurm_node_list}): + env = SLURMEnvironment() + assert env.master_address() == expected diff --git a/tests/plugins/environments/test_torchelastic_environment.py b/tests/plugins/environments/test_torchelastic_environment.py new file mode 100644 index 0000000000000..2b9efafbbcc67 --- /dev/null +++ b/tests/plugins/environments/test_torchelastic_environment.py @@ -0,0 +1,54 @@ +import logging +import os +from unittest import mock + +import pytest + +from pytorch_lightning.plugins.environments import TorchElasticEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = TorchElasticEnvironment() + assert env.creates_children() + assert env.master_address() == "127.0.0.1" + assert env.master_port() == 12910 + assert env.world_size() is None + with pytest.raises(KeyError): + # local rank is required to be passed as env variable + env.local_rank() + assert env.node_rank() == 0 + + +@mock.patch.dict( + os.environ, { + "MASTER_ADDR": "1.2.3.4", + "MASTER_PORT": "500", + "WORLD_SIZE": "20", + "RANK": "1", + "LOCAL_RANK": "2", + "GROUP_RANK": "3", + } +) +def test_attributes_from_environment_variables(caplog): + """ Test that the torchelastic cluster environment takes the attributes from the environment variables. """ + env = TorchElasticEnvironment() + assert env.master_address() == "1.2.3.4" + assert env.master_port() == 500 + assert env.world_size() == 20 + assert env.global_rank() == 1 + assert env.local_rank() == 2 + assert env.node_rank() == 3 + # setter should be no-op + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): + env.set_global_rank(100) + assert env.global_rank() == 1 + assert "setting global rank is not allowed" in caplog.text + + caplog.clear() + + with caplog.at_level(logging.DEBUG, logger="pytorch_lightning.plugins.environments"): + env.set_world_size(100) + assert env.world_size() == 20 + assert "setting world size is not allowed" in caplog.text diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py new file mode 100644 index 0000000000000..328cb0a59f08e --- /dev/null +++ b/tests/plugins/test_amp_plugins.py @@ -0,0 +1,85 @@ +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +class MyNativeAMP(NativeMixedPrecisionPlugin): + pass + + +class MyApexPlugin(ApexMixedPrecisionPlugin): + pass + + +@mock.patch.dict( + os.environ, { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_PROCID": "0", + "SLURM_LOCALID": "0", + } +) +@mock.patch('torch.cuda.device_count', return_value=2) +@pytest.mark.parametrize('ddp_backend,gpus', [('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)]) +@pytest.mark.parametrize( + 'amp,custom_plugin,plugin_cls', [ + pytest.param('native', False, NativeMixedPrecisionPlugin, marks=RunIf(amp_native=True)), + pytest.param('native', True, MyNativeAMP, marks=RunIf(amp_native=True)), + pytest.param('apex', False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)), + pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)) + ] +) +def test_amp_apex_ddp( + mocked_device_count, ddp_backend: str, gpus: int, amp: str, custom_plugin: bool, plugin_cls: MixedPrecisionPlugin +): + + trainer = Trainer( + fast_dev_run=True, + precision=16, + amp_backend=amp, + gpus=gpus, + accelerator=ddp_backend, + plugins=[plugin_cls()] if custom_plugin else None, + ) + assert isinstance(trainer.precision_plugin, plugin_cls) + + +class GradientUnscaleBoringModel(BoringModel): + + def on_after_backward(self): + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) + if not (torch.isinf(norm) or torch.isnan(norm)): + assert norm.item() < 15. + + +@RunIf(min_gpus=2, amp_native=True) +@pytest.mark.parametrize('accum', [1, 2]) +def test_amp_gradient_unscale(tmpdir, accum: int): + model = GradientUnscaleBoringModel() + + trainer = Trainer( + max_epochs=2, + default_root_dir=tmpdir, + limit_train_batches=2, + limit_test_batches=2, + limit_val_batches=2, + amp_backend='native', + accelerator='ddp_spawn', + gpus=2, + precision=16, + track_grad_norm=2, + log_every_n_steps=1, + accumulate_grad_batches=accum, + ) + trainer.fit(model) diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py new file mode 100644 index 0000000000000..032276dd674d0 --- /dev/null +++ b/tests/plugins/test_cluster_integration.py @@ -0,0 +1,114 @@ +import os +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPShardedPlugin, DeepSpeedPlugin, RPCSequentialPlugin +from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.utilities import rank_zero_only +from tests.helpers.runif import RunIf + + +def environment_combinations(): + expected = dict(global_rank=3, local_rank=1, node_rank=1, world_size=4) + # Lightning + variables = { + "CUDA_VISIBLE_DEVICES": "0,1,2,4", + "LOCAL_RANK": "1", + "NODE_RANK": "1", + "WORLD_SIZE": "8", + } + environment = LightningEnvironment() + yield environment, variables, expected + # SLURM + variables = { + "CUDA_VISIBLE_DEVICES": "0,1,2,4", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_LOCALID": "1", + "SLURM_NODEID": "1", + "SLURM_PROCID": "3", + "SLURM_NTASKS": "4", + } + environment = SLURMEnvironment() + yield environment, variables, expected + # TorchElastic + variables = { + "CUDA_VISIBLE_DEVICES": "0,1,2,4", + "LOCAL_RANK": "1", + "GROUP_RANK": "1", + "RANK": "3", + "WORLD_SIZE": "4", + "LOCAL_WORLD_SIZE": "2", + } + environment = TorchElasticEnvironment() + yield environment, variables, expected + + +@pytest.mark.parametrize( + "plugin_cls", [ + DDPPlugin, + DDPShardedPlugin, + DDP2Plugin, + pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), + pytest.param(RPCSequentialPlugin, marks=RunIf(fairscale_pipe=True)), + ] +) +def test_ranks_available_manual_plugin_selection(plugin_cls): + """ Test that the rank information is readily available after Trainer initialization. """ + num_nodes = 2 + for cluster, variables, expected in environment_combinations(): + + if plugin_cls == DDP2Plugin: + expected.update(global_rank=expected["node_rank"], world_size=num_nodes) + + with mock.patch.dict(os.environ, variables): + plugin = plugin_cls( + parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], + num_nodes=num_nodes, + cluster_environment=cluster, + ) + trainer = Trainer(plugins=[plugin]) + assert rank_zero_only.rank == expected["global_rank"] + assert trainer.global_rank == expected["global_rank"] + assert trainer.local_rank == expected["local_rank"] + assert trainer.node_rank == expected["node_rank"] + assert trainer.world_size == expected["world_size"] + + +@pytest.mark.parametrize( + "trainer_kwargs", [ + dict(accelerator="ddp", gpus=[1, 2]), + dict(accelerator="ddp_sharded", gpus=[1, 2]), + dict(accelerator="ddp2", gpus=[1, 2]), + dict(accelerator="ddp_cpu", num_processes=2), + dict(accelerator="ddp_spawn", gpus=[1, 2]), + ] +) +@mock.patch("torch.cuda.is_available", return_value=True) +@mock.patch("torch.cuda.device_count", return_value=4) +def test_ranks_available_automatic_plugin_selection(mock0, mock1, trainer_kwargs): + """ Test that the rank information is readily available after Trainer initialization. """ + num_nodes = 2 + trainer_kwargs.update(num_nodes=num_nodes) + + for cluster, variables, expected in environment_combinations(): + + if trainer_kwargs["accelerator"] == "ddp2": + expected.update(global_rank=expected["node_rank"], world_size=num_nodes) + if trainer_kwargs["accelerator"] in ("ddp_cpu", "ddp_spawn"): + if isinstance(cluster, (SLURMEnvironment, TorchElasticEnvironment)): + # slurm and torchelastic do not work with spawn plugins + continue + # when using spawn, we don't reach rank > 0 until we call Trainer.fit() + expected.update(global_rank=(expected["node_rank"] * 2), local_rank=0) + + with mock.patch.dict(os.environ, variables): + trainer = Trainer(**trainer_kwargs) + assert type(trainer.training_type_plugin.cluster_environment) == type(cluster) + assert rank_zero_only.rank == expected["global_rank"] + assert trainer.global_rank == expected["global_rank"] + assert trainer.local_rank == expected["local_rank"] + assert trainer.node_rank == expected["node_rank"] + assert trainer.world_size == expected["world_size"] diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py index 2c074e6c3afda..0cf16b6e78df5 100644 --- a/tests/plugins/test_rpc_plugin.py +++ b/tests/plugins/test_rpc_plugin.py @@ -20,6 +20,7 @@ "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", + "SLURM_PROCID": "0", "SLURM_LOCALID": "0", }, ) diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index bb587827c3a3f..53baa8e54461b 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock from unittest.mock import MagicMock import pytest @@ -51,8 +52,9 @@ def predict_dataloader(self): (None, [_loader, _loader_no_len], None, None), ], ) +@mock.patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm") def test_error_patched_iterable_dataloaders( - tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders + _, tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders ): model = BoringModelNoDataloaders() connector = DataConnector(MagicMock()) @@ -69,6 +71,7 @@ def test_error_patched_iterable_dataloaders( TPUSpawnPlugin(MagicMock()).connect(model) -def test_error_process_iterable_dataloader(tmpdir): +@mock.patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm") +def test_error_process_iterable_dataloader(_): with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len)