From 1015a0050621828c9e8af2c934e19c5c68d61a5e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 22:23:25 -0500 Subject: [PATCH] Clean up dataloader logic (#926) * added get dataloaders directly using a getter * deleted decorator * added prepare_data hook * refactored dataloader init * refactored dataloader init * added dataloader reset flag and main loop * added dataloader reset flag and main loop * added dataloader reset flag and main loop * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * made changes * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed bad loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixed error in .fit with loaders * fixes #909 * fixes #909 * bug fix * Fixes #902 --- CHANGELOG.md | 4 + docs/source/hooks.rst | 6 +- .../lightning_module_template.py | 20 +- pytorch_lightning/core/decorators.py | 29 +- pytorch_lightning/core/lightning.py | 41 ++- pytorch_lightning/trainer/data_loading.py | 276 ++++++++++-------- pytorch_lightning/trainer/evaluation_loop.py | 45 ++- pytorch_lightning/trainer/model_hooks.py | 5 +- pytorch_lightning/trainer/trainer.py | 130 ++++----- pytorch_lightning/trainer/training_loop.py | 47 +-- tests/models/__init__.py | 4 + tests/models/base.py | 22 +- tests/models/mixins.py | 196 ++++++++++++- tests/models/utils.py | 12 +- tests/test_cpu_models.py | 3 +- tests/test_gpu_models.py | 32 +- tests/test_restore_models.py | 16 +- tests/test_trainer.py | 86 +++--- 18 files changed, 615 insertions(+), 359 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b96da2d8d891f..18f10757768c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added automatic sampler setup. Depending on DDP or TPU, lightning configures the sampler correctly (user needs to do nothing) ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) +- Added `reload_dataloaders_every_epoch=False` flag for trainer. Some users require reloading data every epoch ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) +- Added `progress_bar_refresh_rate=50` flag for trainer. Throttle refresh rate on notebooks ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) - Updated governance docs - Added a check to ensure that the metric used for early stopping exists before training commences ([#542](https://github.com/PyTorchLightning/pytorch-lightning/pull/542)) - Added `optimizer_idx` argument to `backward` hook ([#733](https://github.com/PyTorchLightning/pytorch-lightning/pull/733)) @@ -22,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Removed `@data_loader` decorator ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) - Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752)) - Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767)) - Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749)) diff --git a/docs/source/hooks.rst b/docs/source/hooks.rst index fee74ea218155..0cade1c4a8eea 100644 --- a/docs/source/hooks.rst +++ b/docs/source/hooks.rst @@ -8,9 +8,9 @@ Training set-up - init_optimizers - configure_apex - configure_ddp -- get_train_dataloader -- get_test_dataloaders -- get_val_dataloaders +- train_dataloader +- test_dataloaders +- val_dataloaders - summarize - restore_weights diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index 81cdf2acba6be..fabe18504686f 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -192,37 +192,35 @@ def __dataloader(self, train): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root=self.hparams.data_root, train=train, - transform=transform, download=True) + transform=transform, download=False) # when using multi-node (ddp) we need to add the datasampler - train_sampler = None batch_size = self.hparams.batch_size - if self.use_ddp: - train_sampler = DistributedSampler(dataset) - - should_shuffle = train_sampler is None loader = DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=should_shuffle, - sampler=train_sampler, num_workers=0 ) return loader - @pl.data_loader + def prepare_data(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root=self.hparams.data_root, train=True, + transform=transform, download=True) + dataset = MNIST(root=self.hparams.data_root, train=False, + transform=transform, download=True) + def train_dataloader(self): log.info('Training data loader called.') return self.__dataloader(train=True) - @pl.data_loader def val_dataloader(self): log.info('Validation data loader called.') return self.__dataloader(train=False) - @pl.data_loader def test_dataloader(self): log.info('Test data loader called.') return self.__dataloader(train=False) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 3448fda4d4864..be4fd41d06ee8 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -1,5 +1,6 @@ import traceback from functools import wraps +import warnings def data_loader(fn): @@ -8,27 +9,9 @@ def data_loader(fn): :param fn: :return: """ - wraps(fn) - attr_name = '_lazy_' + fn.__name__ - @wraps(fn) - def _get_data_loader(self): - try: - value = getattr(self, attr_name) - except AttributeError: - try: - value = fn(self) # Lazy evaluation, done only once. - if ( - value is not None and - not isinstance(value, list) and - fn.__name__ in ['test_dataloader', 'val_dataloader'] - ): - value = [value] - except AttributeError as e: - # Guard against AttributeError suppression. (Issue #142) - traceback.print_exc() - error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) - raise RuntimeError(error) from e - setattr(self, attr_name, value) # Memoize evaluation. - return value + w = 'data_loader decorator deprecated in 0.6.1. Will remove 0.8.0' + warnings.warn(w) - return _get_data_loader + def inner_fx(self): + return fn(self) + return inner_fx diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7a09aabb3cd45..85a713427e151 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -16,6 +16,13 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +try: + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True + +except ImportError: + XLA_AVAILABLE = False + class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): @@ -798,7 +805,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, sec optimizer.zero_grad() """ - if isinstance(optimizer, torch.optim.LBFGS): + if self.trainer.use_tpu and XLA_AVAILABLE: + xm.optimizer_step(optimizer) + elif isinstance(optimizer, torch.optim.LBFGS): optimizer.step(second_order_closure) else: optimizer.step() @@ -868,7 +877,33 @@ def tbptt_split_batch(self, batch, split_size): return splits - @data_loader + def prepare_data(self): + """Use this to download and prepare data. + In distributed (GPU, TPU), this will only be called once + + :return: PyTorch DataLoader + + This is called before requesting the dataloaders + + .. code-block:: python + + model.prepare_data() + model.train_dataloader() + model.val_dataloader() + model.test_dataloader() + + Example + ------- + + .. code-block:: python + + def prepare_data(self): + download_imagenet() + clean_imagenet() + cache_imagenet() + """ + return None + def train_dataloader(self): """Implement a PyTorch DataLoader @@ -908,7 +943,6 @@ def tng_dataloader(self): # todo: remove in v0.8.0 " and this method will be removed in v0.8.0", DeprecationWarning) return output - @data_loader def test_dataloader(self): r""" @@ -942,7 +976,6 @@ def test_dataloader(self): """ return None - @data_loader def val_dataloader(self): r""" diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 43e87928a6abe..c355aef456c23 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -2,6 +2,9 @@ from abc import ABC import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import RandomSampler, SequentialSampler, DataLoader, BatchSampler +from pytorch_lightning.utilities.debugging import MisconfigurationException try: # loading for pyTorch 1.3 @@ -14,9 +17,6 @@ EXIST_ITER_DATASET = False else: EXIST_ITER_DATASET = True -from torch.utils.data.distributed import DistributedSampler - -from pytorch_lightning.utilities.debugging import MisconfigurationException try: from apex import amp @@ -47,6 +47,13 @@ def __init__(self): self.val_check_interval = None self.use_tpu = None self.tpu_local_core_rank = None + self.train_dataloader = None + self.num_training_batches = None + self.val_check_batch = None + self.val_dataloaders = None + self.num_val_batches = None + self.test_dataloaders = None + self.num_test_batches = None def _percent_range_check(self, name): value = getattr(self, name) @@ -57,21 +64,97 @@ def _percent_range_check(self, name): if not 0. <= value <= 1.: raise ValueError(msg) - def init_train_dataloader(self, model): + def call_prepare_data(self, model): + """ + Let model download the data on proc==0 only + :param model: + """ + # download data on DDP+ + if self.use_ddp or self.use_ddp2: + if self.proc_rank == 0: + model.prepare_data() + + # all processes wait until data download has happened + dist.barrier() + + # data download/load on TPU + elif self.use_tpu and XLA_AVAILABLE: + if self.tpu_local_core_rank == 0: + model.prepare_data() + + # all processes wait until data download has happened + torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders") + + else: + # regular download + model.prepare_data() + + def auto_add_sampler(self, dataloader, train): + # do nothing when user gives a sampler + dl_args = { + 'dataset': dataloader.dataset, + 'batch_size': dataloader.batch_size, + 'shuffle': False, + 'num_workers': dataloader.num_workers, + 'collate_fn': dataloader.collate_fn, + 'pin_memory': dataloader.pin_memory, + 'drop_last': dataloader.drop_last, + 'timeout': dataloader.timeout, + 'worker_init_fn': dataloader.worker_init_fn + } + + if train: + if self.use_ddp or self.use_ddp2: + sampler = DistributedSampler(dataloader.dataset) + dl_args['shuffle'] = False + + elif self.use_tpu: + sampler = DistributedSampler( + dataloader.dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal() + ) + dl_args['shuffle'] = False + else: + sampler = RandomSampler(dataloader.dataset) + + # on not train + else: + if self.use_tpu: + sampler = DistributedSampler( + dataloader.dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal() + ) + dl_args['shuffle'] = False + else: + sampler = SequentialSampler(dataloader.dataset) + + dl_args['sampler'] = sampler + + new_dataloader = DataLoader(**dl_args) + return new_dataloader + + def reset_train_dataloader(self, model): """ Dataloaders are provided by the model :param model: :return: """ - self.get_train_dataloader = model.train_dataloader + + self.train_dataloader = self.request_data_loader(model.train_dataloader) + self.num_training_batches = 0 + + # automatically add samplers + self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) # determine number of training batches - if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset): + if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset): self.num_training_batches = float('inf') else: self._percent_range_check('train_percent_check') - self.num_training_batches = len(self.get_train_dataloader()) + self.num_training_batches = len(self.train_dataloader) self.num_training_batches = int(self.num_training_batches * self.train_percent_check) # determine when to check validation @@ -90,161 +173,98 @@ def init_train_dataloader(self, model): self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - on_ddp = self.use_ddp or self.use_ddp2 - needs_sampler = on_ddp or self.use_tpu - if needs_sampler and not isinstance(self.get_train_dataloader().sampler, DistributedSampler): - msg = """ - You're using multiple gpus and multiple nodes, or TPUs without using a - to assign a subset of your data to each process. To silence this warning, pass a - DistributedSampler to your DataLoader. - - ie: this: - dataset = myDataset() - dataloader = Dataloader(dataset) - - becomes: - dataset = myDataset() - dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - dataloader = Dataloader(dataset, sampler=dist_sampler) - - If you want each process to load the full dataset, ignore this warning. - """ - if msg not in self.shown_warnings and self.proc_rank == 0: - self.shown_warnings.add(msg) - warnings.warn(msg) - - def init_val_dataloader(self, model): + # support IterableDataset for train data + self.is_iterable_train_dataloader = ( + EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset) + ) + if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int): + m = ''' + When using an iterableDataset for `train_dataloader`, + `Trainer(val_check_interval)` must be an int. + An int k specifies checking validation every k training batches + ''' + raise MisconfigurationException(m) + + def reset_val_dataloader(self, model): """ Dataloaders are provided by the model :param model: :return: """ - self.get_val_dataloaders = model.val_dataloader + if not self.is_overriden('validation_step'): + return + + self.val_dataloaders = self.request_data_loader(model.val_dataloader) + if not isinstance(self.val_dataloaders, list): + self.val_dataloaders = [self.val_dataloaders] self.num_val_batches = 0 + # add samplers + for i, dataloader in enumerate(self.val_dataloaders): + dl = self.auto_add_sampler(dataloader, train=False) + self.val_dataloaders[i] = dl + # determine number of validation batches # val datasets could be none, 1 or 2+ - if self.get_val_dataloaders() is not None: + if self.val_dataloaders is not None: self._percent_range_check('val_percent_check') - self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) + self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders) self.num_val_batches = int(self.num_val_batches * self.val_percent_check) - on_ddp = self.use_ddp or self.use_ddp2 - needs_sampler = on_ddp or self.use_tpu - if needs_sampler and self.get_val_dataloaders() is not None: - for dataloader in self.get_val_dataloaders(): - if not isinstance(dataloader.sampler, DistributedSampler): - msg = """ - Your val_dataloader(s) don't use DistributedSampler. - - You're using multiple gpus and multiple nodes, or TPUs without using a - DistributedSampler to assign a subset of your data to each process. - To silence this warning, pass a DistributedSampler to your DataLoader. - - ie: this: - dataset = myDataset() - dataloader = Dataloader(dataset) - - becomes: - dataset = myDataset() - dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - dataloader = Dataloader(dataset, sampler=dist_sampler) - - If you want each process to load the full dataset, ignore this warning. - """ - if msg not in self.shown_warnings and self.proc_rank == 0: - self.shown_warnings.add(msg) - warnings.warn(msg) - break - - def init_test_dataloader(self, model): + def reset_test_dataloader(self, model): """Dataloaders are provided by the model. :param model: """ + if not self.is_overriden('test_step'): + return - self.get_test_dataloaders = model.test_dataloader + # get actual loader + self.test_dataloaders = self.request_data_loader(model.test_dataloader) + if not isinstance(self.test_dataloaders, list): + self.test_dataloaders = [self.test_dataloaders] + self.num_test_batches = 0 + + # add samplers + for i, dataloader in enumerate(self.test_dataloaders): + dl = self.auto_add_sampler(dataloader, train=False) + self.test_dataloaders[i] = dl # determine number of test batches - if self.get_test_dataloaders() is not None: + if self.test_dataloaders is not None: self._percent_range_check('test_percent_check') - len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) + len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders) self.num_test_batches = len_sum self.num_test_batches = int(self.num_test_batches * self.test_percent_check) - on_ddp = self.use_ddp or self.use_ddp2 - needs_sampler = on_ddp or self.use_tpu - if needs_sampler and self.get_test_dataloaders() is not None: - for dataloader in self.get_test_dataloaders(): - if not isinstance(dataloader.sampler, DistributedSampler): - msg = """ - Your `test_dataloader(s)` don't use DistributedSampler. - - You're using multiple gpus and multiple nodes, or TPUs without using a - DistributedSampler to assign a subset of your data to each process. - To silence this warning, pass a DistributedSampler to your DataLoader. - - ie: this:: - - dataset = myDataset() - dataloader = Dataloader(dataset) - - becomes:: - - dataset = myDataset() - dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - dataloader = Dataloader(dataset, sampler=dist_sampler) - - If you want each process to load the full dataset, ignore this warning. - """ - if msg not in self.shown_warnings and self.proc_rank == 0: - self.shown_warnings.add(msg) - warnings.warn(msg) - break - - def get_dataloaders(self, model): + def request_data_loader(self, data_loader_fx): """ - Dataloaders are provided by the model - :param model: + Handles downloading data in the GPU or TPU case. + + :param data_loader_fx: :return: """ - - self.init_train_dataloader(model) - self.init_test_dataloader(model) - self.init_val_dataloader(model) - + # get the function we'll use to get data if self.use_ddp or self.use_ddp2: - # wait for all processes to catch up - dist.barrier() + data_loader = data_loader_fx() - # load each dataloader - self.get_train_dataloader() - self.get_test_dataloaders() - self.get_val_dataloaders() + # all processes wait until data download has happened + dist.barrier() - # on TPUs load each dataloader only on process 0 - # this will trigger the data downloads - if self.use_tpu and XLA_AVAILABLE: - if self.tpu_local_core_rank == 0: - self.get_train_dataloader() - self.get_test_dataloaders() - self.get_val_dataloaders() + # data download/load on TPU + elif self.use_tpu and XLA_AVAILABLE: + data_loader = data_loader_fx() - # wait for all processes to catch up + # all processes wait until data download has happened torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders") - # support IterableDataset for train data - self.is_iterable_train_dataloader = ( - EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset)) - if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int): - m = ''' - When using an iterableDataset for `train_dataloader`, - `Trainer(val_check_interval)` must be an int. - An int k specifies checking validation every k training batches - ''' - raise MisconfigurationException(m) + # regular start + else: + data_loader = data_loader_fx() + + return data_loader def determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f5d2b9327f9fa..37847393a10cc 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -165,9 +165,11 @@ def __init__(self): self.checkpoint_callback = None self.current_epoch = None self.callback_metrics = None - self.get_test_dataloaders = None - self.get_val_dataloaders = None + self.test_dataloaders = None + self.val_dataloaders = None self.use_tpu = None + self.reload_dataloaders_every_epoch = None + self.progress_bar_refresh_rate = None @abstractmethod def copy_trainer_model_properties(self, model): @@ -204,6 +206,16 @@ def log_metrics(self, metrics, grad_norm_dic): # this is just empty shell for code from other class pass + @abstractmethod + def reset_test_dataloader(self, model): + # this is just empty shell for code from other class + pass + + @abstractmethod + def reset_val_dataloader(self, model): + # this is just empty shell for code from other class + pass + def evaluate(self, model, dataloaders, max_batches, test=False): """Run evaluation code. @@ -258,11 +270,12 @@ def evaluate(self, model, dataloaders, max_batches, test=False): dl_outputs.append(output) # batch done - if test: - self.test_progress_bar.update(1) - else: - self.val_progress_bar.update(1) - self.main_progress_bar.update(1) + if batch_idx % self.progress_bar_refresh_rate == 0: + if test: + self.test_progress_bar.update(self.progress_bar_refresh_rate) + else: + self.val_progress_bar.update(self.progress_bar_refresh_rate) + self.main_progress_bar.update(self.progress_bar_refresh_rate) outputs.append(dl_outputs) eval_results = {} @@ -288,8 +301,8 @@ def evaluate(self, model, dataloaders, max_batches, test=False): def run_evaluation(self, test=False): # when testing make sure user defined a test step - if test and not (self.is_overriden('test_step') and self.is_overriden('test_end')): - m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`. + if test and not (self.is_overriden('test_step')): + m = '''You called `.test()` without defining model's `.test_step()`. Please define and try again''' raise MisconfigurationException(m) @@ -299,11 +312,17 @@ def run_evaluation(self, test=False): # select dataloaders if test: - dataloaders = self.get_test_dataloaders() + if self.reload_dataloaders_every_epoch or self.test_dataloaders is None: + self.reset_test_dataloader(model) + + dataloaders = self.test_dataloaders max_batches = self.num_test_batches else: # val - dataloaders = self.get_val_dataloaders() + if self.reload_dataloaders_every_epoch or self.val_dataloaders is None: + self.reset_val_dataloader(model) + + dataloaders = self.val_dataloaders max_batches = self.num_val_batches # cap max batches to 1 when using fast_dev_run @@ -357,10 +376,10 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] - if test and len(self.get_test_dataloaders()) > 1: + if test and len(self.test_dataloaders) > 1: args.append(dataloader_idx) - elif not test and len(self.get_val_dataloaders()) > 1: + elif not test and len(self.val_dataloaders) > 1: args.append(dataloader_idx) # handle DP, DDP forward diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index e5afc90aebe6f..eb0d529d2681b 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -11,8 +11,9 @@ def is_function_implemented(self, f_name): f_op = getattr(model, f_name, None) return callable(f_op) - def is_overriden(self, f_name): - model = self.get_model() + def is_overriden(self, f_name, model=None): + if model is None: + model = self.get_model() super_object = LightningModule # when code pointers are different, it was overriden diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4d204bd287e36..855163900b62d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -79,6 +79,7 @@ def __init__( num_tpu_cores: Optional[int] = None, log_gpu_memory: Optional[str] = None, show_progress_bar: bool = True, + progress_bar_refresh_rate: int = 50, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, @@ -109,6 +110,7 @@ def __init__( truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[BaseProfiler] = None, + reload_dataloaders_every_epoch: bool = False, ): r""" @@ -284,6 +286,8 @@ def __init__( # default used by the Trainer trainer = Trainer(show_progress_bar=True) + progress_bar_refresh_rate: How often to refresh progress bar (in steps) + overfit_pct: uses this much data of all datasets. Example:: @@ -577,6 +581,7 @@ def __init__( # advanced profiler for function-level stats profiler = AdvancedProfiler() trainer = Trainer(profiler=profiler) + reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: @@ -592,7 +597,6 @@ def __init__( if not num_nodes: # in case you did not set the proper value num_nodes = nb_gpu_nodes self.num_gpu_nodes = num_nodes - self.log_gpu_memory = log_gpu_memory # Backward compatibility @@ -603,6 +607,8 @@ def __init__( gradient_clip_val = gradient_clip self.gradient_clip_val = gradient_clip_val + self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch + self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False @@ -672,9 +678,9 @@ def __init__( self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 - self.get_train_dataloader = None - self.get_test_dataloaders = None - self.get_val_dataloaders = None + self.train_dataloader = None + self.test_dataloaders = None + self.val_dataloaders = None self.is_iterable_train_dataloader = False # training state @@ -849,8 +855,8 @@ def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, - val_dataloader: Optional[DataLoader] = None, - test_dataloader: Optional[DataLoader] = None + val_dataloaders: Optional[DataLoader] = None, + test_dataloaders: Optional[DataLoader] = None ): r""" Runs the full optimization routine. @@ -862,13 +868,13 @@ def fit( DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. - val_dataloader: Either a single + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloader method this will be skipped + If the model has a predefined val_dataloaders method this will be skipped - test_dataloader: Either a single + test_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloader method this will be skipped + If the model has a predefined test_dataloaders method this will be skipped Example:: @@ -894,13 +900,10 @@ def fit( # feed to .fit() """ + # set up the passed in dataloaders (if needed) + self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) - # Update the dataloader attributes of the model with the ones supplied here, - # if they are not already defined in model - _set_dataloader(model, train_dataloader, 'train_dataloader') - _set_dataloader(model, val_dataloader, 'val_dataloader') - _set_dataloader(model, test_dataloader, 'test_dataloader') - + # route to appropriate start method # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: task = int(os.environ['SLURM_LOCALID']) @@ -944,6 +947,39 @@ def fit( # used for testing or when we need to know that training succeeded return 1 + def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders): + # when dataloader is passed via fit, patch the train_dataloader + # functions to overwrite with these implementations + if train_dataloader is not None: + if not self.is_overriden('training_step', model): + m = 'You called .fit() with a train_dataloader but did not define training_step()' + raise MisconfigurationException(m) + + def patch_train_dataloader(): + return train_dataloader + + model.train_dataloader = patch_train_dataloader + + if val_dataloaders is not None: + if not self.is_overriden('validation_step', model): + m = 'You called .fit() with a val_dataloaders but did not define validation_step()' + raise MisconfigurationException(m) + + def patch_val_dataloader(): + return val_dataloaders + + model.val_dataloader = patch_val_dataloader + + if test_dataloaders is not None: + if not self.is_overriden('test_step', model): + m = 'You called .fit() with a test_dataloaders but did not define test_step()' + raise MisconfigurationException(m) + + def patch_test_dataloader(): + return test_dataloaders + + model.test_dataloader = patch_test_dataloader + def init_optimizers( self, optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]] @@ -1010,9 +1046,6 @@ def run_pretrain_routine(self, model: LightningModule): # register auto-resubmit when on SLURM self.register_slurm_signal_handlers() - # transfer data loaders from model - self.get_dataloaders(ref_model) - # print model summary if self.proc_rank == 0 and self.weights_summary is not None: if self.weights_summary in ['full', 'top']: @@ -1028,11 +1061,20 @@ def run_pretrain_routine(self, model: LightningModule): # restore training and model before hpc call self.restore_weights(model) + # download the data and do whatever transforms we need + self.call_prepare_data(ref_model) + # when testing requested only run test and return if self.testing: + # only load test dataloader for testing + self.reset_test_dataloader(ref_model) self.run_evaluation(test=True) return + # load the dataloaders + self.reset_train_dataloader(ref_model) + self.reset_val_dataloader(ref_model) + # check if we should run validation during training self.disable_validation = ((self.num_val_batches == 0 or not self.is_overriden('validation_step')) and @@ -1045,14 +1087,14 @@ def run_pretrain_routine(self, model: LightningModule): if not self.disable_validation and self.num_sanity_val_steps > 0: # init progress bars for validation sanity check pbar = tqdm(desc='Validation sanity check', - total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), + total=self.num_sanity_val_steps * len(self.val_dataloaders), leave=False, position=2 * self.process_position, disable=not self.show_progress_bar, dynamic_ncols=True) self.main_progress_bar = pbar # dummy validation progress bar self.val_progress_bar = tqdm(disable=True) - eval_results = self.evaluate(model, self.get_val_dataloaders(), + eval_results = self.evaluate(model, self.val_dataloaders, self.num_sanity_val_steps, False) _, _, _, callback_metrics, _ = self.process_output(eval_results) @@ -1105,49 +1147,3 @@ def test(self, model: Optional[LightningModule] = None): self.fit(model) else: self.run_evaluation(test=True) - - -def _set_dataloader(model, dataloader, attribute): - r''' - Check dataloaders passed to .fit() method if they are pytorch DataLoader - objects and whether or not we should overright the corresponding dataloader - in the model - - Args: - model (LightningModule): The model to check - - dataloader: If a pytorch dataloader (or a list of pytorch dataloaders) - is passed, it will be incorporate into the model as model.attribute. - If attribute alreay exist it will warn the userpass. If not a - dataloader will throw an error - - attribute (str): The attribute to save the dataloader under - - ''' - # Check if attribute comes directly from base class or - # derived in user subclass - if LightningModule.__qualname__ in getattr(model, attribute).__qualname__: - # Val and test should be list of dataloaders - dataloader = dataloader if attribute == 'train_dataloader' or \ - (attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader] - - # Check we are given valid dataloaders - is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader) - is_dataloader_list = isinstance(dataloader, list) - if is_dataloader_list: - valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader) - if is_dataloader or is_dataloader_list and valid_loaders: - - # Overwrite abstract methods - dl = lambda: dataloader - dl.__name__ = attribute - setattr(model, attribute, dl) - - elif dataloader and dataloader != [None]: - raise ValueError(f'`{attribute}` needs to be an instance of ' - '`torch.utils.data.DataLoader` or a list of ' - 'DataLoaders, instead got %r`' % dataloader) - - elif dataloader: # if default (None) is passed, do not warn the user - warnings.warn(f'Model has predefined `{attribute}`,' - f' will skip `{attribute}={dataloader}` passed to fit method.') diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 6d12d6fe6fb10..0542d0a837f5d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -168,15 +168,9 @@ def training_step(self, batch, batch_idx): except ImportError: APEX_AVAILABLE = False -try: - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -except ImportError: - XLA_AVAILABLE = False - try: import torch_xla.distributed.parallel_loader as xla_pl + import torch_xla.core.xla_model as xm XLA_AVAILABLE = True @@ -226,11 +220,13 @@ def __init__(self): self.model = None self.running_loss = None self.training_tqdm_dict = None - self.get_train_dataloader = None self.reduce_lr_on_plateau_scheduler = None self.profiler = None self.batch_idx = None self.precision = None + self.train_dataloader = None + self.reload_dataloaders_every_epoch = None + self.progress_bar_refresh_rate = None @property def max_nb_epochs(self): @@ -305,6 +301,11 @@ def process_output(self, output, train): # this is just empty shell for code from other class pass + @abstractmethod + def reset_train_dataloader(self, model): + # this is just empty shell for code from other class + pass + def train(self): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) @@ -314,9 +315,9 @@ def train(self): # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # set seed for distributed sampler (enables shuffling for each epoch) - if (self.use_ddp or self.use_tpu) \ - and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): - self.get_train_dataloader().sampler.set_epoch(epoch) + if self.use_ddp \ + and hasattr(self.train_dataloader.sampler, 'set_epoch'): + self.train_dataloader.sampler.set_epoch(epoch) # get model model = self.get_model() @@ -394,6 +395,7 @@ def train(self): return self.run_training_teardown() + except KeyboardInterrupt: log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') self.run_training_teardown() @@ -405,18 +407,19 @@ def run_training_epoch(self): with self.profiler.profile('on_epoch_start'): model.on_epoch_start() - # request the dataloader - train_dataloader = self.get_train_dataloader() + # reset train dataloader + if self.reload_dataloaders_every_epoch: + self.reset_train_dataloader(self.get_model()) # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device() - train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) - train_dataloader = train_dataloader.per_device_loader(device) + self.train_dataloader = xla_pl.ParallelLoader(self.train_dataloader, [device]) + self.train_dataloader = self.train_dataloader.per_device_loader(device) # run epoch for batch_idx, batch in self.profiler.profile_iterable( - enumerate(train_dataloader), "get_train_batch" + enumerate(self.train_dataloader), "get_train_batch" ): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: @@ -591,11 +594,8 @@ def optimizer_closure(): # override function to modify this behavior model = self.get_model() with self.profiler.profile('optimizer_step'): - if self.use_tpu: - xm.optimizer_step(optimizer) - else: - model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, optimizer_closure) + model.optimizer_step(self.current_epoch, batch_idx, + optimizer, opt_idx, optimizer_closure) # calculate running loss for display self.running_loss.append(self.batch_loss_value) @@ -609,8 +609,9 @@ def optimizer_closure(): model.on_batch_end() # update progress bar - self.main_progress_bar.update(1) - self.main_progress_bar.set_postfix(**self.training_tqdm_dict) + if batch_idx % self.progress_bar_refresh_rate == 0: + self.main_progress_bar.update(self.progress_bar_refresh_rate) + self.main_progress_bar.set_postfix(**self.training_tqdm_dict) # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 3e6424bd4982a..df16bffd668b8 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -12,6 +12,10 @@ LightningTestMixin, LightningTestStepMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, + LightningTestFitSingleTestDataloadersMixin, + LightningTestFitMultipleTestDataloadersMixin, + LightningValStepFitSingleDataloaderMixin, + LightningValStepFitMultipleDataloadersMixin ) diff --git a/tests/models/base.py b/tests/models/base.py index d33f0118dfd05..2c17402624d63 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -150,30 +150,27 @@ def configure_optimizers(self): # test returning only 1 list instead of 2 return optimizer + def prepare_data(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = TestingMNIST(root=self.hparams.data_root, train=True, + transform=transform, download=True, num_samples=2000) + dataset = TestingMNIST(root=self.hparams.data_root, train=False, + transform=transform, download=True, num_samples=2000) + def _dataloader(self, train): # init data generators transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = TestingMNIST(root=self.hparams.data_root, train=train, - transform=transform, download=True, num_samples=2000) + transform=transform, download=False, num_samples=2000) # when using multi-node we need to add the datasampler - train_sampler = None batch_size = self.hparams.batch_size - try: - if self.use_ddp and not self.force_remove_distributed_sampler: - train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) - batch_size = batch_size // self.trainer.world_size # scale batch size - except Exception: - pass - - should_shuffle = train_sampler is None loader = DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=should_shuffle, - sampler=train_sampler ) return loader @@ -218,7 +215,6 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no cover class LightningTestModelBase(TestModelBase): """ with pre-defined train dataloader """ - @data_loader def train_dataloader(self): return self._dataloader(train=True) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 03da85d59d096..940f5e30f350a 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -11,7 +11,6 @@ class LightningValidationStepMixin: when val_dataloader returns a single dataloader """ - @data_loader def val_dataloader(self): return self._dataloader(train=False) @@ -105,7 +104,6 @@ class LightningValidationStepMultipleDataloadersMixin: when val_dataloader returns multiple dataloaders """ - @data_loader def val_dataloader(self): return [self._dataloader(train=False), self._dataloader(train=False)] @@ -204,7 +202,6 @@ def validation_end(self, outputs): class LightningTestStepMixin: - @data_loader def test_dataloader(self): return self._dataloader(train=False) @@ -289,7 +286,6 @@ def test_end(self, outputs): class LightningTestStepMultipleDataloadersMixin: - @data_loader def test_dataloader(self): return [self._dataloader(train=False), self._dataloader(train=False)] @@ -343,6 +339,198 @@ def test_step(self, batch, batch_idx, dataloader_idx): return output +class LightningTestFitSingleTestDataloadersMixin: + def test_step(self, batch, batch_idx): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_test = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + test_acc = torch.tensor(test_acc) + + if self.on_gpu: + test_acc = test_acc.cuda(loss_test.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_test = loss_test.unsqueeze(0) + test_acc = test_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_idx % 1 == 0: + output = OrderedDict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + }) + return output + if batch_idx % 2 == 0: + return test_acc + + if batch_idx % 3 == 0: + output = OrderedDict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + 'test_dic': {'test_loss_a': loss_test} + }) + return output + + +class LightningTestFitMultipleTestDataloadersMixin: + def test_step(self, batch, batch_idx, dataloader_idx): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_test = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + test_acc = torch.tensor(test_acc) + + if self.on_gpu: + test_acc = test_acc.cuda(loss_test.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_test = loss_test.unsqueeze(0) + test_acc = test_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_idx % 1 == 0: + output = OrderedDict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + }) + return output + if batch_idx % 2 == 0: + return test_acc + + if batch_idx % 3 == 0: + output = OrderedDict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + 'test_dic': {'test_loss_a': loss_test} + }) + return output + if batch_idx % 5 == 0: + output = OrderedDict({ + f'test_loss_{dataloader_idx}': loss_test, + f'test_acc_{dataloader_idx}': test_acc, + }) + return output + + +class LightningValStepFitSingleDataloaderMixin: + def validation_step(self, batch, batch_idx): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + + if self.on_gpu: + val_acc = val_acc.cuda(loss_val.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + val_acc = val_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_idx % 1 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + }) + return output + if batch_idx % 2 == 0: + return val_acc + + if batch_idx % 3 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + 'test_dic': {'val_loss_a': loss_val} + }) + return output + + +class LightningValStepFitMultipleDataloadersMixin: + def validation_step(self, batch, batch_idx, dataloader_idx): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + + if self.on_gpu: + val_acc = val_acc.cuda(loss_val.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + val_acc = val_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_idx % 1 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + }) + return output + if batch_idx % 2 == 0: + return val_acc + + if batch_idx % 3 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + 'test_dic': {'val_loss_a': loss_val} + }) + return output + if batch_idx % 5 == 0: + output = OrderedDict({ + f'val_loss_{dataloader_idx}': loss_val, + f'val_acc_{dataloader_idx}': val_acc, + }) + return output + + class LightningTestMultipleDataloadersMixin(LightningTestStepMultipleDataloadersMixin): def test_end(self, outputs): """ diff --git a/tests/models/utils.py b/tests/models/utils.py index 7a641b56ad6e2..644efb9c3ddd7 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -36,7 +36,11 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50): path_expt=trainer_options.get('default_save_path')) # test new model accuracy - for dataloader in model.test_dataloader(): + test_loaders = model.test_dataloader() + if not isinstance(test_loaders, list): + test_loaders = [test_loaders] + + for dataloader in test_loaders: run_prediction(dataloader, pretrained_model, min_acc=min_acc) if trainer.use_ddp: @@ -69,7 +73,11 @@ def run_model_test(trainer_options, model, on_gpu=True): pretrained_model = load_model(logger, trainer.checkpoint_callback.filepath) # test new model accuracy - [run_prediction(dataloader, pretrained_model) for dataloader in model.test_dataloader()] + test_loaders = model.test_dataloader() + if not isinstance(test_loaders, list): + test_loaders = [test_loaders] + + [run_prediction(dataloader, pretrained_model) for dataloader in test_loaders] if trainer.use_ddp or trainer.use_ddp2: # on hpc this would work fine... but need to hack it for the purpose of the test diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 37daf58c12aa9..5d5a20bb561bd 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -46,7 +46,7 @@ def test_lbfgs_cpu_model(tmpdir): trainer_options = dict( default_save_path=tmpdir, - max_epochs=1, + max_epochs=2, print_nan_grads=True, show_progress_bar=False, weights_summary='top', @@ -304,7 +304,6 @@ def training_step(self, batch, batch_idx, hiddens): 'hiddens': self.test_hidden, } - @data_loader def train_dataloader(self): return torch.utils.data.DataLoader( dataset=MockSeq2SeqDataset(), diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index e982d2f25afd8..a47b97f95ed91 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -124,7 +124,11 @@ def test_cpu_slurm_save_load(tmpdir): # predict with trained model before saving # make a prediction - for dataloader in model.test_dataloader(): + dataloaders = model.test_dataloader() + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: for batch in dataloader: break @@ -211,32 +215,6 @@ def test_multi_gpu_model_dp(tmpdir): memory.get_memory_profile('min_max') -def test_ddp_sampler_error(tmpdir): - """Make sure DDP + AMP work.""" - if not tutils.can_run_gpu_test(): - return - - tutils.reset_seed() - tutils.set_random_master_port() - - hparams = tutils.get_hparams() - model = LightningTestModel(hparams, force_remove_distributed_sampler=True) - - logger = tutils.get_test_tube_logger(tmpdir, True) - - trainer = Trainer( - logger=logger, - show_progress_bar=False, - max_epochs=1, - gpus=[0, 1], - distributed_backend='ddp', - precision=16 - ) - - with pytest.warns(UserWarning): - trainer.get_dataloaders(model) - - @pytest.fixture def mocked_device_count(monkeypatch): def device_count(): diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index 2bd033b09756f..ba347e474a326 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -53,7 +53,11 @@ def test_running_test_pretrained_model_ddp(tmpdir): new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) - for dataloader in model.test_dataloader(): + dataloaders = model.test_dataloader() + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: tutils.run_prediction(dataloader, pretrained_model) @@ -244,7 +248,7 @@ def assert_good_acc(): dp_model = new_trainer.model dp_model.eval() - dataloader = trainer.get_train_dataloader() + dataloader = trainer.train_dataloader tutils.run_prediction(dataloader, dp_model, dp=True) # new model @@ -311,7 +315,7 @@ def assert_good_acc(): # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model trainer.model.eval() - for dataloader in trainer.get_val_dataloaders(): + for dataloader in trainer.val_dataloaders: tutils.run_prediction(dataloader, trainer.model) model.on_train_start = assert_good_acc @@ -345,7 +349,11 @@ def test_model_saving_loading(tmpdir): assert result == 1, 'amp + ddp model failed to complete' # make a prediction - for dataloader in model.test_dataloader(): + dataloaders = model.test_dataloader() + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: for batch in dataloader: break diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 231ef9508adbf..bc5b2979de781 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -17,6 +17,10 @@ LightningValidationStepMixin, LightningValidationMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, + LightningTestFitSingleTestDataloadersMixin, + LightningTestFitMultipleTestDataloadersMixin, + LightningValStepFitMultipleDataloadersMixin, + LightningValStepFitSingleDataloaderMixin ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin @@ -405,11 +409,11 @@ class CurrentTestModel( assert result == 1 # verify there are 2 val loaders - assert len(trainer.get_val_dataloaders()) == 2, \ + assert len(trainer.val_dataloaders) == 2, \ 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set - for dataloader in trainer.get_val_dataloaders(): + for dataloader in trainer.val_dataloaders: tutils.run_prediction(dataloader, trainer.model) @@ -506,12 +510,14 @@ class CurrentTestModel( trainer = Trainer(**trainer_options) result = trainer.fit(model) + trainer.test() + # verify there are 2 val loaders - assert len(trainer.get_test_dataloaders()) == 2, \ + assert len(trainer.test_dataloaders) == 2, \ 'Multiple test_dataloaders not initiated properly' # make sure predictions are good for each test set - for dataloader in trainer.get_test_dataloaders(): + for dataloader in trainer.test_dataloaders: tutils.run_prediction(dataloader, trainer.model) # run the test method @@ -523,7 +529,7 @@ def test_train_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader + LightningTestModelBaseWithoutDataloader, ): pass @@ -549,7 +555,8 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader + LightningValStepFitSingleDataloaderMixin, + LightningTestModelBaseWithoutDataloader, ): pass @@ -567,10 +574,11 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloader=model._dataloader(train=False)) + val_dataloaders=model._dataloader(train=False)) + results = trainer.fit(model, **fit_options) - assert len(trainer.get_val_dataloaders()) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' + assert len(trainer.val_dataloaders) == 1, \ + f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' def test_all_dataloaders_passed_to_fit(tmpdir): @@ -578,7 +586,9 @@ def test_all_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader + LightningValStepFitSingleDataloaderMixin, + LightningTestFitSingleTestDataloadersMixin, + LightningTestModelBaseWithoutDataloader, ): pass @@ -596,14 +606,17 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloader=model._dataloader(train=False), - test_dataloader=model._dataloader(train=False)) + val_dataloaders=model._dataloader(train=False), + test_dataloaders=model._dataloader(train=False)) + results = trainer.fit(model, **fit_options) - assert len(trainer.get_val_dataloaders()) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' - assert len(trainer.get_test_dataloaders()) == 1, \ - f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + trainer.test() + + assert len(trainer.val_dataloaders) == 1, \ + f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' + assert len(trainer.test_dataloaders) == 1, \ + f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_multiple_dataloaders_passed_to_fit(tmpdir): @@ -611,7 +624,9 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader + LightningValStepFitMultipleDataloadersMixin, + LightningTestFitMultipleTestDataloadersMixin, + LightningTestModelBaseWithoutDataloader, ): pass @@ -629,16 +644,17 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloader=[model._dataloader(train=False), - model._dataloader(train=False)], - test_dataloader=[model._dataloader(train=False), - model._dataloader(train=False)]) + val_dataloaders=[model._dataloader(train=False), + model._dataloader(train=False)], + test_dataloaders=[model._dataloader(train=False), + model._dataloader(train=False)]) results = trainer.fit(model, **fit_options) + trainer.test() - assert len(trainer.get_val_dataloaders()) == 2, \ - f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' - assert len(trainer.get_test_dataloaders()) == 2, \ - f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + assert len(trainer.val_dataloaders) == 2, \ + f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' + assert len(trainer.test_dataloaders) == 2, \ + f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_mixing_of_dataloader_options(tmpdir): @@ -646,7 +662,9 @@ def test_mixing_of_dataloader_options(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBase + LightningValStepFitSingleDataloaderMixin, + LightningTestFitSingleTestDataloadersMixin, + LightningTestModelBase, ): pass @@ -663,18 +681,20 @@ class CurrentTestModel( # fit model trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloader=model._dataloader(train=False)) + fit_options = dict(val_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) # fit model trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloader=model._dataloader(train=False), - test_dataloader=model._dataloader(train=False)) + fit_options = dict(val_dataloaders=model._dataloader(train=False), + test_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) - assert len(trainer.get_val_dataloaders()) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' - assert len(trainer.get_test_dataloaders()) == 1, \ - f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + trainer.test() + + assert len(trainer.val_dataloaders) == 1, \ + f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' + assert len(trainer.test_dataloaders) == 1, \ + f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def _init_steps_model():