From 6fffdec640739fa16eafae9e7c071b6d2fd342bb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 21:56:50 -0800 Subject: [PATCH] fix tests --- CHANGELOG.md | 2 ++ .../callbacks/model_checkpoint.py | 8 ++---- tests/checkpointing/test_model_checkpoint.py | 28 +++++-------------- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c5b35bd0c197d..b1faaf6c5d325c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- `period` has been deprecated in favor of `every_n_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + ### Removed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1ece791c89b96f..719892266231d9 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -174,8 +174,8 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.period = every_n_epochs - self.every_n_epochs = every_n_epochs + self.every_n_epochs = period or every_n_epochs + self.period = self.every_n_epochs self.every_n_batches = every_n_batches self._last_global_step_saved = -1 self.current_score = None @@ -192,8 +192,6 @@ def __init__( 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' ' Please use `every_n_epochs` instead.', DeprecationWarning ) - self.every_n_epochs = period - self.period = period self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) @@ -209,7 +207,7 @@ def on_pretrain_routine_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: if self._should_skip_saving_checkpoint(trainer): return - step = trainer.global_step + step = trainer.total_batch_idx skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) if skip_batch: return diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4ea0a4d1a1de22..476850a995d34f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -592,9 +592,8 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -615,9 +614,8 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -631,16 +629,15 @@ def test_ckpt_every_n_batches(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ model = LogInTwoMethods() - + every_n_batches = 16 trainer = Trainer( default_root_dir=tmpdir, - min_epochs=2, max_epochs=2, progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( filename="{step}", every_n_epochs=0, - every_n_batches=16, + every_n_batches=every_n_batches, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -649,16 +646,7 @@ def test_ckpt_every_n_batches(tmpdir): ) trainer.fit(model) - expected = [ - "step=15.ckpt", - "step=31.ckpt", - "step=47.ckpt", - "step=63.ckpt", - "step=79.ckpt", - "step=95.ckpt", - "step=111.ckpt", - "step=127.ckpt", - ] + expected=[f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] assert set(os.listdir(tmpdir)) == set(expected) @@ -667,9 +655,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): model = LogInTwoMethods() trainer = Trainer( default_root_dir=tmpdir, - min_epochs=2, max_epochs=2, - progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( every_n_epochs=1, every_n_batches=30,