Skip to content

Commit

Permalink
update tests to not rely on patched model dataloader
Browse files Browse the repository at this point in the history
x
x
x
x
  • Loading branch information
awaelchli committed Oct 12, 2021
1 parent f16bfe9 commit 772cd27
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
)

with pytest.raises(MisconfigurationException, match=r"You restored a checkpoint with current_epoch"):
new_trainer.fit(model)
new_trainer.fit(model, datamodule=dm)


def test_early_stopping_no_extraneous_invocations(tmpdir):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
new_trainer.test(pretrained_model)
pretrained_model.cpu()

dataloaders = model.test_dataloader()
dataloaders = dm.test_dataloader()
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

Expand Down Expand Up @@ -539,7 +539,7 @@ def on_pretrain_routine_end(self):
# haven't trained with the new loaded model
new_trainer.state.stage = RunningStage.VALIDATING

dataloader = self.train_dataloader()
dataloader = dm.train_dataloader()
tpipes.run_prediction_eval_model_template(self.trainer.lightning_module, dataloader=dataloader)
self.on_pretrain_routine_end_called = True

Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,19 @@ def test_loader_detaching():

class LoaderTestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert len(model.train_dataloader()) == 10
assert len(self.trainer.train_dataloader.loaders) == 10
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
assert len(model.val_dataloader()) == 10
assert len(self.trainer.val_dataloaders[0]) == 10
return super().validation_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
assert len(model.test_dataloader()) == 10
assert len(self.trainer.test_dataloaders[0]) == 10
return super().test_step(batch, batch_idx)

def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert len(model.predict_dataloader()) == 10
assert len(self.trainer.predict_dataloaders[0]) == 10
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

loader = DataLoader(RandomDataset(32, 10), batch_size=1)
Expand Down
16 changes: 7 additions & 9 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
model.test_step = model.test_step__multiple_dataloaders

# train, multiple val and multiple test passed to fit
# multiple val dataloaders passed to fit
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2)
trainer.fit(model, train_dataloader=model.dataloader(train=True), val_dataloaders=dataloaders)

Expand All @@ -195,10 +195,10 @@ def test_dataloaders_passed_to_fn(tmpdir, ckpt_path, n):
ckpt_path = trainer.checkpoint_callback.best_model_path

trainer.test(test_dataloaders=dataloaders, ckpt_path=ckpt_path)
trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
assert len(trainer.test_dataloaders) == n

trainer.validate(val_dataloaders=dataloaders, ckpt_path=ckpt_path)
assert len(trainer.val_dataloaders) == n
assert len(trainer.test_dataloaders) == n


class DummyModel(BoringModel):
Expand Down Expand Up @@ -551,17 +551,15 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert trainer.state.finished, f"Training failed with {trainer.state}"

# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model, val_dataloaders=model.dataloader(train=False))
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"

if ckpt_path == "specific":
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.test(test_dataloaders=model.dataloader(train=False), ckpt_path=ckpt_path)

assert len(trainer.val_dataloaders) == 1, f"`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}"
assert (
len(trainer.test_dataloaders) == 1
), f"`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}"
Expand Down Expand Up @@ -1313,8 +1311,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):


def test_dataloaders_reset_and_attach(tmpdir):
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before
attaching the new one."""
"""Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching
the new one."""
# the assertions compare the datasets and not dataloaders since we patch and replace the samplers
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))
Expand Down
7 changes: 5 additions & 2 deletions tests/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,12 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
limit_train_batches=0.2,
auto_scale_batch_size="power",
)
fit_options = dict(train_dataloader=model.dataloader(train=True))
fit_options = dict(train_dataloaders=model.dataloader(train=True))

with pytest.raises(MisconfigurationException):
with pytest.raises(
MisconfigurationException,
match="The batch scaling feature cannot be used with dataloaders passed directly",
):
trainer.tune(model, **fit_options)


Expand Down

0 comments on commit 772cd27

Please sign in to comment.