Skip to content

Commit

Permalink
More explicit exception message when testing with fast_dev_run=True (#…
Browse files Browse the repository at this point in the history
…6667)

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
ashleve and carmocca authored Mar 29, 2021
1 parent dcf6e4e commit cca0eca
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added


- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))


- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))


Expand Down
47 changes: 27 additions & 20 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,31 +955,38 @@ def __load_ckpt_weights(
model,
ckpt_path: Optional[str] = None,
) -> Optional[str]:
# if user requests the best checkpoint but we don't have it, error
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
if ckpt_path is None:
return

fn = self.state.value

if ckpt_path == 'best':
# if user requests the best checkpoint but we don't have it, error
if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do'
f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.'
)
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
)
# load best weights
ckpt_path = self.checkpoint_callback.best_model_path

if not ckpt_path:
raise MisconfigurationException(
'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.'
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)

# load best weights
if ckpt_path is not None:
# ckpt_path is 'best' so load the best model
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path

if not ckpt_path:
fn = self.state.value
raise MisconfigurationException(
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)
# only one process running at this point for TPUs, as spawn isn't triggered yet
if self._device_type != DeviceType.TPU:
self.training_type_plugin.barrier()

# only one process running at this point for TPUs, as spawn isn't triggered yet
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
return ckpt_path

def predict(
Expand Down
9 changes: 9 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,3 +1777,12 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None:

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()])
trainer.fit(model, datamodule=dm)


def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)

with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()

0 comments on commit cca0eca

Please sign in to comment.