From b530b7afd2799daacfcf84d4d6773c48ee911557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 12 Oct 2021 12:45:28 +0200 Subject: [PATCH] update tests to not rely on patched dataloaders (#9905) --- tests/callbacks/test_early_stopping.py | 2 +- tests/models/test_restore.py | 4 ++-- tests/trainer/test_data_loading.py | 8 +++---- tests/trainer/test_dataloaders.py | 16 ++++++-------- tests/trainer/test_trainer_tricks.py | 30 ++++++++++++++++++++------ tests/tuner/test_scale_batch_size.py | 7 ++++-- 6 files changed, 42 insertions(+), 25 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index b44dee9abb0b6..3a4fc7c4f5541 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -95,7 +95,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): ) with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"): - new_trainer.fit(model) + new_trainer.fit(model, datamodule=dm) def test_early_stopping_no_extraneous_invocations(tmpdir): diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 1034fa26a3ac5..f20bff249cd09 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -340,7 +340,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): new_trainer.test(pretrained_model) pretrained_model.cpu() - dataloaders = model.test_dataloader() + dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -539,7 +539,7 @@ def on_pretrain_routine_end(self): # haven't trained with the new loaded model new_trainer.state.stage = RunningStage.VALIDATING - dataloader = self.train_dataloader() + dataloader = dm.train_dataloader() tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader) self.on_pretrain_routine_end_called = True diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 6e91cf926723c..31e18b2bfb578 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -267,19 +267,19 @@ def test_loader_detaching(): class LoaderTestModel(BoringModel): def training_step(self, batch, batch_idx): - assert len(model.train_dataloader()) == 10 + assert len(self.trainer.train_dataloader.loaders) == 10 return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - assert len(model.val_dataloader()) == 10 + assert len(self.trainer.val_dataloaders[0]) == 10 return super().validation_step(batch, batch_idx) def test_step(self, batch, batch_idx): - assert len(model.test_dataloader()) == 10 + assert len(self.trainer.test_dataloaders[0]) == 10 return super().test_step(batch, batch_idx) def predict_step(self, batch, batch_idx, dataloader_idx=None): - assert len(model.predict_dataloader()) == 10 + assert len(self.trainer.predict_dataloaders[0]) == 10 return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) loader = DataLoader(RandomDataset(32, 10), batch_size=1) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 3b2b0e687aebb..9a3c79bdd3cf6 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -184,7 +184,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders - # train, multiple val and multiple test passed to fit + # multiple val dataloaders passed to fit trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders) @@ -195,10 +195,10 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n): ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path) - trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) + assert len(trainer.test_dataloaders) == n + trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path) assert len(trainer.val_dataloaders) == n - assert len(trainer.test_dataloaders) == n class DummyModel(BoringModel): @@ -551,17 +551,15 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert trainer.state.finished, f"Training failed with {trainer.state}" # fit model trainer = Trainer(**trainer_options) trainer.fit(model, val_dataloaders=model.dataloader(train=False)) - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" + if ckpt_path == "specific": ckpt_path = trainer.checkpoint_callback.best_model_path trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path) - - assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}" assert ( len(trainer.test_dataloaders) == 1 ), f"`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}" @@ -1313,8 +1311,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): def test_dataloaders_reset_and_attach(tmpdir): - """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before - attaching the new one.""" + """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching + the new one.""" # the assertions compare the datasets and not dataloaders since we patch and replace the samplers dataloader_0 = DataLoader(dataset=RandomDataset(32, 64)) dataloader_1 = DataLoader(dataset=RandomDataset(32, 64)) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 922dbdd13ab41..a1bc7e6cafd49 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir): # test train loader applies correct limits # ------------------------------------------------------ trainer = Trainer(overfit_batches=4) + trainer.data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) assert trainer.num_training_batches == 4 @@ -93,6 +94,7 @@ def test_overfit_batch_limits(tmpdir): assert torch.eq(ya, yb).all() trainer = Trainer(overfit_batches=0.11) + trainer.data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) # The dataloader should have been overwritten with a Sequential sampler. assert trainer.train_dataloader is not train_loader @@ -111,7 +113,9 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as percent # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=0.11) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == num_train_samples # make sure we turned off shuffle for the user @@ -125,23 +129,35 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as int # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 1 - loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=5) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 5 # ------------------------------------------------------ # test limit_xxx_batches as percent AND int # ------------------------------------------------------ if split == RunningStage.VALIDATING: - loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_val_batches=0.1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(val_loader)) - loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_val_batches=10) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10 else: - loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_test_batches=0.1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(test_loader)) - loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_test_batches=10) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10 diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 32b6f1db41ac9..5e4d1af1277c7 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -220,9 +220,12 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): limit_train_batches=0.2, auto_scale_batch_size="power", ) - fit_options = dict(train_dataloader=model.dataloader(train=True)) + fit_options = dict(train_dataloaders=model.dataloader(train=True)) - with pytest.raises(MisconfigurationException): + with pytest.raises( + MisconfigurationException, + match="The batch scaling feature cannot be used with dataloaders passed directly", + ): trainer.tune(model, **fit_options)