Skip to content

Commit

Permalink
Deprecate checkpoint_callback from the Trainer constructor in fav…
Browse files Browse the repository at this point in the history
…our of `enable_checkpointing` (#9754)

* enable_chekpointing

* update codebase

* chlog

* update tests

* fix warning

* Apply suggestions from code review

Co-authored-by: ananthsub <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: ananthsub <[email protected]>

* Apply suggestions from code review

Co-authored-by: ananthsub <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 12, 2021
1 parent 14fb076 commit db322f4
Show file tree
Hide file tree
Showing 33 changed files with 130 additions and 109 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))


- Deprecated `checkpoint_callback` from the `Trainer` constructor in favour of `enable_checkpointing` ([#9754](https://github.com/PyTorchLightning/pytorch-lightning/pull/9754))


- Deprecated the `LightningModule.on_post_move_to_device` method ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/hyperparameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ To recap, add ALL possible trainer flags to the argparser and init the ``Trainer
trainer = Trainer.from_argparse_args(hparams)
# or if you need to pass in callbacks
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...])
trainer = Trainer.from_argparse_args(hparams, enable_checkpointing=..., callbacks=[...])
----------

Expand Down
70 changes: 35 additions & 35 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,38 @@ Example::
checkpoint_callback
^^^^^^^^^^^^^^^^^^^

Deprecated: This has been deprecated in v1.5 and will be removed in v.17. Please use ``enable_checkpointing`` instead.

default_root_dir
^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/default%E2%80%A8_root_dir.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/default_root_dir.mp4"></video>

|
Default path for logs and weights when no logger or
:class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On
certain clusters you might want to separate where logs and checkpoints are
stored. If you don't then use this argument for convenience. Paths can be local
paths or remote paths such as `s3://bucket/path` or 'hdfs://path/'. Credentials
will need to be set up to use remote filepaths.

.. testcode::

# default used by the Trainer
trainer = Trainer(default_root_dir=os.getcwd())

distributed_backend
^^^^^^^^^^^^^^^^^^^
Deprecated: This has been renamed ``accelerator``.

enable_checkpointing
^^^^^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
Expand All @@ -542,11 +574,11 @@ To disable automatic checkpointing, set this to `False`.

.. code-block:: python
# default used by Trainer
trainer = Trainer(checkpoint_callback=True)
# default used by Trainer, saves the most recent model to a single checkpoint after each epoch
trainer = Trainer(enable_checkpointing=True)
# turn off automatic checkpointing
trainer = Trainer(checkpoint_callback=False)
trainer = Trainer(enable_checkpointing=False)
You can override the default behavior by initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint`
Expand All @@ -563,38 +595,6 @@ See :doc:`Saving and Loading Weights <../common/weights_loading>` for how to cus
# Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])


.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.


default_root_dir
^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/default%E2%80%A8_root_dir.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/default_root_dir.mp4"></video>

|
Default path for logs and weights when no logger or
:class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On
certain clusters you might want to separate where logs and checkpoints are
stored. If you don't then use this argument for convenience. Paths can be local
paths or remote paths such as `s3://bucket/path` or 'hdfs://path/'. Credentials
will need to be set up to use remote filepaths.

.. testcode::

# default used by the Trainer
trainer = Trainer(default_root_dir=os.getcwd())

distributed_backend
^^^^^^^^^^^^^^^^^^^
Deprecated: This has been renamed ``accelerator``.

fast_dev_run
^^^^^^^^^^^^

