From 8fee4ab2f68e98121a9e3bc2be13252ff015b46e Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Wed, 24 Mar 2021 21:02:31 +0100 Subject: [PATCH 01/23] Add more explicit error message when testing --- pytorch_lightning/trainer/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f0f1d3e6b11e1..711a19e5c8bc7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -939,6 +939,11 @@ def test( # Attach dataloaders (if given) self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if not model_provided and self.fast_dev_run: + raise MisconfigurationException( + f"You cannot execute testing when model is not provided and {self.fast_dev_run=}." + ) + if not model_provided: self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) From c26f7748f5bb28371e321b61ba83699b701e1ba7 Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Wed, 24 Mar 2021 21:09:23 +0100 Subject: [PATCH 02/23] Change err message --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 711a19e5c8bc7..5ee8428661e8e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,7 @@ def test( if not model_provided and self.fast_dev_run: raise MisconfigurationException( - f"You cannot execute testing when model is not provided and {self.fast_dev_run=}." + f"You cannot execute testing when model is not provided and {self.fast_dev_run=}. Provide model with trainer.test(model=...)" ) if not model_provided: From 9903875a6b383ed02029e2cd62b409a313261c19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Zalewski?= Date: Wed, 24 Mar 2021 21:15:52 +0100 Subject: [PATCH 03/23] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5ee8428661e8e..ed6a086911241 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,7 @@ def test( if not model_provided and self.fast_dev_run: raise MisconfigurationException( - f"You cannot execute testing when model is not provided and {self.fast_dev_run=}. Provide model with trainer.test(model=...)" + f'You cannot execute testing when model is not provided and {self.fast_dev_run=}. Provide model with `trainer.test(model=...)`' ) if not model_provided: From 93ad46346deabdd3e17a66989a7327a7b48ca8cc Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Wed, 24 Mar 2021 21:17:57 +0100 Subject: [PATCH 04/23] Split message str --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ed6a086911241..2d89da5d7db9f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,8 @@ def test( if not model_provided and self.fast_dev_run: raise MisconfigurationException( - f'You cannot execute testing when model is not provided and {self.fast_dev_run=}. Provide model with `trainer.test(model=...)`' + f'You cannot execute testing when model is not provided and {self.fast_dev_run=}.' + 'Provide model with `trainer.test(model=...)`' ) if not model_provided: From a9715c6906251c1e55e4ab4d89fe7e090bf86358 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Zalewski?= Date: Wed, 24 Mar 2021 21:20:51 +0100 Subject: [PATCH 05/23] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2d89da5d7db9f..0fd53ab22f23e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,7 @@ def test( if not model_provided and self.fast_dev_run: raise MisconfigurationException( - f'You cannot execute testing when model is not provided and {self.fast_dev_run=}.' + f'You cannot execute testing when the model is not provided and {self.fast_dev_run=}.' 'Provide model with `trainer.test(model=...)`' ) From db2c31fec5f873aa7eb46c4caebbd35f455e11b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Zalewski?= Date: Wed, 24 Mar 2021 21:30:42 +0100 Subject: [PATCH 06/23] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0fd53ab22f23e..78affb776df34 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,7 @@ def test( if not model_provided and self.fast_dev_run: raise MisconfigurationException( - f'You cannot execute testing when the model is not provided and {self.fast_dev_run=}.' + f'You cannot execute testing when the model is not provided and {self.fast_dev_run=}. ' 'Provide model with `trainer.test(model=...)`' ) From d91574fdba04828bd4592e3481a5db77790305ef Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Wed, 24 Mar 2021 21:37:34 +0100 Subject: [PATCH 07/23] fix err --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 78affb776df34..99180ba5e3918 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,7 @@ def test( if not model_provided and self.fast_dev_run: raise MisconfigurationException( - f'You cannot execute testing when the model is not provided and {self.fast_dev_run=}. ' + f'You cannot execute testing when the model is not provided and `fast_dev_run=True`. ' 'Provide model with `trainer.test(model=...)`' ) From 61d95941a6795504b5c9247843889bb6dca9c8c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Zalewski?= Date: Wed, 24 Mar 2021 21:43:56 +0100 Subject: [PATCH 08/23] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 99180ba5e3918..d01528f7146d7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,7 @@ def test( if not model_provided and self.fast_dev_run: raise MisconfigurationException( - f'You cannot execute testing when the model is not provided and `fast_dev_run=True`. ' + 'You cannot execute testing when the model is not provided and `fast_dev_run=True`. ' 'Provide model with `trainer.test(model=...)`' ) From 3c3a16208a2b4a45b480472b33e308ce6d3b1071 Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Wed, 24 Mar 2021 21:58:27 +0100 Subject: [PATCH 09/23] fix --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d01528f7146d7..cef3ee19379d8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -939,7 +939,7 @@ def test( # Attach dataloaders (if given) self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) - if not model_provided and self.fast_dev_run: + if not model_provided and ckpt_path == 'best' and self.fast_dev_run: raise MisconfigurationException( 'You cannot execute testing when the model is not provided and `fast_dev_run=True`. ' 'Provide model with `trainer.test(model=...)`' From 321ce93512f7c4e4b4e37e5d12fcfb237b67604f Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Wed, 24 Mar 2021 22:00:27 +0100 Subject: [PATCH 10/23] improve message --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cef3ee19379d8..377ebe0a79445 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -942,7 +942,7 @@ def test( if not model_provided and ckpt_path == 'best' and self.fast_dev_run: raise MisconfigurationException( 'You cannot execute testing when the model is not provided and `fast_dev_run=True`. ' - 'Provide model with `trainer.test(model=...)`' + 'Provide model with `trainer.test(model=...)` or `trainer.test(ckpt_path=...)`' ) if not model_provided: From 1e77c5f18ace190aee8641f54306e1911974bef4 Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Fri, 26 Mar 2021 15:27:50 +0100 Subject: [PATCH 11/23] add test --- pytorch_lightning/trainer/trainer.py | 21 +++++++++++---------- tests/trainer/test_trainer.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 377ebe0a79445..6fcb563749154 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -939,12 +939,6 @@ def test( # Attach dataloaders (if given) self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) - if not model_provided and ckpt_path == 'best' and self.fast_dev_run: - raise MisconfigurationException( - 'You cannot execute testing when the model is not provided and `fast_dev_run=True`. ' - 'Provide model with `trainer.test(model=...)` or `trainer.test(ckpt_path=...)`' - ) - if not model_provided: self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) @@ -962,10 +956,17 @@ def __load_ckpt_weights( ckpt_path: Optional[str] = None, ) -> Optional[str]: # if user requests the best checkpoint but we don't have it, error - if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: - raise MisconfigurationException( - 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' - ) + if ckpt_path == 'best': + if not self.checkpoint_callback.best_model_path and self.fast_dev_run: + raise MisconfigurationException( + 'You cannot execute `trainer.test()` or trainer.validate()`' + ' when `fast_dev_run=True`.' + ) + + if not self.checkpoint_callback.best_model_path: + raise MisconfigurationException( + 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' + ) # load best weights if ckpt_path is not None: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4ca2f737f5106..e1c26e80094ef 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1777,3 +1777,19 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()]) trainer.fit(model, datamodule=dm) + + +@pytest.mark.parametrize("fast_dev_run", [True, False]) +def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir, fast_dev_run): + model = BoringModel() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run) + + trainer.fit(model) + + if fast_dev_run: + with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"): + trainer.validate() + + with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"): + trainer.test() From 1510cd56009e2fb985e01578dbac4b353aee8664 Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Fri, 26 Mar 2021 15:40:51 +0100 Subject: [PATCH 12/23] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 780a8790b9fdd..96d9a98d7083f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added more explicit exception message when trying to execute trainer.test() or trainer.validate() with fast_dev_run=True - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) From e3f8502573d92f7afce1530b45e54e62f9f54d41 Mon Sep 17 00:00:00 2001 From: hobogalaxy Date: Fri, 26 Mar 2021 15:46:53 +0100 Subject: [PATCH 13/23] update changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96d9a98d7083f..0dc0ef5cc3ab8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added more explicit exception message when trying to execute trainer.test() or trainer.validate() with fast_dev_run=True +- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667)) + - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) From 2bde80181a138a9a5702a722f3470efbe2eb7c3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 26 Mar 2021 18:32:05 +0100 Subject: [PATCH 14/23] Apply suggestions from code review --- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/trainer/test_trainer.py | 16 ++++++---------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6fcb563749154..3bccda59d1a92 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -959,8 +959,8 @@ def __load_ckpt_weights( if ckpt_path == 'best': if not self.checkpoint_callback.best_model_path and self.fast_dev_run: raise MisconfigurationException( - 'You cannot execute `trainer.test()` or trainer.validate()`' - ' when `fast_dev_run=True`.' + 'You cannot execute `trainer.test()` or `trainer.validate()`' + ' with `fast_dev_run=True`.' ) if not self.checkpoint_callback.best_model_path: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e1c26e80094ef..691a83f1b18f0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1779,17 +1779,13 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: trainer.fit(model, datamodule=dm) -@pytest.mark.parametrize("fast_dev_run", [True, False]) -def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir, fast_dev_run): +def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): model = BoringModel() - - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run) - + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) - if fast_dev_run: - with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"): - trainer.validate() + with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"): + trainer.validate() - with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"): - trainer.test() + with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"): + trainer.test() From 4c771fcf00a78e27ea7bc8ae676fe4787058c780 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Mar 2021 18:34:46 +0100 Subject: [PATCH 15/23] Move comment --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3bccda59d1a92..dd172e28e35a4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -955,7 +955,6 @@ def __load_ckpt_weights( model, ckpt_path: Optional[str] = None, ) -> Optional[str]: - # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best': if not self.checkpoint_callback.best_model_path and self.fast_dev_run: raise MisconfigurationException( @@ -963,6 +962,7 @@ def __load_ckpt_weights( ' with `fast_dev_run=True`.' ) + # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: raise MisconfigurationException( 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' From 29d47735376e33cb9f29bbaa993f764645e17b48 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Mar 2021 18:36:37 +0100 Subject: [PATCH 16/23] Unnecessary fit --- tests/trainer/test_trainer.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 691a83f1b18f0..b72f0137a3a75 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1780,12 +1780,9 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): - model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model) - with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"): + with pytest.raises(MisconfigurationException, match=".*with `fast_dev_run=True`*"): trainer.validate() - - with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"): + with pytest.raises(MisconfigurationException, match=".*with `fast_dev_run=True`*"): trainer.test() From 39da91a23e9255b5cb43b2f09967365de93d8059 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Mar 2021 18:37:05 +0100 Subject: [PATCH 17/23] Fix match --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b72f0137a3a75..4c0b0d9731ae8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1782,7 +1782,7 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.raises(MisconfigurationException, match=".*with `fast_dev_run=True`*"): + with pytest.raises(MisconfigurationException, match="with `fast_dev_run=True`"): trainer.validate() - with pytest.raises(MisconfigurationException, match=".*with `fast_dev_run=True`*"): + with pytest.raises(MisconfigurationException, match="with `fast_dev_run=True`"): trainer.test() From 5065ec4abf7d18659f3ea472d96cd1d507c3b946 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 14:01:06 +0200 Subject: [PATCH 18/23] Refactor. Mention ckpt_path --- pytorch_lightning/trainer/trainer.py | 35 +++++++++++++--------------- tests/trainer/test_trainer.py | 4 ++-- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dd172e28e35a4..9d28b4b051c58 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -955,38 +955,35 @@ def __load_ckpt_weights( model, ckpt_path: Optional[str] = None, ) -> Optional[str]: + fn = self.state.value + if ckpt_path == 'best': if not self.checkpoint_callback.best_model_path and self.fast_dev_run: raise MisconfigurationException( - 'You cannot execute `trainer.test()` or `trainer.validate()`' - ' with `fast_dev_run=True`.' + f'You cannot execute `trainer.{fn}()` with `fast_dev_run=True` unless you do' + f'`trainer.{fn}(ckpt_path=...)` as no checkpoint path was generated during fitting.' ) - # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: raise MisconfigurationException( 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' ) + # load best weights + ckpt_path = self.checkpoint_callback.best_model_path - # load best weights - if ckpt_path is not None: - # ckpt_path is 'best' so load the best model - if ckpt_path == 'best': - ckpt_path = self.checkpoint_callback.best_model_path + if not ckpt_path: + raise MisconfigurationException( + f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' + f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' + ) - if not ckpt_path: - fn = self.state.value - raise MisconfigurationException( - f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' - ' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' - ) + # only one process running at this point for TPUs, as spawn isn't triggered yet + if self._device_type != DeviceType.TPU: + self.training_type_plugin.barrier() - # only one process running at this point for TPUs, as spawn isn't triggered yet - if not self._device_type == DeviceType.TPU: - self.training_type_plugin.barrier() + ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt['state_dict']) - ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt['state_dict']) return ckpt_path def predict( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4c0b0d9731ae8..a9ce7ad2a47bf 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1782,7 +1782,7 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.raises(MisconfigurationException, match="with `fast_dev_run=True`"): + with pytest.raises(MisconfigurationException, match="trainer.validate\(\)` with `fast_dev_run=True"): trainer.validate() - with pytest.raises(MisconfigurationException, match="with `fast_dev_run=True`"): + with pytest.raises(MisconfigurationException, match="trainer.test\(\)` with `fast_dev_run=True"): trainer.test() From 538f4d0f6b2bbc066d3e263344505405ff0f3e08 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 14:03:18 +0200 Subject: [PATCH 19/23] flake8 --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a9ce7ad2a47bf..2896b3f11258e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1782,7 +1782,7 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.raises(MisconfigurationException, match="trainer.validate\(\)` with `fast_dev_run=True"): + with pytest.raises(MisconfigurationException, match=r"trainer.validate\(\)` with `fast_dev_run=True"): trainer.validate() - with pytest.raises(MisconfigurationException, match="trainer.test\(\)` with `fast_dev_run=True"): + with pytest.raises(MisconfigurationException, match=r"trainer.test\(\)` with `fast_dev_run=True"): trainer.test() From 1610ec473284932dda4ac1ff343bbff515b8e166 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 14:07:56 +0200 Subject: [PATCH 20/23] Consistency --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9d28b4b051c58..5735157c2b84c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -960,8 +960,8 @@ def __load_ckpt_weights( if ckpt_path == 'best': if not self.checkpoint_callback.best_model_path and self.fast_dev_run: raise MisconfigurationException( - f'You cannot execute `trainer.{fn}()` with `fast_dev_run=True` unless you do' - f'`trainer.{fn}(ckpt_path=...)` as no checkpoint path was generated during fitting.' + f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do' + f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.' ) # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: From 728349b49ea262ec5de1b2c2368a746ead15de6c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 14:11:21 +0200 Subject: [PATCH 21/23] Reuse if --- pytorch_lightning/trainer/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5735157c2b84c..82d626c709972 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -958,15 +958,15 @@ def __load_ckpt_weights( fn = self.state.value if ckpt_path == 'best': - if not self.checkpoint_callback.best_model_path and self.fast_dev_run: - raise MisconfigurationException( - f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do' - f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.' - ) # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: + if self.fast_dev_run: + raise MisconfigurationException( + f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do' + f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.' + ) raise MisconfigurationException( - 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' + f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights ckpt_path = self.checkpoint_callback.best_model_path From 137dd04f667cda658cea8874dba0a3567efc1cab Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 14:43:11 +0200 Subject: [PATCH 22/23] Add early exit --- pytorch_lightning/trainer/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 82d626c709972..fa02df7fb7ad1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -955,6 +955,9 @@ def __load_ckpt_weights( model, ckpt_path: Optional[str] = None, ) -> Optional[str]: + if ckpt_path is None: + return + fn = self.state.value if ckpt_path == 'best': From b15434da6174f1b763b7e1784a9279ffc5933ab0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Mar 2021 15:04:08 +0200 Subject: [PATCH 23/23] Fix test --- tests/trainer/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2896b3f11258e..ee93ca59eca76 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1782,7 +1782,7 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None: def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - with pytest.raises(MisconfigurationException, match=r"trainer.validate\(\)` with `fast_dev_run=True"): + with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"): trainer.validate() - with pytest.raises(MisconfigurationException, match=r"trainer.test\(\)` with `fast_dev_run=True"): + with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"): trainer.test()