From a16807993b4dfe1ce92d517c02e41f4fafad4b67 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 5 Apr 2021 14:48:34 +0300 Subject: [PATCH 1/2] Fix DPP + SyncBN Ensure that model is already on correct GPU before applying SyncBN conversion --- pytorch_lightning/plugins/training_type/ddp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 58e26e7db32d8..b8437b0d418e4 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -241,12 +241,12 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device self.model_to_device() + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() self.barrier() From 9b4a71dfcb430dfe7777722193ef51cfb295002b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 5 Apr 2021 16:18:36 +0300 Subject: [PATCH 2/2] Fix order of SyncBN for ddp_spawn --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..126afc9be6040 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -141,12 +141,12 @@ def new_process(self, process_idx, trainer, mp_queue): self.dist.rank = self.global_rank self.dist.device = self.root_device - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device self.model_to_device() + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() self.barrier()