From 5e8de9ad20a208a0a01da0ee18215aaa4d7bb5df Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 10 Mar 2021 22:47:06 -0800 Subject: [PATCH] make-private make attributes private to the class --- .../callbacks/model_checkpoint.py | 32 +++++++++---------- tests/checkpointing/test_model_checkpoint.py | 4 +-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 073b437b4cf89e..1fc5712e8fd87b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -205,7 +205,7 @@ def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None: if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0) + skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -215,8 +215,8 @@ def on_validation_end(self, trainer, pl_module): checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs is None - or self.every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 + self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 ) if skip: return @@ -279,18 +279,18 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_train_steps < 0: + if self._every_n_train_steps < 0: raise MisconfigurationException( - f'Invalid value for every_n_train_steps={self.every_n_train_steps}. Must be >= 0' + f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0' ) - if self.every_n_val_epochs < 0: + if self._every_n_val_epochs < 0: raise MisconfigurationException( - f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0' + f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' ) - if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0: + if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0: raise MisconfigurationException( - f'Invalid values for every_n_train_steps={self.every_n_train_steps}' - ' and every_n_val_epochs={self.every_n_val_epochs}.' + f'Invalid values for every_n_train_steps={self._every_n_train_steps}' + ' and every_n_val_epochs={self._every_n_val_epochs}.' 'Both cannot be enabled at the same time.' ) if self.monitor is None: @@ -346,12 +346,12 @@ def __init_triggers( # Default to running once after each validation epoch if neither # every_n_train_steps nor every_n_val_epochs is set if every_n_train_steps is None and every_n_val_epochs is None: - self.every_n_val_epochs = 1 - self.every_n_train_steps = 0 + self._every_n_val_epochs = 1 + self._every_n_train_steps = 0 log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") else: - self.every_n_val_epochs = every_n_val_epochs or 0 - self.every_n_train_steps = every_n_train_steps or 0 + self._every_n_val_epochs = every_n_val_epochs or 0 + self._every_n_train_steps = every_n_train_steps or 0 # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: @@ -359,9 +359,9 @@ def __init_triggers( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' ' Please use `every_n_val_epochs` instead.', DeprecationWarning ) - self.every_n_val_epochs = period + self._every_n_val_epochs = period - self._period = self.every_n_val_epochs + self._period = self._every_n_val_epochs @property def period(self) -> Optional[int]: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ae7d5b772652f3..c5f540f8de1056 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -550,8 +550,8 @@ def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): def test_none_every_n_train_steps_val_epochs(tmpdir): checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) assert checkpoint_callback.period == 1 - assert checkpoint_callback.every_n_val_epochs == 1 - assert checkpoint_callback.every_n_train_steps == 0 + assert checkpoint_callback._every_n_val_epochs == 1 + assert checkpoint_callback._every_n_train_steps == 0 def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):