Skip to content

Commit

Permalink
update test for resume_from_checkpoint on missing file (#7255)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored May 4, 2021
1 parent d413bab commit b780af5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 17 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `LightningModule.from_datasets()` now accepts `IterableDataset` instances as training datasets. ([#7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503))


- Changed `resume_from_checkpoint` warning to an error when the checkpoint file does not exist ([#7075](https://github.com/PyTorchLightning/pytorch-lightning/pull/7075))


### Deprecated


Expand Down
22 changes: 5 additions & 17 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,12 @@ def test_model_properties_resume_from_checkpoint(tmpdir):


def test_try_resume_from_non_existing_checkpoint(tmpdir):
""" Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
dm = ClassifDataModule()
model = ClassificationModel()
checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=False,
callbacks=[checkpoint_cb],
limit_train_batches=2,
limit_val_batches=2,
)
# Generate checkpoint `last.ckpt` with BoringModel
trainer.fit(model, datamodule=dm)
# `True` if resume/restore successfully else `False`
assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu)
""" Test that trying to resume from non-existing `resume_from_checkpoint` fails with an error."""
model = BoringModel()
trainer = Trainer(resume_from_checkpoint=str(tmpdir / "non_existing.ckpt"))

with pytest.raises(FileNotFoundError, match="Aborting training"):
trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)
trainer.fit(model)


class CaptureCallbacksBeforeTraining(Callback):
Expand Down

0 comments on commit b780af5

Please sign in to comment.