diff --git a/CHANGELOG.md b/CHANGELOG.md index 79b253061ddb2..70d2abcd7a7af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added + +- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667)) + + - Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f0f1d3e6b11e1..fa02df7fb7ad1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -955,31 +955,38 @@ def __load_ckpt_weights( model, ckpt_path: Optional[str] = None, ) -> Optional[str]: - # if user requests the best checkpoint but we don't have it, error - if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: + if ckpt_path is None: + return + + fn = self.state.value + + if ckpt_path == 'best': + # if user requests the best checkpoint but we don't have it, error + if not self.checkpoint_callback.best_model_path: + if self.fast_dev_run: + raise MisconfigurationException( + f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do' + f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.' + ) + raise MisconfigurationException( + f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' + ) + # load best weights + ckpt_path = self.checkpoint_callback.best_model_path + + if not ckpt_path: raise MisconfigurationException( - 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' + f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' + f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' ) - # load best weights - if ckpt_path is not None: - # ckpt_path is 'best' so load the best model - if ckpt_path == 'best': - ckpt_path = self.checkpoint_callback.best_model_path - - if not ckpt_path: - fn = self.state.value - raise MisconfigurationException( - f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' - ' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' - ) + # only one process running at this point for TPUs, as spawn isn't triggered yet + if self._device_type != DeviceType.TPU: + self.training_type_plugin.barrier() - # only one process running at this point for TPUs, as spawn isn't triggered yet - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() + ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt['state_dict']) - ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt['state_dict']) return ckpt_path def predict( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4ca2f737f5106..ee93ca59eca76 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1777,3 +1777,12 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()]) trainer.fit(model, datamodule=dm) + + +def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + + with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"): + trainer.validate() + with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"): + trainer.test()