diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 58e26e7db32d8..b8437b0d418e4 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -241,12 +241,12 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device self.model_to_device() + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() self.barrier() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 87d7fa5faecac..126afc9be6040 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -141,12 +141,12 @@ def new_process(self, process_idx, trainer, mp_queue): self.dist.rank = self.global_rank self.dist.device = self.root_device - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device self.model_to_device() + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() self.barrier()