Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise a warning if evaulation is triggered with best ckpt in case of multiple checkpoint callbacks #11274

Merged
merged 10 commits into from
Jan 4, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))


- Raised `UserWarning` if evaluation is triggered with `best` ckpt and trainer is configured with multiple checkpoint callbacks ([#11274](https://github.com/PyTorchLightning/pytorch-lightning/pull/11274))


- `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270))


Expand Down
12 changes: 9 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,16 +1381,22 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
ckpt_path = "best"

if ckpt_path == "best":
# if user requests the best checkpoint but we don't have it, error
if len(self.checkpoint_callbacks) > 1:
rank_zero_warn(
f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`'
" callbacks. It will use the best checkpoint path from first checkpoint callback."
)

if not self.checkpoint_callback:
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
)

if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"
f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting."
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
)
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
Expand Down
29 changes: 20 additions & 9 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,21 @@ def predict_step(self, batch, *_):
trainer_fn(model, ckpt_path="best")


def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks():
"""Test that a warning is raised if best ckpt callback is used for evaluation configured with multiple
checkpoints."""

ckpt_callback1 = ModelCheckpoint()
ckpt_callback1.best_model_path = "foo_best_model.ckpt"
ckpt_callback2 = ModelCheckpoint()
ckpt_callback2.best_model_path = "bar_best_model.ckpt"
trainer = Trainer(callbacks=[ckpt_callback1, ckpt_callback2])
trainer.state.fn = TrainerFn.TESTING

with pytest.warns(UserWarning, match="best checkpoint path from first checkpoint callback"):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True)


def test_disabled_training(tmpdir):
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""

Expand Down Expand Up @@ -1799,15 +1814,11 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None:
trainer.fit(model, datamodule=dm)


def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
model = BoringModel()
trainer.fit(model)

with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()
def test_exception_when_testing_or_validating_with_fast_dev_run():
trainer = Trainer(fast_dev_run=True)
trainer.state.fn = TrainerFn.TESTING
with pytest.raises(MisconfigurationException, match=r"with `fast_dev_run=True`. .* pass an exact checkpoint path"):
trainer._Trainer__set_ckpt_path(ckpt_path="best", model_provided=False, model_connected=True)


class TrainerStagesModel(BoringModel):
Expand Down