diff --git a/CHANGELOG.md b/CHANGELOG.md index 24e629494f87f..263f23d25899e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `AttributeError` when `logger=None` on TPU ([#6221](https://github.com/PyTorchLightning/pytorch-lightning/pull/6221)) - - Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260)) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 8f63bc7b86b11..cad6e1a40adea 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -28,7 +28,7 @@ def setup(self, trainer, model): return super().setup(trainer, model) def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): - xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) + xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): """ diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 351d945675a0c..13585f8f368f4 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -15,6 +15,7 @@ from typing import Any, List, Optional, Union import torch +import torch.distributed as torch_distrib from torch.optim.lr_scheduler import _LRScheduler, Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer @@ -116,7 +117,8 @@ def start_predicting(self, trainer): hvd.join() def barrier(self, *args, **kwargs): - hvd.join() + if torch_distrib.is_initialized(): + hvd.join() def broadcast(self, obj: object, src: int = 0) -> object: obj = hvd.broadcast_object(obj, src) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d4b8974617ac1..371649057909b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union import torch +import torch.distributed as torch_distrib import torch.multiprocessing as mp from pytorch_lightning.core.lightning import LightningModule @@ -112,7 +113,8 @@ def model_to_device(self) -> None: self._model.to(xm.xla_device()) def barrier(self, name: Optional[str] = None) -> None: - rendezvous(f"pl.Trainer.{name}") + if torch_distrib.is_initialized(): + rendezvous(f"pl.Trainer.{name}") def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing callback through model -> trainer -> callback? @@ -126,7 +128,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - xm.save(self.lightning_module.state_dict(), last_path) + self.save(self.lightning_module.state_dict(), last_path) if self.global_rank == 0: # todo, pass complete checkpoint as state dictionary @@ -134,6 +136,18 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(last_path) self.mp_queue.put(results) + def save(self, state_dict: Dict, path: str) -> None: + """ + Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``. + The rendez-vous doesn't affect directly saving. + We can ignore the ``RuntimeError`` to reduce friction with TPUs. + """ + try: + xm.save(state_dict, path) + except RuntimeError as e: + if "Failed to meet rendezvous" not in str(e): + raise e + def broadcast(self, obj: object, src: int = 0) -> object: buffer = io.BytesIO() torch.save(obj, buffer) @@ -281,4 +295,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False): # dump states as a checkpoint dictionary object _checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only) # Todo: TypeError: 'mappingproxy' object does not support item assignment - xm.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) + self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index fb6a1d4ab8442..59d406b0479c6 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -494,7 +494,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): # define the max CPU available self.num_processes = os.cpu_count() # special case with TPUs - elif self.distributed_backend == 'tpu': + elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: self._device_type = DeviceType.TPU elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dd81f4b53ce13..e123c1af5a5d0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -56,7 +56,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.enums import LightningEnum @@ -949,8 +949,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} - if not self._device_type == DeviceType.TPU: - self.accelerator.barrier() + + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index bfa8f2432e3a2..db96e6854db90 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -177,8 +177,6 @@ def test_model_16bit_tpu_cores_8(tmpdir): def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works""" - # todo: Test on 8 cores - hanging. - class CustomBoringModel(BoringModel): def validation_step(self, *args, **kwargs): @@ -195,9 +193,10 @@ def validation_step(self, *args, **kwargs): max_epochs=2, limit_train_batches=2, limit_val_batches=2, - tpu_cores=[1], + tpu_cores=8, ) trainer.fit(model) + trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32)) @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")