-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable val/test loop disabling + datamodule tests (#2692)
* 🎨 warn instead of error out on loaders * 🐛 test misconfiguration should still fail * 🚧 . * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj * updated docs with new result obj Co-authored-by: William Falcon <[email protected]>
- Loading branch information
1 parent
4bf1918
commit 9076551
Showing
13 changed files
with
393 additions
and
279 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -133,5 +133,6 @@ mnist/ | |
# pl tests | ||
ml-runs/ | ||
*.zip | ||
*.ckpt | ||
pytorch\ lightning | ||
test-reports/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities import rank_zero_warn | ||
|
||
|
||
class ConfigValidator(object): | ||
|
||
def __init__(self, trainer): | ||
self.trainer = trainer | ||
|
||
def enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): | ||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders | ||
if (train_dataloader or val_dataloaders) and datamodule: | ||
raise MisconfigurationException( | ||
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' | ||
) | ||
|
||
def verify_loop_configurations(self, model: LightningModule): | ||
r""" | ||
Checks that the model is configured correctly before training or testing is started. | ||
Args: | ||
model: The model to check the configuration. | ||
""" | ||
if not self.trainer.testing: | ||
self.__verify_train_loop_configuration(model) | ||
self.__verify_eval_loop_configuration(model, 'validation') | ||
else: | ||
# check test loop configuration | ||
self.__verify_eval_loop_configuration(model, 'test') | ||
|
||
def __verify_train_loop_configuration(self, model): | ||
# ----------------------------------- | ||
# verify model has a training step | ||
# ----------------------------------- | ||
has_training_step = self.trainer.is_overridden('training_step', model) | ||
if not has_training_step: | ||
raise MisconfigurationException( | ||
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' | ||
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' | ||
) | ||
|
||
# ----------------------------------- | ||
# verify model has a train dataloader | ||
# ----------------------------------- | ||
has_train_dataloader = self.trainer.is_overridden('train_dataloader', model) | ||
if not has_train_dataloader: | ||
raise MisconfigurationException( | ||
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' | ||
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' | ||
) | ||
|
||
# ----------------------------------- | ||
# verify model has optimizer | ||
# ----------------------------------- | ||
has_optimizers = self.trainer.is_overridden('configure_optimizers', model) | ||
if not has_optimizers: | ||
raise MisconfigurationException( | ||
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' | ||
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' | ||
) | ||
|
||
def __verify_eval_loop_configuration(self, model, eval_loop_name): | ||
step_name = f'{eval_loop_name}_step' | ||
|
||
# map the dataloader name | ||
loader_name = f'{eval_loop_name}_dataloader' | ||
if eval_loop_name == 'validation': | ||
loader_name = 'val_dataloader' | ||
|
||
has_loader = self.trainer.is_overridden(loader_name, model) | ||
has_step = self.trainer.is_overridden(step_name, model) | ||
|
||
if has_loader and not has_step: | ||
rank_zero_warn( | ||
f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop' | ||
) | ||
if has_step and not has_loader: | ||
rank_zero_warn( | ||
f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.