From d40385effeb63210633b5e7679abca1d90fffb1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 28 Jun 2021 11:30:48 +0200 Subject: [PATCH 1/8] device ids in barrier x x s same fix for spawn fix non-nccl x --- pytorch_lightning/plugins/training_type/ddp.py | 9 +++++++-- pytorch_lightning/plugins/training_type/ddp_spawn.py | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index b855d100b1f12..6e91f1e967554 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -323,8 +323,13 @@ def pre_dispatch(self): def post_dispatch(self) -> None: self.cluster_environment.teardown() - def barrier(self, *args, **kwargs): - if torch_distrib.is_available() and torch_distrib.is_initialized(): + def barrier(self, *args, **kwargs) -> None: + if not torch_distrib.is_initialized(): + return + device_ids = self.determine_ddp_device_ids() + if _TORCH_GREATER_EQUAL_1_8 and device_ids is not None: + torch_distrib.barrier(device_ids=device_ids) + else: torch_distrib.barrier() def broadcast(self, obj: object, src: int = 0) -> object: diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b61f9a6052630..da346907e0583 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -309,8 +309,13 @@ def __recover_child_process_weights(self, best_path, last_path): ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt) - def barrier(self, *args, **kwargs): - if torch_distrib.is_initialized(): + def barrier(self, *args, **kwargs) -> None: + if not torch_distrib.is_initialized(): + return + device_ids = self.determine_ddp_device_ids() + if _TORCH_GREATER_EQUAL_1_8 and device_ids is not None: + torch_distrib.barrier(device_ids=device_ids) + else: torch_distrib.barrier() def broadcast(self, obj: object, src: int = 0) -> object: From 6979e5044715adbe0b1f845c690e63d68b450209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 28 Jun 2021 12:07:10 +0200 Subject: [PATCH 2/8] add changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 980d2a450f786..4fa5738eefb7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -315,6 +315,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a DDP info message that was never shown ([#8111](https://github.com/PyTorchLightning/pytorch-lightning/pull/8111)) +- Fixed NCCL error when selecting non-consecutive device ids ([#8165](https://github.com/PyTorchLightning/pytorch-lightning/pull/8165)) + + ## [1.3.7] - 2021-06-22 - Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975)) From 2fdee971064ffd15bab930a9adac57e63a1606ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 28 Jun 2021 15:17:27 +0200 Subject: [PATCH 3/8] get nccl backend --- pytorch_lightning/plugins/training_type/ddp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 6e91f1e967554..8efbac4b3b5e2 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -326,9 +326,8 @@ def post_dispatch(self) -> None: def barrier(self, *args, **kwargs) -> None: if not torch_distrib.is_initialized(): return - device_ids = self.determine_ddp_device_ids() - if _TORCH_GREATER_EQUAL_1_8 and device_ids is not None: - torch_distrib.barrier(device_ids=device_ids) + if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": + torch_distrib.barrier(device_ids=self.determine_ddp_device_ids()) else: torch_distrib.barrier() From 7283b384f0ecddb48462fc1aabf95fff44f28f3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 28 Jun 2021 15:18:36 +0200 Subject: [PATCH 4/8] get backend --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index da346907e0583..cdf53b4854a10 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -312,9 +312,8 @@ def __recover_child_process_weights(self, best_path, last_path): def barrier(self, *args, **kwargs) -> None: if not torch_distrib.is_initialized(): return - device_ids = self.determine_ddp_device_ids() - if _TORCH_GREATER_EQUAL_1_8 and device_ids is not None: - torch_distrib.barrier(device_ids=device_ids) + if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": + torch_distrib.barrier(device_ids=self.determine_ddp_device_ids()) else: torch_distrib.barrier() From 7a913d893d5c803cfb65695e1cb9b33c53f2f237 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 28 Jun 2021 16:35:17 +0200 Subject: [PATCH 5/8] add test --- tests/plugins/test_ddp_plugin.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index d236dc145d96c..f999490bb78ce 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import Trainer from pytorch_lightning.plugins import DDPPlugin @@ -46,3 +47,28 @@ def test_ddp_with_2_gpus(): assert model.device == torch.device("cpu") cuda_memory = torch.cuda.memory_allocated() assert cuda_memory < model.start_cuda_memory + + +class BarrierModel(BoringModel): + + def setup(self, stage=None): + assert not isinstance(self.trainer.accelerator.model, DistributedDataParallel) + self.trainer.accelerator.barrier("barrier before model is wrapped") + + def on_train_start(self): + assert isinstance(self.trainer.accelerator.model, DistributedDataParallel) + self.trainer.accelerator.barrier("barrier after model is wrapped") + + +@RunIf(min_gpus=4, special=True) +def test_ddp_barrier_non_consecutive_device_ids(tmpdir): + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + gpus=[1, 3], + accelerator="ddp", + ) + trainer.fit(model) From 6974299444ecbd9e666d00340e150a7873b4ba99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 28 Jun 2021 16:39:13 +0200 Subject: [PATCH 6/8] test --- tests/plugins/test_ddp_plugin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index f999490bb78ce..0f17a1fa716f4 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -66,8 +66,7 @@ def test_ddp_barrier_non_consecutive_device_ids(tmpdir): model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=1, + max_steps=1, gpus=[1, 3], accelerator="ddp", ) From 1439be18709d0bae7cd716fba703cb1520f87999 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Jun 2021 14:39:14 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/plugins/test_ddp_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index f999490bb78ce..1900dc92c17bb 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -62,7 +62,7 @@ def on_train_start(self): @RunIf(min_gpus=4, special=True) def test_ddp_barrier_non_consecutive_device_ids(tmpdir): - + model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, From 8677cba9dace6a37638a5e86179a78b15271bd22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 28 Jun 2021 17:45:03 +0200 Subject: [PATCH 8/8] update ddp test --- tests/plugins/test_ddp_plugin.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 4ccc2c21956f2..61c5d70191db2 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import torch from torch.nn.parallel import DistributedDataParallel @@ -61,13 +63,16 @@ def on_train_start(self): @RunIf(min_gpus=4, special=True) -def test_ddp_barrier_non_consecutive_device_ids(tmpdir): - +@mock.patch("torch.distributed.barrier") +def test_ddp_barrier_non_consecutive_device_ids(barrier_mock, tmpdir): + """ Test correct usage of barriers when device ids do not start at 0 or are not consecutive. """ model = BoringModel() + gpus = [1, 3] trainer = Trainer( default_root_dir=tmpdir, max_steps=1, - gpus=[1, 3], + gpus=gpus, accelerator="ddp", ) trainer.fit(model) + barrier_mock.assert_any_call(device_ids=[gpus[trainer.local_rank]])