Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Feb 24, 2021
1 parent b88b888 commit 2848bdc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))

### Changed

Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,9 @@ def __init__(
every_n_epochs: int = 1,
every_n_batches: int = -1,
mode: str = "min",
<<<<<<< HEAD
period: Optional[int] = None,
=======
period: int = 1,
every_n_epochs: int = 1,
every_n_batches: int = -1,
>>>>>>> Update model_checkpoint.py
period: Optional[int] = None,
):
super().__init__()
self.monitor = monitor
Expand Down
30 changes: 30 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,36 @@ def test_none_monitor_top_k(tmpdir):
ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
ModelCheckpoint(dirpath=tmpdir, save_top_k=0)

def test_invalid_every_n_epoch(tmpdir):
""" Test that an exception is raised for every_n_epochs = 0 or < -1. """
with pytest.raises(
MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'
):
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0)
with pytest.raises(
MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'
):
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2)

# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3)

def test_invalid_every_n_batches(tmpdir):
""" Test that an exception is raised for every_n_batches = 0 or < -1. """
with pytest.raises(
MisconfigurationException, match=r'Invalid value for every_n_batches=0*'
):
ModelCheckpoint(dirpath=tmpdir, every_n_batches=0)
with pytest.raises(
MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'
):
ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2)

# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1)
ModelCheckpoint(dirpath=tmpdir, every_n_batches=3)


def test_none_monitor_save_last(tmpdir):
""" Test that a warning appears for save_last=True with monitor=None. """
Expand Down

0 comments on commit 2848bdc

Please sign in to comment.