diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 007f898a27cc7..e6ece8c8cffb1 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -257,12 +257,12 @@ def pre_dispatch(self): self.dist.rank = self.global_rank self.dist.device = self.root_device - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device self.model_to_device() + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() self.barrier() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index fdb88a3c5cba5..dcd6443b0e6fd 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -148,12 +148,12 @@ def new_process(self, process_idx, trainer, mp_queue): self.dist.rank = self.global_rank self.dist.device = self.root_device - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device self.model_to_device() + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() self.barrier() diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index a26057da32b4f..69d199a76dfff 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -721,8 +721,6 @@ def __len__(self): assert has_len(dataloader) assert has_iterable_dataset(dataloader) trainer = Trainer(default_root_dir=tmpdir, max_steps=3) - with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): - trainer.validate(model, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): @@ -735,7 +733,6 @@ def __len__(self): assert not has_len(dataloader) assert has_iterable_dataset(dataloader) trainer = Trainer(default_root_dir=tmpdir, max_steps=3) - trainer.validate(model, val_dataloaders=dataloader) trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) trainer.test(model, test_dataloaders=dataloader) trainer.predict(model, dataloaders=dataloader)