From ab4012d0eaa4d1bab8d0caa41626cee3e8d680b5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 10 Mar 2021 22:47:06 -0800 Subject: [PATCH] make-private make attributes private to the class --- .../callbacks/model_checkpoint.py | 32 +++++++++---------- tests/checkpointing/test_model_checkpoint.py | 4 +-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 731232612cd0c..9c9244b9db317 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -217,7 +217,7 @@ def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None: if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0) + skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -227,8 +227,8 @@ def on_validation_end(self, trainer, pl_module): checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs is None - or self.every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 + self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 ) if skip: return @@ -291,18 +291,18 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_train_steps < 0: + if self._every_n_train_steps < 0: raise MisconfigurationException( - f'Invalid value for every_n_train_steps={self.every_n_train_steps}. Must be >= 0' + f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0' ) - if self.every_n_val_epochs < 0: + if self._every_n_val_epochs < 0: raise MisconfigurationException( - f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0' + f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' ) - if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0: + if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0: raise MisconfigurationException( - f'Invalid values for every_n_train_steps={self.every_n_train_steps}' - ' and every_n_val_epochs={self.every_n_val_epochs}.' + f'Invalid values for every_n_train_steps={self._every_n_train_steps}' + ' and every_n_val_epochs={self._every_n_val_epochs}.' 'Both cannot be enabled at the same time.' ) if self.monitor is None: @@ -358,12 +358,12 @@ def __init_triggers( # Default to running once after each validation epoch if neither # every_n_train_steps nor every_n_val_epochs is set if every_n_train_steps is None and every_n_val_epochs is None: - self.every_n_val_epochs = 1 - self.every_n_train_steps = 0 + self._every_n_val_epochs = 1 + self._every_n_train_steps = 0 log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") else: - self.every_n_val_epochs = every_n_val_epochs or 0 - self.every_n_train_steps = every_n_train_steps or 0 + self._every_n_val_epochs = every_n_val_epochs or 0 + self._every_n_train_steps = every_n_train_steps or 0 # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: @@ -371,9 +371,9 @@ def __init_triggers( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' ' Please use `every_n_val_epochs` instead.', DeprecationWarning ) - self.every_n_val_epochs = period + self._every_n_val_epochs = period - self._period = self.every_n_val_epochs + self._period = self._every_n_val_epochs @property def period(self) -> Optional[int]: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e9d5e2daa85a7..a9e86b0b2a223 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -559,8 +559,8 @@ def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): def test_none_every_n_train_steps_val_epochs(tmpdir): checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) assert checkpoint_callback.period == 1 - assert checkpoint_callback.every_n_val_epochs == 1 - assert checkpoint_callback.every_n_train_steps == 0 + assert checkpoint_callback._every_n_val_epochs == 1 + assert checkpoint_callback._every_n_train_steps == 0 def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):