Skip to content

Commit

Permalink
Enable val/test loop disabling + datamodule tests (#2692)
Browse files Browse the repository at this point in the history
* 🎨 warn instead of error out on loaders

* 🐛 test misconfiguration should still fail

* 🚧 .

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

* updated docs with new result obj

Co-authored-by: William Falcon <[email protected]>
  • Loading branch information
nateraw and williamFalcon authored Jul 25, 2020
1 parent 4bf1918 commit 9076551
Show file tree
Hide file tree
Showing 13 changed files with 393 additions and 279 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,6 @@ mnist/
# pl tests
ml-runs/
*.zip
*.ckpt
pytorch\ lightning
test-reports/
82 changes: 82 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn


class ConfigValidator(object):

def __init__(self, trainer):
self.trainer = trainer

def enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader or val_dataloaders) and datamodule:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

def verify_loop_configurations(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training or testing is started.
Args:
model: The model to check the configuration.
"""
if not self.trainer.testing:
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'validation')
else:
# check test loop configuration
self.__verify_eval_loop_configuration(model, 'test')

def __verify_train_loop_configuration(self, model):
# -----------------------------------
# verify model has a training step
# -----------------------------------
has_training_step = self.trainer.is_overridden('training_step', model)
if not has_training_step:
raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
)

# -----------------------------------
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = self.trainer.is_overridden('train_dataloader', model)
if not has_train_dataloader:
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
)

# -----------------------------------
# verify model has optimizer
# -----------------------------------
has_optimizers = self.trainer.is_overridden('configure_optimizers', model)
if not has_optimizers:
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
)

def __verify_eval_loop_configuration(self, model, eval_loop_name):
step_name = f'{eval_loop_name}_step'

# map the dataloader name
loader_name = f'{eval_loop_name}_dataloader'
if eval_loop_name == 'validation':
loader_name = 'val_dataloader'

has_loader = self.trainer.is_overridden(loader_name, model)
has_step = self.trainer.is_overridden(step_name, model)

if has_loader and not has_step:
rank_zero_warn(
f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop'
)
if has_step and not has_loader:
rank_zero_warn(
f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop'
)
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ def reset_val_dataloader(self, model: LightningModule) -> None:
Args:
model: The current `LightningModule`
"""
if self.is_overridden('validation_step'):
has_loader = self.is_overridden('val_dataloader', model)
has_step = self.is_overridden('validation_step', model)
if has_loader and has_step:
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')

def reset_test_dataloader(self, model) -> None:
Expand All @@ -348,7 +350,9 @@ def reset_test_dataloader(self, model) -> None:
Args:
model: The current `LightningModule`
"""
if self.is_overridden('test_step'):
has_loader = self.is_overridden('test_dataloader', model)
has_step = self.is_overridden('test_step', model)
if has_loader and has_step:
self.num_test_batches, self.test_dataloaders =\
self._reset_eval_dataloader(model, 'test')

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train_fx(trial_hparams, cluster_manager, _):

pid = os.getpid()
rng1 = np.random.RandomState(pid)
RANDOM_PORTS = rng1.randint(10000, 19999, 100)
RANDOM_PORTS = rng1.randint(10000, 19999, 1000)


class TrainerDDPMixin(ABC):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def run_evaluation(self, test_mode: bool = False):

# enable fast_dev_run without val loop
if dataloaders is None:
return
return [], []

# cap max batches to 1 when using fast_dev_run
if self.fast_dev_run:
Expand Down
88 changes: 12 additions & 76 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.configuration_validator import ConfigValidator

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -644,6 +645,7 @@ def __init__(

# tracks internal state for debugging
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)

# Callback system
self.on_init_end()
Expand Down Expand Up @@ -974,18 +976,19 @@ def fit(
if hasattr(model, 'hparams'):
parsing.clean_namespace(model.hparams)

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader or val_dataloaders) and datamodule:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None

self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)

# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
self.__attach_datamodule(model, datamodule)

# check that model is configured correctly
self.check_model_configuration(model)
self.config_validator.verify_loop_configurations(model)

# callbacks
self.on_fit_start()
Expand Down Expand Up @@ -1256,9 +1259,9 @@ def run_pretrain_routine(self, model: LightningModule):
self.train()

def _run_sanity_check(self, ref_model, model):
should_sanity_check = (
self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
)

using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step')
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

# run tiny validation (if validation defined)
# to make sure program won't crash during val
Expand Down Expand Up @@ -1448,73 +1451,6 @@ def __test_given_model(self, model, test_dataloaders):

return results

def check_model_configuration(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training or testing is started.
Args:
model: The model to check the configuration.
"""
# Check training_step, train_dataloader, configure_optimizer methods
if not self.testing:
if not self.is_overridden('training_step', model):
raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
)

if not self.is_overridden('train_dataloader', model):
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
)

if not self.is_overridden('configure_optimizers', model):
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
)

# Check val_dataloader, validation_step and validation_epoch_end
if self.is_overridden('val_dataloader', model):
if not self.is_overridden('validation_step', model):
raise MisconfigurationException(
'You have passed in a `val_dataloader()`' ' but have not defined `validation_step()`.'
)
else:
if not self.is_overridden('validation_epoch_end', model):
rank_zero_warn(
'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
' you may also want to define `validation_epoch_end()` for accumulating stats.',
RuntimeWarning,
)
else:
if self.is_overridden('validation_step', model):
raise MisconfigurationException(
'You have defined `validation_step()`,' ' but have not passed in a `val_dataloader()`.'
)

# Check test_dataloader, test_step and test_epoch_end
if self.is_overridden('test_dataloader', model):
if not self.is_overridden('test_step', model):
raise MisconfigurationException(
'You have passed in a `test_dataloader()`' ' but have not defined `test_step()`.'
)
else:
if not self.is_overridden('test_epoch_end', model):
rank_zero_warn(
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.',
RuntimeWarning,
)
else:
if self.testing and self.is_overridden('test_step', model):
raise MisconfigurationException(
'You have defined `test_step()` but did not'
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.'
)

def barrier(self, name):
if self.use_ddp or self.use_ddp2:
pass
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ def train(self):
# if reload_dataloaders_every_epoch, this is moved to the epoch loop
if not self.reload_dataloaders_every_epoch:
self.reset_train_dataloader(model)
self.reset_val_dataloader(model)

if model.val_dataloader is not None:
self.reset_val_dataloader(model)

# Train start events
with self.profiler.profile('on_train_start'):
Expand Down
18 changes: 9 additions & 9 deletions tests/base/datamodules.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from torch.utils.data import random_split, DataLoader

from pytorch_lightning import LightningDataModule
from tests.base.datasets import MNIST
from pytorch_lightning.core.datamodule import LightningDataModule
from tests.base.datasets import TrialMNIST


class MNISTDataModule(LightningDataModule):
class TrialMNISTDataModule(LightningDataModule):

def __init__(self, data_dir: str = './'):
super(MNISTDataModule, self).__init__()
super().__init__()
self.data_dir = data_dir

def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
TrialMNIST(self.data_dir, train=True, download=True)
TrialMNIST(self.data_dir, train=False, download=True)

def setup(self):
mnist_full = MNIST(self.data_dir, train=True, download=False)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
self.dims = tuple(self.mnist_train[0][0].shape)
self.mnist_test = MNIST(self.data_dir, train=False, download=False)
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
Expand Down
Loading

0 comments on commit 9076551

Please sign in to comment.