Expand Down
30 changes: 20 additions & 10 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(self, trainer):
def on_trainer_init(
self,
callbacks: Optional[Union[List[Callback], Callback]],
checkpoint_callback: bool,
checkpoint_callback: Optional[bool],
enable_checkpointing: bool,
enable_progress_bar: bool,
progress_bar_refresh_rate: Optional[int],
process_position: int,
Expand Down Expand Up @@ -67,7 +68,7 @@ def on_trainer_init(

# configure checkpoint callback
# pass through the required args to figure out defaults
self._configure_checkpoint_callbacks(checkpoint_callback)
self._configure_checkpoint_callbacks(checkpoint_callback, enable_checkpointing)

# configure swa callback
self._configure_swa_callbacks()
Expand Down Expand Up @@ -140,22 +141,31 @@ def _configure_accumulated_gradients(
self.trainer.accumulate_grad_batches = grad_accum_callback.get_accumulate_grad_batches(0)
self.trainer.accumulation_scheduler = grad_accum_callback

def _configure_checkpoint_callbacks(self, checkpoint_callback: bool) -> None:
def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], enable_checkpointing: bool) -> None:
if checkpoint_callback is not None:
rank_zero_deprecation(
f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
f"be removed in v1.7. Please consider using `Trainer(enable_checkpointing={checkpoint_callback})`."
)
# if both are set then checkpoint only if both are True
enable_checkpointing = checkpoint_callback and enable_checkpointing

# TODO: Remove this error in v1.5 so we rely purely on the type signature
if not isinstance(checkpoint_callback, bool):
if not isinstance(enable_checkpointing, bool):
error_msg = (
"Invalid type provided for checkpoint_callback:"
f" Expected bool but received {type(checkpoint_callback)}."
"Invalid type provided for `enable_checkpointing`: "
f"Expected bool but received {type(enable_checkpointing)}."
)
if isinstance(checkpoint_callback, Callback):
if isinstance(enable_checkpointing, Callback):
error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
raise MisconfigurationException(error_msg)
if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
if self._trainer_has_checkpoint_callbacks() and enable_checkpointing is False:
raise MisconfigurationException(
"Trainer was configured with checkpoint_callback=False but found ModelCheckpoint in callbacks list."
"Trainer was configured with `enable_checkpointing=False`"
" but found `ModelCheckpoint` in callbacks list."
)

if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True:
self.trainer.callbacks.append(ModelCheckpoint())

def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ class Trainer(
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
checkpoint_callback: bool = True,
checkpoint_callback: Optional[bool] = None,
enable_checkpointing: bool = True,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: Union[int, float] = 0.0,
Expand Down Expand Up @@ -215,6 +216,12 @@ def __init__(
callbacks: Add a callback or list of callbacks.
checkpoint_callback: If ``True``, enable checkpointing.
.. deprecated:: v1.5
``checkpoint_callback`` has been deprecated in v1.5 and will be removed in v1.7.
Please consider using ``enable_checkpointing`` instead.
enable_checkpointing: If ``True``, enable checkpointing.
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
Expand Down Expand Up @@ -465,6 +472,7 @@ def __init__(
self.callback_connector.on_trainer_init(
callbacks,
checkpoint_callback,
enable_checkpointing,
enable_progress_bar,
progress_bar_refresh_rate,
process_position,
Expand Down
6 changes: 2 additions & 4 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_resume_training_on_cpu(tmpdir):
"""Checks if training can be resumed from a saved checkpoint on CPU."""
# Train a model on TPU
model = BoringModel()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=8)
trainer = Trainer(max_epochs=1, tpu_cores=8)
trainer.fit(model)

model_path = trainer.checkpoint_callback.best_model_path
Expand All @@ -62,9 +62,7 @@ def test_resume_training_on_cpu(tmpdir):
assert weight_tensor.device == torch.device("cpu")

# Verify that training is resumed on CPU
trainer = Trainer(
resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir
)
trainer = Trainer(resume_from_checkpoint=model_path, max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def configure_callbacks(self):

model = TestModel()
trainer_options = dict(
default_root_dir=tmpdir, checkpoint_callback=False, fast_dev_run=True, enable_progress_bar=False
default_root_dir=tmpdir, enable_checkpointing=False, fast_dev_run=True, enable_progress_bar=False
)

def assert_expected_calls(_trainer, model_callback, trainer_callback):
Expand Down Expand Up @@ -86,7 +86,7 @@ def configure_callbacks(self):
return [model_callback_mock]

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False)

callbacks_before_fit = trainer.callbacks.copy()
assert callbacks_before_fit
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
limit_train_batches=4,
limit_val_batches=4,
max_epochs=expected_count,
checkpoint_callback=False,
enable_checkpointing=False,
)
trainer.fit(model, datamodule=dm)

Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
callbacks=[TestFinetuning(), lr_monitor, Check()],
enable_progress_bar=False,
weights_summary=None,
checkpoint_callback=False,
enable_checkpointing=False,
)
model = TestModel()
model.training_epoch_end = None
Expand Down
8 changes: 4 additions & 4 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def on_validation_epoch_end(self, *args):
limit_val_batches=limit_val_batches,
callbacks=[progress_bar],
logger=False,
checkpoint_callback=False,
enable_checkpointing=False,
)
trainer.fit(model)

Expand Down Expand Up @@ -342,7 +342,7 @@ def test_main_progress_bar_update_amount(
limit_val_batches=val_batches,
callbacks=[progress_bar],
logger=False,
checkpoint_callback=False,
enable_checkpointing=False,
)
trainer.fit(model)
if train_batches > 0:
Expand All @@ -362,7 +362,7 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate
limit_test_batches=test_batches,
callbacks=[progress_bar],
logger=False,
checkpoint_callback=False,
enable_checkpointing=False,
)
trainer.test(model)
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])
Expand All @@ -379,7 +379,7 @@ def training_step(self, batch, batch_idx):
return super().training_step(batch, batch_idx)

trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, checkpoint_callback=False
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
)
trainer.fit(TestModel())

Expand Down
6 changes: 3 additions & 3 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def train_with_pruning_callback(
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
checkpoint_callback=False,
enable_checkpointing=False,
logger=False,
limit_train_batches=10,
limit_val_batches=2,
Expand Down Expand Up @@ -227,7 +227,7 @@ def apply_lottery_ticket_hypothesis(self):
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
checkpoint_callback=False,
enable_checkpointing=False,
logger=False,
limit_train_batches=10,
limit_val_batches=2,
Expand All @@ -254,7 +254,7 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
default_root_dir=tmpdir,
enable_progress_bar=False,
weights_summary=None,
checkpoint_callback=False,
enable_checkpointing=False,
logger=False,
limit_train_batches=10,
limit_val_batches=2,
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
assert torch.allclose(org_score, quant_score, atol=0.45)
model_path = trainer.checkpoint_callback.best_model_path

trainer_args.update(dict(max_epochs=1, checkpoint_callback=False))
trainer_args.update(dict(max_epochs=1, enable_checkpointing=False))
if not convert:
trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
trainer.fit(qmodel, datamodule=dm)
Expand Down
4 changes: 2 additions & 2 deletions tests/checkpointing/test_checkpoint_callback_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from tests.helpers.runif import RunIf


def test_checkpoint_callback_disabled(tmpdir):
def test_disabled_checkpointing(tmpdir):
# no callback
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
trainer = Trainer(max_epochs=3, enable_checkpointing=False)
assert not trainer.checkpoint_callbacks
trainer.fit(BoringModel())
assert not trainer.checkpoint_callbacks
Expand Down
1 change: 0 additions & 1 deletion tests/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
default_root_dir=str(tmpdir),
gpus=int(torch.cuda.is_available()),
precision=(16 if torch.cuda.is_available() else 32),
checkpoint_callback=True,
callbacks=[es, stop],
max_epochs=21,
accumulate_grad_batches=2,
Expand Down
16 changes: 8 additions & 8 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,17 +998,17 @@ def test_configure_model_checkpoint(tmpdir):
callback2 = ModelCheckpoint()

# no callbacks
trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs)
trainer = Trainer(enable_checkpointing=False, callbacks=[], **kwargs)
assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks)
assert trainer.checkpoint_callback is None

# default configuration
trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs)
trainer = Trainer(callbacks=[], **kwargs)
assert sum(1 for c in trainer.callbacks if isinstance(c, ModelCheckpoint)) == 1
assert isinstance(trainer.checkpoint_callback, ModelCheckpoint)

# custom callback passed to callbacks list, checkpoint_callback=True is ignored
trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs)
# custom callback passed to callbacks list, enable_checkpointing=True is ignored
trainer = Trainer(enable_checkpointing=True, callbacks=[callback1], **kwargs)
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
assert trainer.checkpoint_callback == callback1

Expand All @@ -1017,8 +1017,8 @@ def test_configure_model_checkpoint(tmpdir):
assert trainer.checkpoint_callback == callback1
assert trainer.checkpoint_callbacks == [callback1, callback2]

with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)
with pytest.raises(MisconfigurationException, match="`enable_checkpointing=False` but found `ModelCheckpoint`"):
Trainer(enable_checkpointing=False, callbacks=[callback1], **kwargs)


def test_val_check_interval_checkpoint_files(tmpdir):
Expand Down Expand Up @@ -1189,8 +1189,8 @@ def test_model_checkpoint_mode_options():

def test_trainer_checkpoint_callback_bool(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir)
with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"):
Trainer(checkpoint_callback=mc)
with pytest.raises(MisconfigurationException, match="Invalid type provided for `enable_checkpointing`"):
Trainer(enable_checkpointing=mc)


def test_check_val_every_n_epochs_top_k_integration(tmpdir):
Expand Down
Loading

0 comments on commit db322f4

Please sign in to comment.