diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 58e26e7db32d8..b8437b0d418e4 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -241,12 +241,12 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device self.model_to_device() + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() self.barrier()