Skip to content

Commit

Permalink
use warning instead
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Dec 31, 2021
1 parent 0d20696 commit a5ca72d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `training_type_plugin` file to `strategy` ([#11239](https://github.com/PyTorchLightning/pytorch-lightning/pull/11239))


- Raised `MisconfigurationException` if evaulation is triggered with `best` ckpt but trainer is configured with multiple checkpoint callbacks ([#11274](https://github.com/PyTorchLightning/pytorch-lightning/pull/11274))
- Raised `UserWarning` if evaulation is triggered with `best` ckpt and trainer is configured with multiple checkpoint callbacks ([#11274](https://github.com/PyTorchLightning/pytorch-lightning/pull/11274))


### Deprecated
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,9 +1384,9 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_

if ckpt_path == "best":
if len(self.checkpoint_callbacks) > 1:
raise MisconfigurationException(
f'.{fn}(ckpt_path="best" is not supported with multiple `ModelCheckpoint` callbacks.'
" Please pass in the exact checkpoint path."
rank_zero_warn(
f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`'
" callbacks. It will use the best checkpoint path from first checkpoint callback."
)

if not self.checkpoint_callback:
Expand All @@ -1397,8 +1397,8 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True` unless you do'
f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting."
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
f" an exact checkpoint path"
)
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
Expand Down
26 changes: 23 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,30 @@ def predict_step(self, batch, *_):
def test_best_ckpt_evaluate_raises_error_with_multiple_ckpt_callbacks(tmpdir, fn):
"""Test that an error is raised if best ckpt callback is used for evaluation configured with multiple
checkpoints."""
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[ModelCheckpoint(), ModelCheckpoint()])

class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", batch_idx)
self.log("bar", batch_idx + 1)
return super().validation_step(batch, batch_idx)

ckpt_callbacks = [ModelCheckpoint(monitor="foo", save_top_k=1), ModelCheckpoint(monitor="bar", save_top_k=1)]
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
callbacks=ckpt_callbacks,
limit_test_batches=1,
limit_val_batches=1,
limit_predict_batches=1,
)

model = TestModel()
trainer.fit(model)

trainer_fn = getattr(trainer, fn)
with pytest.raises(MisconfigurationException, match="not supported with multiple `ModelCheckpoint` callbacks"):
trainer_fn(BoringModel(), ckpt_path="best")
with pytest.warns(UserWarning, match="best checkpoint path from first checkpoint callback"):
trainer_fn(ckpt_path="best")


def test_disabled_training(tmpdir):
Expand Down

0 comments on commit a5ca72d

Please sign in to comment.