Skip to content

Commit

Permalink
update tests to not rely on patched dataloaders (#9905)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Oct 12, 2021
1 parent 98c0a11 commit b530b7a
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 25 deletions.
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
)

with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
new_trainer.fit(model)
new_trainer.fit(model, datamodule=dm)


def test_early_stopping_no_extraneous_invocations(tmpdir):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
new_trainer.test(pretrained_model)
pretrained_model.cpu()

dataloaders = model.test_dataloader()
dataloaders = dm.test_dataloader()
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

Expand Down Expand Up @@ -539,7 +539,7 @@ def on_pretrain_routine_end(self):
# haven't trained with the new loaded model
new_trainer.state.stage = RunningStage.VALIDATING

dataloader = self.train_dataloader()
dataloader = dm.train_dataloader()
tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader)
self.on_pretrain_routine_end_called = True

Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,19 @@ def test_loader_detaching():

class LoaderTestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert len(model.train_dataloader()) == 10
assert len(self.trainer.train_dataloader.loaders) == 10
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
assert len(model.val_dataloader()) == 10
assert len(self.trainer.val_dataloaders[0]) == 10
return super().validation_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
assert len(model.test_dataloader()) == 10
assert len(self.trainer.test_dataloaders[0]) == 10
return super().test_step(batch, batch_idx)

def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert len(model.predict_dataloader()) == 10
assert len(self.trainer.predict_dataloaders[0]) == 10
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

loader = DataLoader(RandomDataset(32, 10), batch_size=1)
Expand Down
16 changes: 7 additions & 9 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders

# train, multiple val and multiple test passed to fit
# multiple val dataloaders passed to fit
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders)

Expand All @@ -195,10 +195,10 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
ckpt_path = trainer.checkpoint_callback.best_model_path

trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path)
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
assert len(trainer.test_dataloaders) == n

trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
assert len(trainer.val_dataloaders) == n
assert len(trainer.test_dataloaders) == n


class DummyModel(BoringModel):
Expand Down Expand Up @@ -551,17 +551,15 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert trainer.state.finished, f"Training failed with {trainer.state}"

# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"

if ckpt_path == "specific":
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)

assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
assert (
len(trainer.test_dataloaders) == 1
), f"`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}"
Expand Down Expand Up @@ -1313,8 +1311,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):


def test_dataloaders_reset_and_attach(tmpdir):
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before
attaching the new one."""
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching
the new one."""
# the assertions compare the datasets and not dataloaders since we patch and replace the samplers
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))
Expand Down
30 changes: 23 additions & 7 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_overfit_batch_limits(tmpdir):
# test train loader applies correct limits
# ------------------------------------------------------
trainer = Trainer(overfit_batches=4)
trainer.data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
assert trainer.num_training_batches == 4

Expand All @@ -93,6 +94,7 @@ def test_overfit_batch_limits(tmpdir):
assert torch.eq(ya, yb).all()

trainer = Trainer(overfit_batches=0.11)
trainer.data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
# The dataloader should have been overwritten with a Sequential sampler.
assert trainer.train_dataloader is not train_loader
Expand All @@ -111,7 +113,9 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# test overfit_batches as percent
# ------------------------------------------------------
loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model)
trainer = Trainer(overfit_batches=0.11)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == num_train_samples

# make sure we turned off shuffle for the user
Expand All @@ -125,23 +129,35 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# test overfit_batches as int
# ------------------------------------------------------
loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model)
trainer = Trainer(overfit_batches=1)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 1
loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model)
trainer = Trainer(overfit_batches=5)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 5

# ------------------------------------------------------
# test limit_xxx_batches as percent AND int
# ------------------------------------------------------
if split == RunningStage.VALIDATING:
loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model)
trainer = Trainer(limit_val_batches=0.1)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == int(0.1 * len(val_loader))

loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model)
trainer = Trainer(limit_val_batches=10)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 10
else:
loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model)
trainer = Trainer(limit_test_batches=0.1)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == int(0.1 * len(test_loader))

loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model)
trainer = Trainer(limit_test_batches=10)
trainer.data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 10
7 changes: 5 additions & 2 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,12 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
limit_train_batches=0.2,
auto_scale_batch_size="power",
)
fit_options = dict(train_dataloader=model.dataloader(train=True))
fit_options = dict(train_dataloaders=model.dataloader(train=True))

with pytest.raises(MisconfigurationException):
with pytest.raises(
MisconfigurationException,
match="The batch scaling feature cannot be used with dataloaders passed directly",
):
trainer.tune(model, **fit_options)


Expand Down

0 comments on commit b530b7a

Please sign in to comment.