Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Mar 5, 2021
1 parent d91636d commit 6fffdec
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- `period` has been deprecated in favor of `every_n_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


### Removed

Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def __init__(
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = every_n_epochs
self.every_n_epochs = every_n_epochs
self.every_n_epochs = period or every_n_epochs
self.period = self.every_n_epochs
self.every_n_batches = every_n_batches
self._last_global_step_saved = -1
self.current_score = None
Expand All @@ -192,8 +192,6 @@ def __init__(
'Argument `period` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_epochs` instead.', DeprecationWarning
)
self.every_n_epochs = period
self.period = period

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
Expand All @@ -209,7 +207,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None:
if self._should_skip_saving_checkpoint(trainer):
return
step = trainer.global_step
step = trainer.total_batch_idx
skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0)
if skip_batch:
return
Expand Down
28 changes: 7 additions & 21 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,8 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs):
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
val_check_interval=1.0,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)
Expand All @@ -615,9 +614,8 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs):
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
val_check_interval=1.0,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)
Expand All @@ -631,16 +629,15 @@ def test_ckpt_every_n_batches(tmpdir):
""" Tests that the checkpoints are saved every n training steps. """

model = LogInTwoMethods()

every_n_batches = 16
trainer = Trainer(
default_root_dir=tmpdir,
min_epochs=2,
max_epochs=2,
progress_bar_refresh_rate=0,
checkpoint_callback=ModelCheckpoint(
filename="{step}",
every_n_epochs=0,
every_n_batches=16,
every_n_batches=every_n_batches,
dirpath=tmpdir,
save_top_k=-1,
save_last=False,
Expand All @@ -649,16 +646,7 @@ def test_ckpt_every_n_batches(tmpdir):
)

trainer.fit(model)
expected = [
"step=15.ckpt",
"step=31.ckpt",
"step=47.ckpt",
"step=63.ckpt",
"step=79.ckpt",
"step=95.ckpt",
"step=111.ckpt",
"step=127.ckpt",
]
expected=[f"step={i}.ckpt" for i in range(15, 128, every_n_batches)]
assert set(os.listdir(tmpdir)) == set(expected)


Expand All @@ -667,9 +655,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir):
model = LogInTwoMethods()
trainer = Trainer(
default_root_dir=tmpdir,
min_epochs=2,
max_epochs=2,
progress_bar_refresh_rate=0,
checkpoint_callback=ModelCheckpoint(
every_n_epochs=1,
every_n_batches=30,
Expand Down

0 comments on commit 6fffdec

Please sign in to comment.