diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 922dbdd13ab41..a1bc7e6cafd49 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir): # test train loader applies correct limits # ------------------------------------------------------ trainer = Trainer(overfit_batches=4) + trainer.data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) assert trainer.num_training_batches == 4 @@ -93,6 +94,7 @@ def test_overfit_batch_limits(tmpdir): assert torch.eq(ya, yb).all() trainer = Trainer(overfit_batches=0.11) + trainer.data_connector.attach_dataloaders(model=model) trainer.reset_train_dataloader(model) # The dataloader should have been overwritten with a Sequential sampler. assert trainer.train_dataloader is not train_loader @@ -111,7 +113,9 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as percent # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=0.11) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == num_train_samples # make sure we turned off shuffle for the user @@ -125,23 +129,35 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as int # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 1 - loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model) + trainer = Trainer(overfit_batches=5) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 5 # ------------------------------------------------------ # test limit_xxx_batches as percent AND int # ------------------------------------------------------ if split == RunningStage.VALIDATING: - loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_val_batches=0.1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(val_loader)) - loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_val_batches=10) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10 else: - loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_test_batches=0.1) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(test_loader)) - loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model) + trainer = Trainer(limit_test_batches=10) + trainer.data_connector.attach_dataloaders(model) + loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10