From 15bfd1449660c803853a0a95302e6504cd6f9ea3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 10:38:58 -0500 Subject: [PATCH 01/80] added get dataloaders directly using a getter --- pytorch_lightning/trainer/data_loading.py | 103 +++++++++++++++------- 1 file changed, 70 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 43e87928a6abe..1c34c65c6794e 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -47,6 +47,13 @@ def __init__(self): self.val_check_interval = None self.use_tpu = None self.tpu_local_core_rank = None + self.train_dataloader = None + self.num_training_batches = None + self.val_check_batch = None + self.val_dataloaders = None + self.num_val_batches = None + self.test_dataloaders = None + self.num_test_batches = None def _percent_range_check(self, name): value = getattr(self, name) @@ -63,15 +70,15 @@ def init_train_dataloader(self, model): :param model: :return: """ - self.get_train_dataloader = model.train_dataloader + self.trigger_data_downloads(model.train_dataloader, 'train_dataloader') # determine number of training batches - if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset): + if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset): self.num_training_batches = float('inf') else: self._percent_range_check('train_percent_check') - self.num_training_batches = len(self.get_train_dataloader()) + self.num_training_batches = len(self.train_dataloader) self.num_training_batches = int(self.num_training_batches * self.train_percent_check) # determine when to check validation @@ -92,7 +99,7 @@ def init_train_dataloader(self, model): on_ddp = self.use_ddp or self.use_ddp2 needs_sampler = on_ddp or self.use_tpu - if needs_sampler and not isinstance(self.get_train_dataloader().sampler, DistributedSampler): + if needs_sampler and not isinstance(self.train_dataloader.sampler, DistributedSampler): msg = """ You're using multiple gpus and multiple nodes, or TPUs without using a to assign a subset of your data to each process. To silence this warning, pass a @@ -119,21 +126,21 @@ def init_val_dataloader(self, model): :param model: :return: """ - self.get_val_dataloaders = model.val_dataloader + self.trigger_data_downloads(model.val_dataloader, 'val_dataloaders') self.num_val_batches = 0 # determine number of validation batches # val datasets could be none, 1 or 2+ - if self.get_val_dataloaders() is not None: + if self.val_dataloaders is not None: self._percent_range_check('val_percent_check') - self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) + self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders) self.num_val_batches = int(self.num_val_batches * self.val_percent_check) on_ddp = self.use_ddp or self.use_ddp2 needs_sampler = on_ddp or self.use_tpu - if needs_sampler and self.get_val_dataloaders() is not None: - for dataloader in self.get_val_dataloaders(): + if needs_sampler and self.val_dataloaders is not None: + for dataloader in self.val_dataloaders: if not isinstance(dataloader.sampler, DistributedSampler): msg = """ Your val_dataloader(s) don't use DistributedSampler. @@ -164,20 +171,21 @@ def init_test_dataloader(self, model): :param model: """ - self.get_test_dataloaders = model.test_dataloader + self.trigger_data_downloads(model.test_dataloader, 'test_dataloaders') + self.num_test_batches = 0 # determine number of test batches - if self.get_test_dataloaders() is not None: + if self.test_dataloaders is not None: self._percent_range_check('test_percent_check') - len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) + len_sum = sum(len(dataloader) for dataloader in self.test_dataloaders) self.num_test_batches = len_sum self.num_test_batches = int(self.num_test_batches * self.test_percent_check) on_ddp = self.use_ddp or self.use_ddp2 needs_sampler = on_ddp or self.use_tpu - if needs_sampler and self.get_test_dataloaders() is not None: - for dataloader in self.get_test_dataloaders(): + if needs_sampler and self.test_dataloaders is not None: + for dataloader in self.test_dataloaders: if not isinstance(dataloader.sampler, DistributedSampler): msg = """ Your `test_dataloader(s)` don't use DistributedSampler. @@ -204,37 +212,66 @@ def init_test_dataloader(self, model): warnings.warn(msg) break - def get_dataloaders(self, model): + def trigger_data_downloads(self, dataloader_fx, dataloader_name): """ - Dataloaders are provided by the model - :param model: + Handles downloading data in the GPU or TPU case. + + :param dataloader_fx: + :param dataloader_name: :return: """ + # get the function we'll use to get data - self.init_train_dataloader(model) - self.init_test_dataloader(model) - self.init_val_dataloader(model) - + # data download/load on GPU if self.use_ddp or self.use_ddp2: - # wait for all processes to catch up + if self.proc_rank == 0: + dataloader = dataloader_fx() + self.__setattr__(dataloader_name, dataloader) + + # all processes wait until data download has happened dist.barrier() - # load each dataloader - self.get_train_dataloader() - self.get_test_dataloaders() - self.get_val_dataloaders() + # get data from all other processes + if self.proc_rank != 0: + dataloader = dataloader_fx() + self.__setattr__(dataloader_name, dataloader) + + # all processes wait until data download has happened + dist.barrier() - # on TPUs load each dataloader only on process 0 - # this will trigger the data downloads - if self.use_tpu and XLA_AVAILABLE: + # data download/load on TPU + elif self.use_tpu and XLA_AVAILABLE: if self.tpu_local_core_rank == 0: - self.get_train_dataloader() - self.get_test_dataloaders() - self.get_val_dataloaders() + dataloader = dataloader_fx() + self.__setattr__(dataloader_name, dataloader) + + # all processes wait until data download has happened + torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders") + + # get data from all other processes + if self.proc_rank != 0: + dataloader = dataloader_fx() + self.__setattr__(dataloader_name, dataloader) - # wait for all processes to catch up + # all processes wait until data download has happened torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders") + # regular start + else: + dataloader = dataloader_fx() + self.__setattr__(dataloader_name, dataloader) + + def get_dataloaders(self, model): + """ + Dataloaders are provided by the model + :param model: + :return: + """ + + self.init_train_dataloader(model) + self.init_val_dataloader(model) + self.init_test_dataloader(model) + # support IterableDataset for train data self.is_iterable_train_dataloader = ( EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset)) From 6c7a37e02a3f969e0ca2cb370f52c00cbefc2614 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 10:42:47 -0500 Subject: [PATCH 02/80] deleted decorator --- pytorch_lightning/core/decorators.py | 30 +++++----------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 3448fda4d4864..005b220bac26c 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -1,5 +1,4 @@ -import traceback -from functools import wraps +import warnings def data_loader(fn): @@ -8,27 +7,8 @@ def data_loader(fn): :param fn: :return: """ - wraps(fn) - attr_name = '_lazy_' + fn.__name__ - @wraps(fn) - def _get_data_loader(self): - try: - value = getattr(self, attr_name) - except AttributeError: - try: - value = fn(self) # Lazy evaluation, done only once. - if ( - value is not None and - not isinstance(value, list) and - fn.__name__ in ['test_dataloader', 'val_dataloader'] - ): - value = [value] - except AttributeError as e: - # Guard against AttributeError suppression. (Issue #142) - traceback.print_exc() - error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) - raise RuntimeError(error) from e - setattr(self, attr_name, value) # Memoize evaluation. - return value + w = 'data_loader decorator was deprecated in 0.6.1 and will be removed in 0.8.0' + warnings.warn(w) - return _get_data_loader + value = fn() + return value From 38e699191834c944a61017cad966270b23b80de1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 13:56:43 -0500 Subject: [PATCH 03/80] added prepare_data hook --- pytorch_lightning/core/lightning.py | 30 +++++++++++++++++++--- pytorch_lightning/trainer/training_loop.py | 6 +++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7a09aabb3cd45..a5d055c3b9959 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -868,7 +868,33 @@ def tbptt_split_batch(self, batch, split_size): return splits - @data_loader + def prepare_data(self): + """Use this to download and prepare data. + In distributed (GPU, TPU), this will only be called once + + :return: PyTorch DataLoader + + This is called before requesting the dataloaders + + .. code-block:: python + + model.prepare_data() + model.train_dataloader() + model.val_dataloader() + model.test_dataloader() + + Example + ------- + + .. code-block:: python + + def prepare_data(self): + download_imagenet() + clean_imagenet() + cache_imagenet() + """ + return None + def train_dataloader(self): """Implement a PyTorch DataLoader @@ -908,7 +934,6 @@ def tng_dataloader(self): # todo: remove in v0.8.0 " and this method will be removed in v0.8.0", DeprecationWarning) return output - @data_loader def test_dataloader(self): r""" @@ -942,7 +967,6 @@ def test_dataloader(self): """ return None - @data_loader def val_dataloader(self): r""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 6d12d6fe6fb10..d08e65a54dd7f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -231,6 +231,7 @@ def __init__(self): self.profiler = None self.batch_idx = None self.precision = None + self.train_dataloader = None @property def max_nb_epochs(self): @@ -312,11 +313,12 @@ def train(self): model = self.get_model() try: # run all epochs + # TODO: finish replacing train_dataloader for epoch in range(self.current_epoch, self.max_epochs): # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_tpu) \ - and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): - self.get_train_dataloader().sampler.set_epoch(epoch) + and hasattr(self.train_dataloader.sampler, 'set_epoch'): + self.train_dataloader.sampler.set_epoch(epoch) # get model model = self.get_model() From 37be7e775756572675edf92845ff8c7c92aebd73 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:27:48 -0500 Subject: [PATCH 04/80] refactored dataloader init --- pytorch_lightning/trainer/data_loading.py | 220 ++++++++++------------ pytorch_lightning/trainer/trainer.py | 8 +- 2 files changed, 99 insertions(+), 129 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 1c34c65c6794e..1604140bf49b7 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -2,6 +2,9 @@ from abc import ABC import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import RandomSampler, SequentialSampler +from pytorch_lightning.utilities.debugging import MisconfigurationException try: # loading for pyTorch 1.3 @@ -14,9 +17,6 @@ EXIST_ITER_DATASET = False else: EXIST_ITER_DATASET = True -from torch.utils.data.distributed import DistributedSampler - -from pytorch_lightning.utilities.debugging import MisconfigurationException try: from apex import amp @@ -64,13 +64,70 @@ def _percent_range_check(self, name): if not 0. <= value <= 1.: raise ValueError(msg) - def init_train_dataloader(self, model): + def call_prepare_data(self, model): + """ + Let model download the data on proc==0 only + :param model: + """ + # download data on DDP+ + if self.use_ddp or self.use_ddp2: + if self.proc_rank == 0: + model.prepare_data() + + # all processes wait until data download has happened + dist.barrier() + + # data download/load on TPU + elif self.use_tpu and XLA_AVAILABLE: + if self.tpu_local_core_rank == 0: + model.prepare_data() + + # all processes wait until data download has happened + torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders") + + else: + # regular download + model.prepare_data() + + def auto_add_sampler(self, dataloader, train): + # TODO: verify + # do nothing when user gives a sampler + if dataloader.sampler is not None: + return + + if train: + if self.use_ddp or self.use_ddp2: + self.train_dataloader.sampler = DistributedSampler(self.train_dataloader.dataset) + elif self.use_tpu: + sampler = DistributedSampler( + self.train_dataloader.dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal() + ) + self.train_dataloader.sampler = sampler + else: + self.train_dataloader.sampler = RandomSampler(self.train_dataloader.dataset) + + # on not train + else: + if self.use_tpu: + sampler = DistributedSampler( + self.train_dataloader.dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal() + ) + self.train_dataloader.sampler = sampler + else: + self.train_dataloader.sampler = SequentialSampler(self.train_dataloader.dataset) + + def reset_train_dataloader(self, model): """ Dataloaders are provided by the model :param model: :return: """ - self.trigger_data_downloads(model.train_dataloader, 'train_dataloader') + self.train_dataloader = self.request_data_loader(model.train_dataloader) + self.num_training_batches = 0 # determine number of training batches if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset): @@ -97,36 +154,16 @@ def init_train_dataloader(self, model): self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - on_ddp = self.use_ddp or self.use_ddp2 - needs_sampler = on_ddp or self.use_tpu - if needs_sampler and not isinstance(self.train_dataloader.sampler, DistributedSampler): - msg = """ - You're using multiple gpus and multiple nodes, or TPUs without using a - to assign a subset of your data to each process. To silence this warning, pass a - DistributedSampler to your DataLoader. - - ie: this: - dataset = myDataset() - dataloader = Dataloader(dataset) - - becomes: - dataset = myDataset() - dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - dataloader = Dataloader(dataset, sampler=dist_sampler) - - If you want each process to load the full dataset, ignore this warning. - """ - if msg not in self.shown_warnings and self.proc_rank == 0: - self.shown_warnings.add(msg) - warnings.warn(msg) - - def init_val_dataloader(self, model): + # automatically add samplers + self.auto_add_sampler(self.train_dataloader, train=True) + + def reset_val_dataloader(self, model): """ Dataloaders are provided by the model :param model: :return: """ - self.trigger_data_downloads(model.val_dataloader, 'val_dataloaders') + self.val_dataloaders = self.request_data_loader(model.val_dataloader) self.num_val_batches = 0 # determine number of validation batches @@ -137,41 +174,17 @@ def init_val_dataloader(self, model): self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders) self.num_val_batches = int(self.num_val_batches * self.val_percent_check) - on_ddp = self.use_ddp or self.use_ddp2 - needs_sampler = on_ddp or self.use_tpu - if needs_sampler and self.val_dataloaders is not None: - for dataloader in self.val_dataloaders: - if not isinstance(dataloader.sampler, DistributedSampler): - msg = """ - Your val_dataloader(s) don't use DistributedSampler. - - You're using multiple gpus and multiple nodes, or TPUs without using a - DistributedSampler to assign a subset of your data to each process. - To silence this warning, pass a DistributedSampler to your DataLoader. - - ie: this: - dataset = myDataset() - dataloader = Dataloader(dataset) - - becomes: - dataset = myDataset() - dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - dataloader = Dataloader(dataset, sampler=dist_sampler) - - If you want each process to load the full dataset, ignore this warning. - """ - if msg not in self.shown_warnings and self.proc_rank == 0: - self.shown_warnings.add(msg) - warnings.warn(msg) - break - - def init_test_dataloader(self, model): + # add samplers + for dataloader in self.val_dataloaders: + self.auto_add_sampler(dataloader, train=False) + + def reset_test_dataloader(self, model): """Dataloaders are provided by the model. :param model: """ - - self.trigger_data_downloads(model.test_dataloader, 'test_dataloaders') + # get actual loader + self.test_dataloaders = self.request_data_loader(model.test_dataloader) self.num_test_batches = 0 # determine number of test batches @@ -182,99 +195,56 @@ def init_test_dataloader(self, model): self.num_test_batches = len_sum self.num_test_batches = int(self.num_test_batches * self.test_percent_check) - on_ddp = self.use_ddp or self.use_ddp2 - needs_sampler = on_ddp or self.use_tpu - if needs_sampler and self.test_dataloaders is not None: - for dataloader in self.test_dataloaders: - if not isinstance(dataloader.sampler, DistributedSampler): - msg = """ - Your `test_dataloader(s)` don't use DistributedSampler. - - You're using multiple gpus and multiple nodes, or TPUs without using a - DistributedSampler to assign a subset of your data to each process. - To silence this warning, pass a DistributedSampler to your DataLoader. - - ie: this:: - - dataset = myDataset() - dataloader = Dataloader(dataset) + # add samplers + for dataloader in self.test_dataloaders: + self.auto_add_sampler(dataloader, train=False) - becomes:: - - dataset = myDataset() - dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - dataloader = Dataloader(dataset, sampler=dist_sampler) - - If you want each process to load the full dataset, ignore this warning. - """ - if msg not in self.shown_warnings and self.proc_rank == 0: - self.shown_warnings.add(msg) - warnings.warn(msg) - break - - def trigger_data_downloads(self, dataloader_fx, dataloader_name): + def request_data_loader(self, data_loader_fx): """ Handles downloading data in the GPU or TPU case. - :param dataloader_fx: - :param dataloader_name: + :param data_loader_fx: :return: """ # get the function we'll use to get data - - # data download/load on GPU if self.use_ddp or self.use_ddp2: - if self.proc_rank == 0: - dataloader = dataloader_fx() - self.__setattr__(dataloader_name, dataloader) - - # all processes wait until data download has happened - dist.barrier() - - # get data from all other processes - if self.proc_rank != 0: - dataloader = dataloader_fx() - self.__setattr__(dataloader_name, dataloader) + data_loader = data_loader_fx() # all processes wait until data download has happened dist.barrier() # data download/load on TPU elif self.use_tpu and XLA_AVAILABLE: - if self.tpu_local_core_rank == 0: - dataloader = dataloader_fx() - self.__setattr__(dataloader_name, dataloader) - - # all processes wait until data download has happened - torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders") - - # get data from all other processes - if self.proc_rank != 0: - dataloader = dataloader_fx() - self.__setattr__(dataloader_name, dataloader) + data_loader = data_loader_fx() # all processes wait until data download has happened torch_xla.core.xla_model.rendezvous("pl.TrainerDataLoadingMixin.get_dataloaders") # regular start else: - dataloader = dataloader_fx() - self.__setattr__(dataloader_name, dataloader) + data_loader = data_loader_fx() + + return data_loader - def get_dataloaders(self, model): + def setup_dataloaders(self, model): """ - Dataloaders are provided by the model + Give the model a chance to provide the dataloaders and get data :param model: :return: """ - self.init_train_dataloader(model) - self.init_val_dataloader(model) - self.init_test_dataloader(model) + # download the data + self.call_prepare_data(model) + + # load the dataloaders + self.reset_train_dataloader(model) + self.reset_val_dataloader(model) + self.reset_test_dataloader(model) # support IterableDataset for train data self.is_iterable_train_dataloader = ( - EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset)) + EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset) + ) if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int): m = ''' When using an iterableDataset for `train_dataloader`, diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4d204bd287e36..22b586905426b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -672,9 +672,9 @@ def __init__( self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 - self.get_train_dataloader = None - self.get_test_dataloaders = None - self.get_val_dataloaders = None + self.train_dataloader = None + self.test_dataloaders = None + self.val_dataloaders = None self.is_iterable_train_dataloader = False # training state @@ -1011,7 +1011,7 @@ def run_pretrain_routine(self, model: LightningModule): self.register_slurm_signal_handlers() # transfer data loaders from model - self.get_dataloaders(ref_model) + self.setup_dataloaders(ref_model) # print model summary if self.proc_rank == 0 and self.weights_summary is not None: From db1115e0ed1e234dc7a3d4371addbd817d773c82 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:30:56 -0500 Subject: [PATCH 05/80] refactored dataloader init --- pytorch_lightning/trainer/data_loading.py | 39 +++++++---------------- pytorch_lightning/trainer/trainer.py | 12 +++++-- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 1604140bf49b7..515253fda562b 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -157,6 +157,18 @@ def reset_train_dataloader(self, model): # automatically add samplers self.auto_add_sampler(self.train_dataloader, train=True) + # support IterableDataset for train data + self.is_iterable_train_dataloader = ( + EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset) + ) + if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int): + m = ''' + When using an iterableDataset for `train_dataloader`, + `Trainer(val_check_interval)` must be an int. + An int k specifies checking validation every k training batches + ''' + raise MisconfigurationException(m) + def reset_val_dataloader(self, model): """ Dataloaders are provided by the model @@ -226,33 +238,6 @@ def request_data_loader(self, data_loader_fx): return data_loader - def setup_dataloaders(self, model): - """ - Give the model a chance to provide the dataloaders and get data - :param model: - :return: - """ - - # download the data - self.call_prepare_data(model) - - # load the dataloaders - self.reset_train_dataloader(model) - self.reset_val_dataloader(model) - self.reset_test_dataloader(model) - - # support IterableDataset for train data - self.is_iterable_train_dataloader = ( - EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset) - ) - if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int): - m = ''' - When using an iterableDataset for `train_dataloader`, - `Trainer(val_check_interval)` must be an int. - An int k specifies checking validation every k training batches - ''' - raise MisconfigurationException(m) - def determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct): """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 22b586905426b..aa57e098bbb73 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1010,9 +1010,6 @@ def run_pretrain_routine(self, model: LightningModule): # register auto-resubmit when on SLURM self.register_slurm_signal_handlers() - # transfer data loaders from model - self.setup_dataloaders(ref_model) - # print model summary if self.proc_rank == 0 and self.weights_summary is not None: if self.weights_summary in ['full', 'top']: @@ -1028,11 +1025,20 @@ def run_pretrain_routine(self, model: LightningModule): # restore training and model before hpc call self.restore_weights(model) + # download the data and do whatever transforms we need + self.call_prepare_data(model) + # when testing requested only run test and return if self.testing: + # only load test dataloader for testing + self.reset_test_dataloader(model) self.run_evaluation(test=True) return + # load the dataloaders + self.reset_train_dataloader(model) + self.reset_val_dataloader(model) + # check if we should run validation during training self.disable_validation = ((self.num_val_batches == 0 or not self.is_overriden('validation_step')) and From dbe3fc0022ac45a83999ca83555277f99db5ff6a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:39:12 -0500 Subject: [PATCH 06/80] added dataloader reset flag and main loop --- pytorch_lightning/trainer/trainer.py | 4 +++- pytorch_lightning/trainer/training_loop.py | 20 ++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aa57e098bbb73..7fee99e2f947b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -109,6 +109,7 @@ def __init__( truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[BaseProfiler] = None, + reload_dataloaders_every_epoch: bool = False ): r""" @@ -577,6 +578,7 @@ def __init__( # advanced profiler for function-level stats profiler = AdvancedProfiler() trainer = Trainer(profiler=profiler) + reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: @@ -592,7 +594,6 @@ def __init__( if not num_nodes: # in case you did not set the proper value num_nodes = nb_gpu_nodes self.num_gpu_nodes = num_nodes - self.log_gpu_memory = log_gpu_memory # Backward compatibility @@ -603,6 +604,7 @@ def __init__( gradient_clip_val = gradient_clip self.gradient_clip_val = gradient_clip_val + self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d08e65a54dd7f..8fffd81bf4d45 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -232,6 +232,7 @@ def __init__(self): self.batch_idx = None self.precision = None self.train_dataloader = None + self.reload_dataloaders_every_epoch = None @property def max_nb_epochs(self): @@ -306,6 +307,11 @@ def process_output(self, output, train): # this is just empty shell for code from other class pass + @abstractmethod + def reset_train_dataloader(self, model): + # this is just empty shell for code from other class + pass + def train(self): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) @@ -313,7 +319,6 @@ def train(self): model = self.get_model() try: # run all epochs - # TODO: finish replacing train_dataloader for epoch in range(self.current_epoch, self.max_epochs): # set seed for distributed sampler (enables shuffling for each epoch) if (self.use_ddp or self.use_tpu) \ @@ -396,6 +401,7 @@ def train(self): return self.run_training_teardown() + except KeyboardInterrupt: log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') self.run_training_teardown() @@ -407,18 +413,20 @@ def run_training_epoch(self): with self.profiler.profile('on_epoch_start'): model.on_epoch_start() - # request the dataloader - train_dataloader = self.get_train_dataloader() + # reset train dataloader + if self.reload_dataloaders_every_epoch: + self.reset_train_dataloader(self.get_model()) + self.train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device() - train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) - train_dataloader = train_dataloader.per_device_loader(device) + self.train_dataloader = xla_pl.ParallelLoader(self.train_dataloader, [device]) + self.train_dataloader = self.train_dataloader.per_device_loader(device) # run epoch for batch_idx, batch in self.profiler.profile_iterable( - enumerate(train_dataloader), "get_train_batch" + enumerate(self.train_dataloader), "get_train_batch" ): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: From bb642b8dfb6de86644c5bf5fdd14b2e1ce806733 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:40:56 -0500 Subject: [PATCH 07/80] added dataloader reset flag and main loop --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8fffd81bf4d45..099786c156b9e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -234,6 +234,7 @@ def __init__(self): self.train_dataloader = None self.reload_dataloaders_every_epoch = None + @property def max_nb_epochs(self): """ @@ -416,7 +417,6 @@ def run_training_epoch(self): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(self.get_model()) - self.train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: From 6d646fdd099010940a1e0338237b422465aab314 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:44:15 -0500 Subject: [PATCH 08/80] added dataloader reset flag and main loop --- pytorch_lightning/trainer/evaluation_loop.py | 23 ++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f5d2b9327f9fa..be5344d02281d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -165,9 +165,10 @@ def __init__(self): self.checkpoint_callback = None self.current_epoch = None self.callback_metrics = None - self.get_test_dataloaders = None - self.get_val_dataloaders = None + self.test_dataloaders = None + self.val_dataloaders = None self.use_tpu = None + self.reload_dataloaders_every_epoch = None @abstractmethod def copy_trainer_model_properties(self, model): @@ -204,6 +205,16 @@ def log_metrics(self, metrics, grad_norm_dic): # this is just empty shell for code from other class pass + @abstractmethod + def reset_test_dataloader(self): + # this is just empty shell for code from other class + pass + + @abstractmethod + def reset_val_dataloader(self): + # this is just empty shell for code from other class + pass + def evaluate(self, model, dataloaders, max_batches, test=False): """Run evaluation code. @@ -299,11 +310,15 @@ def run_evaluation(self, test=False): # select dataloaders if test: - dataloaders = self.get_test_dataloaders() + if self.reload_dataloaders_every_epoch: + self.reset_test_dataloader(model) + dataloaders = self.test_dataloaders max_batches = self.num_test_batches else: # val - dataloaders = self.get_val_dataloaders() + if self.reload_dataloaders_every_epoch: + self.reset_val_dataloader(model) + dataloaders = self.val_dataloaders max_batches = self.num_val_batches # cap max batches to 1 when using fast_dev_run From 412899c5c67b8095c02beb009e73f412518bdd16 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:48:20 -0500 Subject: [PATCH 09/80] made changes --- docs/source/hooks.rst | 6 ++--- pytorch_lightning/trainer/evaluation_loop.py | 4 +-- pytorch_lightning/trainer/trainer.py | 4 +-- pytorch_lightning/trainer/training_loop.py | 1 - tests/test_restore_models.py | 2 +- tests/test_trainer.py | 26 ++++++++++---------- 6 files changed, 21 insertions(+), 22 deletions(-) diff --git a/docs/source/hooks.rst b/docs/source/hooks.rst index fee74ea218155..0cade1c4a8eea 100644 --- a/docs/source/hooks.rst +++ b/docs/source/hooks.rst @@ -8,9 +8,9 @@ Training set-up - init_optimizers - configure_apex - configure_ddp -- get_train_dataloader -- get_test_dataloaders -- get_val_dataloaders +- train_dataloader +- test_dataloaders +- val_dataloaders - summarize - restore_weights diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index be5344d02281d..c92c528038de5 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -372,10 +372,10 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] - if test and len(self.get_test_dataloaders()) > 1: + if test and len(self.test_dataloaders) > 1: args.append(dataloader_idx) - elif not test and len(self.get_val_dataloaders()) > 1: + elif not test and len(self.val_dataloaders) > 1: args.append(dataloader_idx) # handle DP, DDP forward diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7fee99e2f947b..3328695903240 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1053,14 +1053,14 @@ def run_pretrain_routine(self, model: LightningModule): if not self.disable_validation and self.num_sanity_val_steps > 0: # init progress bars for validation sanity check pbar = tqdm(desc='Validation sanity check', - total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), + total=self.num_sanity_val_steps * len(self.val_dataloaders), leave=False, position=2 * self.process_position, disable=not self.show_progress_bar, dynamic_ncols=True) self.main_progress_bar = pbar # dummy validation progress bar self.val_progress_bar = tqdm(disable=True) - eval_results = self.evaluate(model, self.get_val_dataloaders(), + eval_results = self.evaluate(model, self.val_dataloaders, self.num_sanity_val_steps, False) _, _, _, callback_metrics, _ = self.process_output(eval_results) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 099786c156b9e..2fbbfa0ed6000 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -226,7 +226,6 @@ def __init__(self): self.model = None self.running_loss = None self.training_tqdm_dict = None - self.get_train_dataloader = None self.reduce_lr_on_plateau_scheduler = None self.profiler = None self.batch_idx = None diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index 2bd033b09756f..0ca35308fa5a9 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -311,7 +311,7 @@ def assert_good_acc(): # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model trainer.model.eval() - for dataloader in trainer.get_val_dataloaders(): + for dataloader in trainer.val_dataloaders: tutils.run_prediction(dataloader, trainer.model) model.on_train_start = assert_good_acc diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 231ef9508adbf..b642640528952 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -405,11 +405,11 @@ class CurrentTestModel( assert result == 1 # verify there are 2 val loaders - assert len(trainer.get_val_dataloaders()) == 2, \ + assert len(trainer.val_dataloaders) == 2, \ 'Multiple val_dataloaders not initiated properly' # make sure predictions are good for each val set - for dataloader in trainer.get_val_dataloaders(): + for dataloader in trainer.val_dataloaders: tutils.run_prediction(dataloader, trainer.model) @@ -569,8 +569,8 @@ class CurrentTestModel( fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) - assert len(trainer.get_val_dataloaders()) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' + assert len(trainer.val_dataloaders) == 1, \ + f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' def test_all_dataloaders_passed_to_fit(tmpdir): @@ -600,10 +600,10 @@ class CurrentTestModel( test_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) - assert len(trainer.get_val_dataloaders()) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' + assert len(trainer.val_dataloaders) == 1, \ + f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.get_test_dataloaders()) == 1, \ - f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_multiple_dataloaders_passed_to_fit(tmpdir): @@ -635,10 +635,10 @@ class CurrentTestModel( model._dataloader(train=False)]) results = trainer.fit(model, **fit_options) - assert len(trainer.get_val_dataloaders()) == 2, \ - f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' - assert len(trainer.get_test_dataloaders()) == 2, \ - f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + assert len(trainer.val_dataloaders) == 2, \ + f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' + assert len(trainer.test_dataloaders) == 2, \ + f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def test_mixing_of_dataloader_options(tmpdir): @@ -671,9 +671,9 @@ class CurrentTestModel( fit_options = dict(val_dataloader=model._dataloader(train=False), test_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) - assert len(trainer.get_val_dataloaders()) == 1, \ + assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' - assert len(trainer.get_test_dataloaders()) == 1, \ + assert len(trainer.test_dataloaders) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' From 252713f9630caa7aedf3511ab9ee34ef83f8da1f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:48:33 -0500 Subject: [PATCH 10/80] made changes --- pytorch_lightning/trainer/training_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2fbbfa0ed6000..b9c679bbe392d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -233,7 +233,6 @@ def __init__(self): self.train_dataloader = None self.reload_dataloaders_every_epoch = None - @property def max_nb_epochs(self): """ From de77ecea94e045072f4001432ec213441a2b3c43 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:55:40 -0500 Subject: [PATCH 11/80] made changes --- pytorch_lightning/core/decorators.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 005b220bac26c..f09c288ca8e7a 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -1,4 +1,5 @@ -import warnings +import traceback +from functools import wraps def data_loader(fn): @@ -7,8 +8,6 @@ def data_loader(fn): :param fn: :return: """ - w = 'data_loader decorator was deprecated in 0.6.1 and will be removed in 0.8.0' - warnings.warn(w) - - value = fn() - return value + def inner_fx(self): + return fn(self) + return inner_fx From f7235746d3ab6426ae5ffbf4977900d926b8fa96 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 14:56:23 -0500 Subject: [PATCH 12/80] made changes --- pytorch_lightning/core/decorators.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index f09c288ca8e7a..be4fd41d06ee8 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -1,5 +1,6 @@ import traceback from functools import wraps +import warnings def data_loader(fn): @@ -8,6 +9,9 @@ def data_loader(fn): :param fn: :return: """ + w = 'data_loader decorator deprecated in 0.6.1. Will remove 0.8.0' + warnings.warn(w) + def inner_fx(self): return fn(self) return inner_fx From 9009c2208b2bf26e19778858215e8f67bca62a5d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 15:00:41 -0500 Subject: [PATCH 13/80] made changes --- pytorch_lightning/trainer/data_loading.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 515253fda562b..0925672eff008 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -176,6 +176,8 @@ def reset_val_dataloader(self, model): :return: """ self.val_dataloaders = self.request_data_loader(model.val_dataloader) + if not isinstance(self.val_dataloaders, list): + self.val_dataloaders = [self.val_dataloaders] self.num_val_batches = 0 # determine number of validation batches @@ -197,6 +199,8 @@ def reset_test_dataloader(self, model): """ # get actual loader self.test_dataloaders = self.request_data_loader(model.test_dataloader) + if not isinstance(self.test_dataloaders, list): + self.test_dataloaders = [self.test_dataloaders] self.num_test_batches = 0 # determine number of test batches From 74984e3c4a6f03fe59d69901276b993204b81506 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 15:17:16 -0500 Subject: [PATCH 14/80] made changes --- pytorch_lightning/trainer/evaluation_loop.py | 12 +++++++----- pytorch_lightning/trainer/trainer.py | 6 +++++- pytorch_lightning/trainer/training_loop.py | 4 +++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c92c528038de5..62b4e6b864f52 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -169,6 +169,7 @@ def __init__(self): self.val_dataloaders = None self.use_tpu = None self.reload_dataloaders_every_epoch = None + self.progress_bar_refresh_rate = None @abstractmethod def copy_trainer_model_properties(self, model): @@ -269,11 +270,12 @@ def evaluate(self, model, dataloaders, max_batches, test=False): dl_outputs.append(output) # batch done - if test: - self.test_progress_bar.update(1) - else: - self.val_progress_bar.update(1) - self.main_progress_bar.update(1) + if batch_idx % self.progress_bar_refresh_rate == 0: + if test: + self.test_progress_bar.update(1) + else: + self.val_progress_bar.update(1) + self.main_progress_bar.update(1) outputs.append(dl_outputs) eval_results = {} diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3328695903240..2c75d03ae8cf3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -79,6 +79,7 @@ def __init__( num_tpu_cores: Optional[int] = None, log_gpu_memory: Optional[str] = None, show_progress_bar: bool = True, + progress_bar_refresh_rate: int = 100, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, @@ -109,7 +110,7 @@ def __init__( truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[BaseProfiler] = None, - reload_dataloaders_every_epoch: bool = False + reload_dataloaders_every_epoch: bool = False, ): r""" @@ -285,6 +286,8 @@ def __init__( # default used by the Trainer trainer = Trainer(show_progress_bar=True) + progress_bar_refresh_rate: How often to refresh progress bar (in steps) + overfit_pct: uses this much data of all datasets. Example:: @@ -605,6 +608,7 @@ def __init__( self.gradient_clip_val = gradient_clip_val self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch + self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b9c679bbe392d..e216bfcca5193 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -232,6 +232,7 @@ def __init__(self): self.precision = None self.train_dataloader = None self.reload_dataloaders_every_epoch = None + self.progress_bar_refresh_rate = None @property def max_nb_epochs(self): @@ -617,7 +618,8 @@ def optimizer_closure(): model.on_batch_end() # update progress bar - self.main_progress_bar.update(1) + if batch_idx % self.progress_bar_refresh_rate == 0: + self.main_progress_bar.update(1) self.main_progress_bar.set_postfix(**self.training_tqdm_dict) # collapse all metrics into one dict From 838879b86bfb5d047c3e8ae10b5665d7d2224ae0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 15:25:44 -0500 Subject: [PATCH 15/80] made changes --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e216bfcca5193..47f2c2f81f88f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -620,7 +620,7 @@ def optimizer_closure(): # update progress bar if batch_idx % self.progress_bar_refresh_rate == 0: self.main_progress_bar.update(1) - self.main_progress_bar.set_postfix(**self.training_tqdm_dict) + self.main_progress_bar.set_postfix(**self.training_tqdm_dict) # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} From 14f3a1d0f8b273d9212603e98d7614f27da59109 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 15:28:07 -0500 Subject: [PATCH 16/80] made changes --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2c75d03ae8cf3..9dcd6fec7364a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -79,7 +79,7 @@ def __init__( num_tpu_cores: Optional[int] = None, log_gpu_memory: Optional[str] = None, show_progress_bar: bool = True, - progress_bar_refresh_rate: int = 100, + progress_bar_refresh_rate: int = 10, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, From 93f6b19f187708dd94a24780120c8838e5c249cd Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 15:30:16 -0500 Subject: [PATCH 17/80] made changes --- pytorch_lightning/trainer/evaluation_loop.py | 6 +++--- pytorch_lightning/trainer/training_loop.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 62b4e6b864f52..5bdac334257d3 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -272,10 +272,10 @@ def evaluate(self, model, dataloaders, max_batches, test=False): # batch done if batch_idx % self.progress_bar_refresh_rate == 0: if test: - self.test_progress_bar.update(1) + self.test_progress_bar.update(self.progress_bar_refresh_rate) else: - self.val_progress_bar.update(1) - self.main_progress_bar.update(1) + self.val_progress_bar.update(self.progress_bar_refresh_rate) + self.main_progress_bar.update(self.progress_bar_refresh_rate) outputs.append(dl_outputs) eval_results = {} diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 47f2c2f81f88f..8a6ea5492449e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -619,7 +619,7 @@ def optimizer_closure(): # update progress bar if batch_idx % self.progress_bar_refresh_rate == 0: - self.main_progress_bar.update(1) + self.main_progress_bar.update(self.progress_bar_refresh_rate) self.main_progress_bar.set_postfix(**self.training_tqdm_dict) # collapse all metrics into one dict From b021212adfa078a737c4ec6c3d73f4fda4c87542 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 15:32:05 -0500 Subject: [PATCH 18/80] made changes --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9dcd6fec7364a..019476922fc62 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -79,7 +79,7 @@ def __init__( num_tpu_cores: Optional[int] = None, log_gpu_memory: Optional[str] = None, show_progress_bar: bool = True, - progress_bar_refresh_rate: int = 10, + progress_bar_refresh_rate: int = 50, overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, From 45db9be7bf062791e25913e059121d7d09fd94ec Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:15:11 -0500 Subject: [PATCH 19/80] made changes --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8a6ea5492449e..1f6ad4511acc2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -321,7 +321,7 @@ def train(self): # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # set seed for distributed sampler (enables shuffling for each epoch) - if (self.use_ddp or self.use_tpu) \ + if self.use_ddp \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) From 20b5c6242edb516dbe410adbfee5b87b38bb68e9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:29:12 -0500 Subject: [PATCH 20/80] made changes --- pytorch_lightning/trainer/data_loading.py | 58 ++++++++++++++++------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 0925672eff008..e7426901de839 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -3,7 +3,7 @@ import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import RandomSampler, SequentialSampler +from torch.utils.data import RandomSampler, SequentialSampler, DataLoader from pytorch_lightning.utilities.debugging import MisconfigurationException try: @@ -90,23 +90,37 @@ def call_prepare_data(self, model): model.prepare_data() def auto_add_sampler(self, dataloader, train): - # TODO: verify # do nothing when user gives a sampler if dataloader.sampler is not None: return + dl_args = { + 'dataset': dataloader.dataset, + 'batch_size': dataloader.batch_size, + 'shuffle': dataloader.shuffle, + 'batch_sampler': dataloader.batch_sampler, + 'num_workers': dataloader.num_workers, + 'collate_fn': dataloader.collate_fn, + 'pin_memory': dataloader.pin_memory, + 'drop_last': dataloader.drop_last, + 'timeout': dataloader.timeout, + 'worker_init_fn': None + } + if train: if self.use_ddp or self.use_ddp2: - self.train_dataloader.sampler = DistributedSampler(self.train_dataloader.dataset) + sampler = DistributedSampler(self.train_dataloader.dataset) + dl_args['shuffle'] = False + elif self.use_tpu: sampler = DistributedSampler( self.train_dataloader.dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) - self.train_dataloader.sampler = sampler + dl_args['shuffle'] = False else: - self.train_dataloader.sampler = RandomSampler(self.train_dataloader.dataset) + sampler = RandomSampler(self.train_dataloader.dataset) # on not train else: @@ -116,9 +130,15 @@ def auto_add_sampler(self, dataloader, train): num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) - self.train_dataloader.sampler = sampler + dl_args['shuffle'] = False else: - self.train_dataloader.sampler = SequentialSampler(self.train_dataloader.dataset) + sampler = SequentialSampler(self.train_dataloader.dataset) + + dl_args['sampler'] = sampler + dl_args['dataset'] + + new_dataloader = DataLoader(**dl_args) + return new_dataloader def reset_train_dataloader(self, model): """ @@ -129,6 +149,9 @@ def reset_train_dataloader(self, model): self.train_dataloader = self.request_data_loader(model.train_dataloader) self.num_training_batches = 0 + # automatically add samplers + self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) + # determine number of training batches if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset): self.num_training_batches = float('inf') @@ -154,9 +177,6 @@ def reset_train_dataloader(self, model): self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - # automatically add samplers - self.auto_add_sampler(self.train_dataloader, train=True) - # support IterableDataset for train data self.is_iterable_train_dataloader = ( EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset) @@ -180,6 +200,11 @@ def reset_val_dataloader(self, model): self.val_dataloaders = [self.val_dataloaders] self.num_val_batches = 0 + # add samplers + for i, dataloader in enumerate(self.val_dataloaders): + dl = self.auto_add_sampler(dataloader, train=False) + self.val_dataloaders[i] = dl + # determine number of validation batches # val datasets could be none, 1 or 2+ if self.val_dataloaders is not None: @@ -188,10 +213,6 @@ def reset_val_dataloader(self, model): self.num_val_batches = sum(len(dataloader) for dataloader in self.val_dataloaders) self.num_val_batches = int(self.num_val_batches * self.val_percent_check) - # add samplers - for dataloader in self.val_dataloaders: - self.auto_add_sampler(dataloader, train=False) - def reset_test_dataloader(self, model): """Dataloaders are provided by the model. @@ -203,6 +224,11 @@ def reset_test_dataloader(self, model): self.test_dataloaders = [self.test_dataloaders] self.num_test_batches = 0 + # add samplers + for i, dataloader in enumerate(self.test_dataloaders): + dl = self.auto_add_sampler(dataloader, train=False) + self.test_dataloaders[i] = dl + # determine number of test batches if self.test_dataloaders is not None: self._percent_range_check('test_percent_check') @@ -211,10 +237,6 @@ def reset_test_dataloader(self, model): self.num_test_batches = len_sum self.num_test_batches = int(self.num_test_batches * self.test_percent_check) - # add samplers - for dataloader in self.test_dataloaders: - self.auto_add_sampler(dataloader, train=False) - def request_data_loader(self, data_loader_fx): """ Handles downloading data in the GPU or TPU case. From 189bbb17fd3337ecb3c182d416077ca382b564df Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:32:44 -0500 Subject: [PATCH 21/80] made changes --- pytorch_lightning/trainer/data_loading.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index e7426901de839..cb2a18b7741c7 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -104,38 +104,37 @@ def auto_add_sampler(self, dataloader, train): 'pin_memory': dataloader.pin_memory, 'drop_last': dataloader.drop_last, 'timeout': dataloader.timeout, - 'worker_init_fn': None + 'worker_init_fn': dataloader.worker_init_fn } if train: if self.use_ddp or self.use_ddp2: - sampler = DistributedSampler(self.train_dataloader.dataset) + sampler = DistributedSampler(dataloader.dataset) dl_args['shuffle'] = False elif self.use_tpu: sampler = DistributedSampler( - self.train_dataloader.dataset, + dataloader.dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) dl_args['shuffle'] = False else: - sampler = RandomSampler(self.train_dataloader.dataset) + sampler = RandomSampler(dataloader.dataset) # on not train else: if self.use_tpu: sampler = DistributedSampler( - self.train_dataloader.dataset, + dataloader.dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) dl_args['shuffle'] = False else: - sampler = SequentialSampler(self.train_dataloader.dataset) + sampler = SequentialSampler(dataloader.dataset) dl_args['sampler'] = sampler - dl_args['dataset'] new_dataloader = DataLoader(**dl_args) return new_dataloader From 0a45b1a1b7e65b4e78ad5e6e7acdbfd00e6691b6 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:34:19 -0500 Subject: [PATCH 22/80] made changes --- pytorch_lightning/trainer/data_loading.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index cb2a18b7741c7..126541a43ea20 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -149,6 +149,7 @@ def reset_train_dataloader(self, model): self.num_training_batches = 0 # automatically add samplers + import pdb; pdb.set_trace() self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) # determine number of training batches From 767ad23d750fe9b5e3747860e9cb6fa7efa36892 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:38:29 -0500 Subject: [PATCH 23/80] made changes --- pytorch_lightning/trainer/data_loading.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 126541a43ea20..cb2a18b7741c7 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -149,7 +149,6 @@ def reset_train_dataloader(self, model): self.num_training_batches = 0 # automatically add samplers - import pdb; pdb.set_trace() self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) # determine number of training batches From 56c0654f77300808adf0b2e14b9915ede8ebfeea Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:40:05 -0500 Subject: [PATCH 24/80] made changes --- pytorch_lightning/trainer/data_loading.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index cb2a18b7741c7..2b4a40dd91c82 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -92,7 +92,8 @@ def call_prepare_data(self, model): def auto_add_sampler(self, dataloader, train): # do nothing when user gives a sampler if dataloader.sampler is not None: - return + print('returning', dataloader.sampler) + return dataloader dl_args = { 'dataset': dataloader.dataset, From 783b5c7c5be9067acdfb4fed979d02eedf3caf6b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:55:08 -0500 Subject: [PATCH 25/80] made changes --- pytorch_lightning/trainer/data_loading.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 2b4a40dd91c82..38e4803050952 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -3,7 +3,7 @@ import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import RandomSampler, SequentialSampler, DataLoader +from torch.utils.data import RandomSampler, SequentialSampler, DataLoader, BatchSampler from pytorch_lightning.utilities.debugging import MisconfigurationException try: @@ -91,10 +91,6 @@ def call_prepare_data(self, model): def auto_add_sampler(self, dataloader, train): # do nothing when user gives a sampler - if dataloader.sampler is not None: - print('returning', dataloader.sampler) - return dataloader - dl_args = { 'dataset': dataloader.dataset, 'batch_size': dataloader.batch_size, @@ -135,7 +131,10 @@ def auto_add_sampler(self, dataloader, train): else: sampler = SequentialSampler(dataloader.dataset) + batch_sampler = BatchSampler(sampler, dl_args['batch_size'], dl_args['drop_last']) + dl_args['sampler'] = sampler + dl_args['batch_sampler'] = batch_sampler new_dataloader = DataLoader(**dl_args) return new_dataloader From 8183e82a62604c558e62e4523f11a0519142db4c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:57:29 -0500 Subject: [PATCH 26/80] made changes --- pytorch_lightning/trainer/data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 38e4803050952..31c72d623d733 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -94,7 +94,7 @@ def auto_add_sampler(self, dataloader, train): dl_args = { 'dataset': dataloader.dataset, 'batch_size': dataloader.batch_size, - 'shuffle': dataloader.shuffle, + 'shuffle': True, 'batch_sampler': dataloader.batch_sampler, 'num_workers': dataloader.num_workers, 'collate_fn': dataloader.collate_fn, From 2a2a6efeabeca9ef5f36ab04c9be884974cce208 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 16:59:24 -0500 Subject: [PATCH 27/80] made changes --- pytorch_lightning/trainer/data_loading.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 31c72d623d733..420b3463ccccd 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -95,7 +95,6 @@ def auto_add_sampler(self, dataloader, train): 'dataset': dataloader.dataset, 'batch_size': dataloader.batch_size, 'shuffle': True, - 'batch_sampler': dataloader.batch_sampler, 'num_workers': dataloader.num_workers, 'collate_fn': dataloader.collate_fn, 'pin_memory': dataloader.pin_memory, @@ -131,10 +130,7 @@ def auto_add_sampler(self, dataloader, train): else: sampler = SequentialSampler(dataloader.dataset) - batch_sampler = BatchSampler(sampler, dl_args['batch_size'], dl_args['drop_last']) - dl_args['sampler'] = sampler - dl_args['batch_sampler'] = batch_sampler new_dataloader = DataLoader(**dl_args) return new_dataloader From 8347a189a56e7de980771b489555bf062067cac7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:07:37 -0500 Subject: [PATCH 28/80] made changes --- pytorch_lightning/trainer/data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 420b3463ccccd..51df38c76d883 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -94,7 +94,7 @@ def auto_add_sampler(self, dataloader, train): dl_args = { 'dataset': dataloader.dataset, 'batch_size': dataloader.batch_size, - 'shuffle': True, + 'shuffle': False, 'num_workers': dataloader.num_workers, 'collate_fn': dataloader.collate_fn, 'pin_memory': dataloader.pin_memory, From 176d62d905f24f76989339882def51cac2e9d724 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:09:07 -0500 Subject: [PATCH 29/80] made changes --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 019476922fc62..e76a666441de0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1032,7 +1032,7 @@ def run_pretrain_routine(self, model: LightningModule): self.restore_weights(model) # download the data and do whatever transforms we need - self.call_prepare_data(model) + self.call_prepare_data(ref_model) # when testing requested only run test and return if self.testing: From 4e3fb96f811e744839113247617395923be42f02 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:10:32 -0500 Subject: [PATCH 30/80] made changes --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e76a666441de0..1dd0d829b6393 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1037,13 +1037,13 @@ def run_pretrain_routine(self, model: LightningModule): # when testing requested only run test and return if self.testing: # only load test dataloader for testing - self.reset_test_dataloader(model) + self.reset_test_dataloader(ref_model) self.run_evaluation(test=True) return # load the dataloaders - self.reset_train_dataloader(model) - self.reset_val_dataloader(model) + self.reset_train_dataloader(ref_model) + self.reset_val_dataloader(ref_model) # check if we should run validation during training self.disable_validation = ((self.num_val_batches == 0 or From 51cc57fe3d5a9abf1f30ed42b59fa4e34d630e37 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:18:14 -0500 Subject: [PATCH 31/80] made changes --- pytorch_lightning/trainer/evaluation_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 5bdac334257d3..499c20aa6a1b9 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -207,12 +207,12 @@ def log_metrics(self, metrics, grad_norm_dic): pass @abstractmethod - def reset_test_dataloader(self): + def reset_test_dataloader(self, model): # this is just empty shell for code from other class pass @abstractmethod - def reset_val_dataloader(self): + def reset_val_dataloader(self, model): # this is just empty shell for code from other class pass From c55bb0d5a65dafdbd87d391c1244ad92dd781215 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:23:04 -0500 Subject: [PATCH 32/80] made changes --- pytorch_lightning/trainer/data_loading.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 51df38c76d883..d2bfc952feb88 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -141,7 +141,7 @@ def reset_train_dataloader(self, model): :param model: :return: """ - self.train_dataloader = self.request_data_loader(model.train_dataloader) + self.train_dataloader = self.request_data_loader(model.train_dataloader()) self.num_training_batches = 0 # automatically add samplers @@ -190,7 +190,7 @@ def reset_val_dataloader(self, model): :param model: :return: """ - self.val_dataloaders = self.request_data_loader(model.val_dataloader) + self.val_dataloaders = self.request_data_loader(model.val_dataloader()) if not isinstance(self.val_dataloaders, list): self.val_dataloaders = [self.val_dataloaders] self.num_val_batches = 0 @@ -214,7 +214,7 @@ def reset_test_dataloader(self, model): :param model: """ # get actual loader - self.test_dataloaders = self.request_data_loader(model.test_dataloader) + self.test_dataloaders = self.request_data_loader(model.test_dataloader()) if not isinstance(self.test_dataloaders, list): self.test_dataloaders = [self.test_dataloaders] self.num_test_batches = 0 From 6fe933b8225007956605127d2e9f7109133e29a7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:24:54 -0500 Subject: [PATCH 33/80] made changes --- pytorch_lightning/trainer/data_loading.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index d2bfc952feb88..51df38c76d883 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -141,7 +141,7 @@ def reset_train_dataloader(self, model): :param model: :return: """ - self.train_dataloader = self.request_data_loader(model.train_dataloader()) + self.train_dataloader = self.request_data_loader(model.train_dataloader) self.num_training_batches = 0 # automatically add samplers @@ -190,7 +190,7 @@ def reset_val_dataloader(self, model): :param model: :return: """ - self.val_dataloaders = self.request_data_loader(model.val_dataloader()) + self.val_dataloaders = self.request_data_loader(model.val_dataloader) if not isinstance(self.val_dataloaders, list): self.val_dataloaders = [self.val_dataloaders] self.num_val_batches = 0 @@ -214,7 +214,7 @@ def reset_test_dataloader(self, model): :param model: """ # get actual loader - self.test_dataloaders = self.request_data_loader(model.test_dataloader()) + self.test_dataloaders = self.request_data_loader(model.test_dataloader) if not isinstance(self.test_dataloaders, list): self.test_dataloaders = [self.test_dataloaders] self.num_test_batches = 0 From d43c9e766cab1d6da3b1dbdd9c9d8c01e2155325 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:31:12 -0500 Subject: [PATCH 34/80] made changes --- pytorch_lightning/trainer/trainer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1dd0d829b6393..8b34f38cda15e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -903,9 +903,12 @@ def fit( # Update the dataloader attributes of the model with the ones supplied here, # if they are not already defined in model - _set_dataloader(model, train_dataloader, 'train_dataloader') - _set_dataloader(model, val_dataloader, 'val_dataloader') - _set_dataloader(model, test_dataloader, 'test_dataloader') + if train_dataloader is not None: + _set_dataloader(model, train_dataloader, 'train_dataloader') + if val_dataloader is not None: + _set_dataloader(model, val_dataloader, 'val_dataloader') + if test_dataloader is not None: + _set_dataloader(model, test_dataloader, 'test_dataloader') # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: From 05ab2db15c2c91854959b26ff33ac30da07887c8 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:36:38 -0500 Subject: [PATCH 35/80] made changes --- pl_examples/basic_examples/lightning_module_template.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index 81cdf2acba6be..d3c9c341aaf32 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -212,17 +212,14 @@ def __dataloader(self, train): return loader - @pl.data_loader def train_dataloader(self): log.info('Training data loader called.') return self.__dataloader(train=True) - @pl.data_loader def val_dataloader(self): log.info('Validation data loader called.') return self.__dataloader(train=False) - @pl.data_loader def test_dataloader(self): log.info('Test data loader called.') return self.__dataloader(train=False) From d165b46b34dcfa366abecf0ed148990f2d6c574c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:41:12 -0500 Subject: [PATCH 36/80] made changes --- .../basic_examples/lightning_module_template.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index d3c9c341aaf32..500b9eb40bafd 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -192,7 +192,7 @@ def __dataloader(self, train): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root=self.hparams.data_root, train=train, - transform=transform, download=True) + transform=transform, download=False) # when using multi-node (ddp) we need to add the datasampler train_sampler = None @@ -212,6 +212,14 @@ def __dataloader(self, train): return loader + def prepare_data(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root=self.hparams.data_root, train=True, + transform=transform, download=True) + dataset = MNIST(root=self.hparams.data_root, train=False, + transform=transform, download=True) + def train_dataloader(self): log.info('Training data loader called.') return self.__dataloader(train=True) From c5535e8196f1232795b4c247a5cc8599aa6c10c7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:42:00 -0500 Subject: [PATCH 37/80] made changes --- tests/test_cpu_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 37daf58c12aa9..a370993ab1a06 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -55,6 +55,7 @@ def test_lbfgs_cpu_model(tmpdir): ) model, hparams = tutils.get_model(use_test_model=True, lbfgs=True) + import pdb; pdb.set_trace() tutils.run_model_test_no_loggers(trainer_options, model, min_acc=0.30) From 1e562815ea17bdbdfed4d78629077cd7bb613d09 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:51:06 -0500 Subject: [PATCH 38/80] made changes --- tests/models/utils.py | 6 +++++- tests/test_cpu_models.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index 7a641b56ad6e2..2225c107ad17e 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -36,7 +36,11 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50): path_expt=trainer_options.get('default_save_path')) # test new model accuracy - for dataloader in model.test_dataloader(): + test_loaders = model.test_dataloader() + if not isinstance(test_loaders, list): + test_loaders = [test_loaders] + + for dataloader in test_loaders: run_prediction(dataloader, pretrained_model, min_acc=min_acc) if trainer.use_ddp: diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index a370993ab1a06..37daf58c12aa9 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -55,7 +55,6 @@ def test_lbfgs_cpu_model(tmpdir): ) model, hparams = tutils.get_model(use_test_model=True, lbfgs=True) - import pdb; pdb.set_trace() tutils.run_model_test_no_loggers(trainer_options, model, min_acc=0.30) From 7623a27753fc3fe46cf689bc9a8cebec148830a4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 17:55:00 -0500 Subject: [PATCH 39/80] made changes --- tests/models/base.py | 1 - tests/models/mixins.py | 4 ---- tests/test_cpu_models.py | 1 - 3 files changed, 6 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index d33f0118dfd05..5e7d4153e6bec 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -218,7 +218,6 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no cover class LightningTestModelBase(TestModelBase): """ with pre-defined train dataloader """ - @data_loader def train_dataloader(self): return self._dataloader(train=True) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 03da85d59d096..d4261ef052ced 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -11,7 +11,6 @@ class LightningValidationStepMixin: when val_dataloader returns a single dataloader """ - @data_loader def val_dataloader(self): return self._dataloader(train=False) @@ -105,7 +104,6 @@ class LightningValidationStepMultipleDataloadersMixin: when val_dataloader returns multiple dataloaders """ - @data_loader def val_dataloader(self): return [self._dataloader(train=False), self._dataloader(train=False)] @@ -204,7 +202,6 @@ def validation_end(self, outputs): class LightningTestStepMixin: - @data_loader def test_dataloader(self): return self._dataloader(train=False) @@ -289,7 +286,6 @@ def test_end(self, outputs): class LightningTestStepMultipleDataloadersMixin: - @data_loader def test_dataloader(self): return [self._dataloader(train=False), self._dataloader(train=False)] diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 37daf58c12aa9..b5f2121d532cd 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -304,7 +304,6 @@ def training_step(self, batch, batch_idx, hiddens): 'hiddens': self.test_hidden, } - @data_loader def train_dataloader(self): return torch.utils.data.DataLoader( dataset=MockSeq2SeqDataset(), From 36697f36ccb769230449398596fe539b8024249d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:02:21 -0500 Subject: [PATCH 40/80] made changes --- tests/models/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index 2225c107ad17e..644efb9c3ddd7 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -73,7 +73,11 @@ def run_model_test(trainer_options, model, on_gpu=True): pretrained_model = load_model(logger, trainer.checkpoint_callback.filepath) # test new model accuracy - [run_prediction(dataloader, pretrained_model) for dataloader in model.test_dataloader()] + test_loaders = model.test_dataloader() + if not isinstance(test_loaders, list): + test_loaders = [test_loaders] + + [run_prediction(dataloader, pretrained_model) for dataloader in test_loaders] if trainer.use_ddp or trainer.use_ddp2: # on hpc this would work fine... but need to hack it for the purpose of the test From 803e72d9853aaf92fe0ebe386759d6038e781da5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:05:56 -0500 Subject: [PATCH 41/80] made changes --- tests/test_cpu_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index b5f2121d532cd..5d5a20bb561bd 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -46,7 +46,7 @@ def test_lbfgs_cpu_model(tmpdir): trainer_options = dict( default_save_path=tmpdir, - max_epochs=1, + max_epochs=2, print_nan_grads=True, show_progress_bar=False, weights_summary='top', From a8f3e196c9e1986efb515c91897361908a4f424b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:12:13 -0500 Subject: [PATCH 42/80] made changes --- tests/models/base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/models/base.py b/tests/models/base.py index 5e7d4153e6bec..1f14ae07a487e 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -150,12 +150,20 @@ def configure_optimizers(self): # test returning only 1 list instead of 2 return optimizer + def prepare_data(self): + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = TestingMNIST(root=self.hparams.data_root, train=True, + transform=transform, download=True, num_samples=2000) + dataset = TestingMNIST(root=self.hparams.data_root, train=False, + transform=transform, download=True, num_samples=2000) + def _dataloader(self, train): # init data generators transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = TestingMNIST(root=self.hparams.data_root, train=train, - transform=transform, download=True, num_samples=2000) + transform=transform, download=False, num_samples=2000) # when using multi-node we need to add the datasampler train_sampler = None From 53598e2b444298e42ae1b6c9c8140bbcc7a3c444 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:12:53 -0500 Subject: [PATCH 43/80] made changes --- tests/models/base.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index 1f14ae07a487e..2c17402624d63 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -166,22 +166,11 @@ def _dataloader(self, train): transform=transform, download=False, num_samples=2000) # when using multi-node we need to add the datasampler - train_sampler = None batch_size = self.hparams.batch_size - try: - if self.use_ddp and not self.force_remove_distributed_sampler: - train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) - batch_size = batch_size // self.trainer.world_size # scale batch size - except Exception: - pass - - should_shuffle = train_sampler is None loader = DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=should_shuffle, - sampler=train_sampler ) return loader From cddbac8afdfd1f8f96166a4e3253ce3f7181a175 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:13:45 -0500 Subject: [PATCH 44/80] made changes --- pl_examples/basic_examples/lightning_module_template.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index 500b9eb40bafd..fabe18504686f 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -195,18 +195,11 @@ def __dataloader(self, train): transform=transform, download=False) # when using multi-node (ddp) we need to add the datasampler - train_sampler = None batch_size = self.hparams.batch_size - if self.use_ddp: - train_sampler = DistributedSampler(dataset) - - should_shuffle = train_sampler is None loader = DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=should_shuffle, - sampler=train_sampler, num_workers=0 ) From e42b1b7ce510f18ed5f4471df14498b24fae85c1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:24:57 -0500 Subject: [PATCH 45/80] made changes --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8b34f38cda15e..196e6e1b46b83 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1119,6 +1119,7 @@ def test(self, model: Optional[LightningModule] = None): if model is not None: self.fit(model) else: + self.reset_test_dataloader(self.get_model()) self.run_evaluation(test=True) From b83a7d7ee8e67c08c2b2281ef276b3c6d91f4046 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:33:49 -0500 Subject: [PATCH 46/80] made changes --- pytorch_lightning/trainer/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 196e6e1b46b83..8b34f38cda15e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1119,7 +1119,6 @@ def test(self, model: Optional[LightningModule] = None): if model is not None: self.fit(model) else: - self.reset_test_dataloader(self.get_model()) self.run_evaluation(test=True) From d1177273eaf06d0641f60189582859158c3fc36b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:37:03 -0500 Subject: [PATCH 47/80] made changes --- tests/test_gpu_models.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index e982d2f25afd8..a64eb67e16435 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -211,32 +211,6 @@ def test_multi_gpu_model_dp(tmpdir): memory.get_memory_profile('min_max') -def test_ddp_sampler_error(tmpdir): - """Make sure DDP + AMP work.""" - if not tutils.can_run_gpu_test(): - return - - tutils.reset_seed() - tutils.set_random_master_port() - - hparams = tutils.get_hparams() - model = LightningTestModel(hparams, force_remove_distributed_sampler=True) - - logger = tutils.get_test_tube_logger(tmpdir, True) - - trainer = Trainer( - logger=logger, - show_progress_bar=False, - max_epochs=1, - gpus=[0, 1], - distributed_backend='ddp', - precision=16 - ) - - with pytest.warns(UserWarning): - trainer.get_dataloaders(model) - - @pytest.fixture def mocked_device_count(monkeypatch): def device_count(): From 6e57368e11ce98ce0446333b85d874dcf440d16c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:39:19 -0500 Subject: [PATCH 48/80] made changes --- tests/test_restore_models.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index 0ca35308fa5a9..a348b26f161a1 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -53,7 +53,11 @@ def test_running_test_pretrained_model_ddp(tmpdir): new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) - for dataloader in model.test_dataloader(): + dataloaders = model.test_dataloader() + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: tutils.run_prediction(dataloader, pretrained_model) @@ -345,7 +349,11 @@ def test_model_saving_loading(tmpdir): assert result == 1, 'amp + ddp model failed to complete' # make a prediction - for dataloader in model.test_dataloader(): + dataloaders = model.test_dataloader() + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: for batch in dataloader: break From de6417593441961a3aba6d4b09516c2eb3ea9b6c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:41:37 -0500 Subject: [PATCH 49/80] made changes --- tests/test_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index b642640528952..529ff95ff496d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -507,11 +507,11 @@ class CurrentTestModel( result = trainer.fit(model) # verify there are 2 val loaders - assert len(trainer.get_test_dataloaders()) == 2, \ + assert len(trainer.test_dataloaders) == 2, \ 'Multiple test_dataloaders not initiated properly' # make sure predictions are good for each test set - for dataloader in trainer.get_test_dataloaders(): + for dataloader in trainer.test_dataloaders: tutils.run_prediction(dataloader, trainer.model) # run the test method @@ -602,7 +602,7 @@ class CurrentTestModel( assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.get_test_dataloaders()) == 1, \ + assert len(trainer.test_dataloaders) == 1, \ f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' @@ -672,9 +672,9 @@ class CurrentTestModel( test_dataloader=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}' + f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ - f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}' + f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' def _init_steps_model(): From 55d302d3892e285434bb947237b378511691f1a2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:44:12 -0500 Subject: [PATCH 50/80] made changes --- tests/test_gpu_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index a64eb67e16435..a47b97f95ed91 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -124,7 +124,11 @@ def test_cpu_slurm_save_load(tmpdir): # predict with trained model before saving # make a prediction - for dataloader in model.test_dataloader(): + dataloaders = model.test_dataloader() + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] + + for dataloader in dataloaders: for batch in dataloader: break From df70d2e49d6b0a44c2dd71a3a35e6acd06a7ef57 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 18:55:34 -0500 Subject: [PATCH 51/80] made changes --- pytorch_lightning/trainer/evaluation_loop.py | 6 ++++-- tests/test_cpu_models.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 499c20aa6a1b9..4c541d443ed1c 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -312,14 +312,16 @@ def run_evaluation(self, test=False): # select dataloaders if test: - if self.reload_dataloaders_every_epoch: + if self.reload_dataloaders_every_epoch or self.test_dataloaders is None: self.reset_test_dataloader(model) + dataloaders = self.test_dataloaders max_batches = self.num_test_batches else: # val - if self.reload_dataloaders_every_epoch: + if self.reload_dataloaders_every_epoch or self.val_dataloaders is None: self.reset_val_dataloader(model) + dataloaders = self.val_dataloaders max_batches = self.num_val_batches diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 5d5a20bb561bd..4a3045ddb5cb3 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -111,6 +111,7 @@ def test_running_test_after_fitting(tmpdir): assert result == 1, 'training failed to complete' + import pdb; pdb.set_trace() trainer.test() # test we have good test accuracy From 3635e610fa286eec908e9fe94c494526b524b90b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 19:04:16 -0500 Subject: [PATCH 52/80] made changes --- tests/test_cpu_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 4a3045ddb5cb3..5d5a20bb561bd 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -111,7 +111,6 @@ def test_running_test_after_fitting(tmpdir): assert result == 1, 'training failed to complete' - import pdb; pdb.set_trace() trainer.test() # test we have good test accuracy From f7a638269f30b15ce119d62982fdc70918290a65 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 19:41:41 -0500 Subject: [PATCH 53/80] made changes --- tests/test_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 529ff95ff496d..40df90475628d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -504,6 +504,7 @@ class CurrentTestModel( # fit model trainer = Trainer(**trainer_options) + import pdb; pdb.set_trace() result = trainer.fit(model) # verify there are 2 val loaders From abd2126252806668805c7607179b5d34b8bf345f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 19:43:32 -0500 Subject: [PATCH 54/80] made changes --- tests/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 40df90475628d..a1c5472d352a6 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -503,8 +503,8 @@ class CurrentTestModel( ) # fit model - trainer = Trainer(**trainer_options) import pdb; pdb.set_trace() + trainer = Trainer(**trainer_options) result = trainer.fit(model) # verify there are 2 val loaders From 90eef5e8171f351272dd18f4750b8c253bbc39e0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 19:50:18 -0500 Subject: [PATCH 55/80] fixed bad loaders --- pytorch_lightning/trainer/data_loading.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 51df38c76d883..f45aad4811175 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -141,6 +141,7 @@ def reset_train_dataloader(self, model): :param model: :return: """ + self.train_dataloader = self.request_data_loader(model.train_dataloader) self.num_training_batches = 0 @@ -190,6 +191,9 @@ def reset_val_dataloader(self, model): :param model: :return: """ + if not (self.is_overriden('validation_step') and self.is_overriden('validation_end')): + return + self.val_dataloaders = self.request_data_loader(model.val_dataloader) if not isinstance(self.val_dataloaders, list): self.val_dataloaders = [self.val_dataloaders] @@ -213,6 +217,9 @@ def reset_test_dataloader(self, model): :param model: """ + if not (self.is_overriden('test_step') and self.is_overriden('test_end')): + return + # get actual loader self.test_dataloaders = self.request_data_loader(model.test_dataloader) if not isinstance(self.test_dataloaders, list): From d0476183558db0c4eee108f23d16050f74d0d9cd Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 19:50:41 -0500 Subject: [PATCH 56/80] fixed bad loaders --- tests/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index a1c5472d352a6..529ff95ff496d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -503,7 +503,6 @@ class CurrentTestModel( ) # fit model - import pdb; pdb.set_trace() trainer = Trainer(**trainer_options) result = trainer.fit(model) From eb9a3800b6430b7ad4ebd5913440ccf25526d3ac Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 19:53:56 -0500 Subject: [PATCH 57/80] fixed bad loaders --- pytorch_lightning/trainer/data_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f45aad4811175..c355aef456c23 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -191,7 +191,7 @@ def reset_val_dataloader(self, model): :param model: :return: """ - if not (self.is_overriden('validation_step') and self.is_overriden('validation_end')): + if not self.is_overriden('validation_step'): return self.val_dataloaders = self.request_data_loader(model.val_dataloader) @@ -217,7 +217,7 @@ def reset_test_dataloader(self, model): :param model: """ - if not (self.is_overriden('test_step') and self.is_overriden('test_end')): + if not self.is_overriden('test_step'): return # get actual loader From cb4f761c4d1c43210c06f4b6226a638c40c4e871 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 19:55:38 -0500 Subject: [PATCH 58/80] fixed bad loaders --- tests/test_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 529ff95ff496d..40df90475628d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -504,6 +504,7 @@ class CurrentTestModel( # fit model trainer = Trainer(**trainer_options) + import pdb; pdb.set_trace() result = trainer.fit(model) # verify there are 2 val loaders From cb8e977f207c14e549c15feae38f6aa90cb2edd0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 20:00:11 -0500 Subject: [PATCH 59/80] fixed bad loaders --- tests/test_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 40df90475628d..d9b1e024ded0b 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -504,9 +504,10 @@ class CurrentTestModel( # fit model trainer = Trainer(**trainer_options) - import pdb; pdb.set_trace() result = trainer.fit(model) + trainer.test() + # verify there are 2 val loaders assert len(trainer.test_dataloaders) == 2, \ 'Multiple test_dataloaders not initiated properly' From 82ec6cee91ee0605965bb93a8b525a38af8ea83b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 20:28:18 -0500 Subject: [PATCH 60/80] fixed bad loaders --- pytorch_lightning/trainer/trainer.py | 64 ++++++---------------------- 1 file changed, 13 insertions(+), 51 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8b34f38cda15e..0332bb0427699 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -901,14 +901,22 @@ def fit( """ - # Update the dataloader attributes of the model with the ones supplied here, - # if they are not already defined in model + # when dataloader is passed via fit, patch the train_dataloader + # functions to overwrite with these implementations if train_dataloader is not None: - _set_dataloader(model, train_dataloader, 'train_dataloader') + def patch_train_dataloader(): + return train_dataloader + model.train_dataloader = patch_train_dataloader + if val_dataloader is not None: - _set_dataloader(model, val_dataloader, 'val_dataloader') + def patch_val_dataloader(): + return val_dataloader + model.val_dataloader = patch_val_dataloader + if test_dataloader is not None: - _set_dataloader(model, test_dataloader, 'test_dataloader') + def patch_test_dataloader(): + return test_dataloader + model.test_dataloader = patch_test_dataloader # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: @@ -1120,49 +1128,3 @@ def test(self, model: Optional[LightningModule] = None): self.fit(model) else: self.run_evaluation(test=True) - - -def _set_dataloader(model, dataloader, attribute): - r''' - Check dataloaders passed to .fit() method if they are pytorch DataLoader - objects and whether or not we should overright the corresponding dataloader - in the model - - Args: - model (LightningModule): The model to check - - dataloader: If a pytorch dataloader (or a list of pytorch dataloaders) - is passed, it will be incorporate into the model as model.attribute. - If attribute alreay exist it will warn the userpass. If not a - dataloader will throw an error - - attribute (str): The attribute to save the dataloader under - - ''' - # Check if attribute comes directly from base class or - # derived in user subclass - if LightningModule.__qualname__ in getattr(model, attribute).__qualname__: - # Val and test should be list of dataloaders - dataloader = dataloader if attribute == 'train_dataloader' or \ - (attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader] - - # Check we are given valid dataloaders - is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader) - is_dataloader_list = isinstance(dataloader, list) - if is_dataloader_list: - valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader) - if is_dataloader or is_dataloader_list and valid_loaders: - - # Overwrite abstract methods - dl = lambda: dataloader - dl.__name__ = attribute - setattr(model, attribute, dl) - - elif dataloader and dataloader != [None]: - raise ValueError(f'`{attribute}` needs to be an instance of ' - '`torch.utils.data.DataLoader` or a list of ' - 'DataLoaders, instead got %r`' % dataloader) - - elif dataloader: # if default (None) is passed, do not warn the user - warnings.warn(f'Model has predefined `{attribute}`,' - f' will skip `{attribute}={dataloader}` passed to fit method.') From 617dd32f0d9c061681b66acdcbf4740f97468773 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 20:35:59 -0500 Subject: [PATCH 61/80] fixed bad loaders --- pytorch_lightning/trainer/trainer.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0332bb0427699..cad0ad86ec0ea 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -855,8 +855,8 @@ def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, - val_dataloader: Optional[DataLoader] = None, - test_dataloader: Optional[DataLoader] = None + val_dataloaders: Optional[DataLoader] = None, + test_dataloaders: Optional[DataLoader] = None ): r""" Runs the full optimization routine. @@ -868,13 +868,13 @@ def fit( DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. - val_dataloader: Either a single + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloader method this will be skipped + If the model has a predefined val_dataloaders method this will be skipped - test_dataloader: Either a single + test_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloader method this will be skipped + If the model has a predefined test_dataloaders method this will be skipped Example:: @@ -906,16 +906,19 @@ def fit( if train_dataloader is not None: def patch_train_dataloader(): return train_dataloader + model.train_dataloader = patch_train_dataloader - if val_dataloader is not None: + if val_dataloaders is not None: def patch_val_dataloader(): - return val_dataloader + return val_dataloaders + model.val_dataloader = patch_val_dataloader - if test_dataloader is not None: + if test_dataloaders is not None: def patch_test_dataloader(): - return test_dataloader + return test_dataloaders + model.test_dataloader = patch_test_dataloader # when using multi-node or DDP within a node start each module in a separate process From 97cf4c068024338e926cfa63ba0c95477cbbceec Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 20:36:58 -0500 Subject: [PATCH 62/80] fixed bad loaders --- tests/test_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index d9b1e024ded0b..1a1eb2ec83642 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -570,6 +570,8 @@ class CurrentTestModel( trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloader=model._dataloader(train=False)) + + import pdb; pdb.set_trace() results = trainer.fit(model, **fit_options) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' From 1671fbc286828980935a08ee5f7214e928a013dc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 20:40:01 -0500 Subject: [PATCH 63/80] fixed bad loaders --- tests/test_trainer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 1a1eb2ec83642..bb3d75538c4bf 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -569,7 +569,7 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloader=model._dataloader(train=False)) + val_dataloaders=model._dataloader(train=False)) import pdb; pdb.set_trace() results = trainer.fit(model, **fit_options) @@ -600,8 +600,8 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloader=model._dataloader(train=False), - test_dataloader=model._dataloader(train=False)) + val_dataloaders=model._dataloader(train=False), + test_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) assert len(trainer.val_dataloaders) == 1, \ @@ -633,9 +633,9 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloader=[model._dataloader(train=False), + val_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)], - test_dataloader=[model._dataloader(train=False), + test_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)]) results = trainer.fit(model, **fit_options) @@ -672,8 +672,8 @@ class CurrentTestModel( # fit model trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloader=model._dataloader(train=False), - test_dataloader=model._dataloader(train=False)) + fit_options = dict(val_dataloaders=model._dataloader(train=False), + test_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' From 08eeb48b7978d924085d795a0a1c473e3707e71f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 20:51:54 -0500 Subject: [PATCH 64/80] fixed error in .fit with loaders --- pytorch_lightning/trainer/model_hooks.py | 5 +++-- pytorch_lightning/trainer/trainer.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index e5afc90aebe6f..eb0d529d2681b 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -11,8 +11,9 @@ def is_function_implemented(self, f_name): f_op = getattr(model, f_name, None) return callable(f_op) - def is_overriden(self, f_name): - model = self.get_model() + def is_overriden(self, f_name, model=None): + if model is None: + model = self.get_model() super_object = LightningModule # when code pointers are different, it was overriden diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cad0ad86ec0ea..6e53946152e61 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -904,18 +904,30 @@ def fit( # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: + if not self.is_overriden('training_step', model): + m = 'You called .fit() with a train_dataloader but did not define training_step()' + raise MisconfigurationException(m) + def patch_train_dataloader(): return train_dataloader model.train_dataloader = patch_train_dataloader if val_dataloaders is not None: + if not self.is_overriden('validation_step', model): + m = 'You called .fit() with a val_dataloaders but did not define validation_step()' + raise MisconfigurationException(m) + def patch_val_dataloader(): return val_dataloaders model.val_dataloader = patch_val_dataloader if test_dataloaders is not None: + if not self.is_overriden('test_step', model): + m = 'You called .fit() with a test_dataloaders but did not define test_step()' + raise MisconfigurationException(m) + def patch_test_dataloader(): return test_dataloaders From d9cfcdb583fe420e6bf8b95c4c4601bd7704a8c1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 20:53:58 -0500 Subject: [PATCH 65/80] fixed error in .fit with loaders --- pytorch_lightning/trainer/trainer.py | 68 +++++++++++++++------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6e53946152e61..855163900b62d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -900,39 +900,10 @@ def fit( # feed to .fit() """ + # set up the passed in dataloaders (if needed) + self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) - # when dataloader is passed via fit, patch the train_dataloader - # functions to overwrite with these implementations - if train_dataloader is not None: - if not self.is_overriden('training_step', model): - m = 'You called .fit() with a train_dataloader but did not define training_step()' - raise MisconfigurationException(m) - - def patch_train_dataloader(): - return train_dataloader - - model.train_dataloader = patch_train_dataloader - - if val_dataloaders is not None: - if not self.is_overriden('validation_step', model): - m = 'You called .fit() with a val_dataloaders but did not define validation_step()' - raise MisconfigurationException(m) - - def patch_val_dataloader(): - return val_dataloaders - - model.val_dataloader = patch_val_dataloader - - if test_dataloaders is not None: - if not self.is_overriden('test_step', model): - m = 'You called .fit() with a test_dataloaders but did not define test_step()' - raise MisconfigurationException(m) - - def patch_test_dataloader(): - return test_dataloaders - - model.test_dataloader = patch_test_dataloader - + # route to appropriate start method # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: task = int(os.environ['SLURM_LOCALID']) @@ -976,6 +947,39 @@ def patch_test_dataloader(): # used for testing or when we need to know that training succeeded return 1 + def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders): + # when dataloader is passed via fit, patch the train_dataloader + # functions to overwrite with these implementations + if train_dataloader is not None: + if not self.is_overriden('training_step', model): + m = 'You called .fit() with a train_dataloader but did not define training_step()' + raise MisconfigurationException(m) + + def patch_train_dataloader(): + return train_dataloader + + model.train_dataloader = patch_train_dataloader + + if val_dataloaders is not None: + if not self.is_overriden('validation_step', model): + m = 'You called .fit() with a val_dataloaders but did not define validation_step()' + raise MisconfigurationException(m) + + def patch_val_dataloader(): + return val_dataloaders + + model.val_dataloader = patch_val_dataloader + + if test_dataloaders is not None: + if not self.is_overriden('test_step', model): + m = 'You called .fit() with a test_dataloaders but did not define test_step()' + raise MisconfigurationException(m) + + def patch_test_dataloader(): + return test_dataloaders + + model.test_dataloader = patch_test_dataloader + def init_optimizers( self, optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]] From d2db8f2fff9f8907f60704c2a7751d602126e045 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:01:03 -0500 Subject: [PATCH 66/80] fixed error in .fit with loaders --- tests/models/__init__.py | 2 ++ tests/models/mixins.py | 10 ++++++++++ tests/test_trainer.py | 20 ++++++++++++++------ 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 3e6424bd4982a..4ab7f066c4359 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -12,6 +12,8 @@ LightningTestMixin, LightningTestStepMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, + LightningTestStepNoDataloadersMixin, + LightningValStepNoDataloadersMixin ) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index d4261ef052ced..8469fb9969b47 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -339,6 +339,16 @@ def test_step(self, batch, batch_idx, dataloader_idx): return output +class LightningTestStepNoDataloadersMixin(LightningTestStepMultipleDataloadersMixin): + def test_dataloader(self): + return None + + +class LightningValStepNoDataloadersMixin(LightningTestStepMultipleDataloadersMixin): + def val_dataloader(self): + return None + + class LightningTestMultipleDataloadersMixin(LightningTestStepMultipleDataloadersMixin): def test_end(self, outputs): """ diff --git a/tests/test_trainer.py b/tests/test_trainer.py index bb3d75538c4bf..50d961c4d3c53 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -17,6 +17,8 @@ LightningValidationStepMixin, LightningValidationMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, + LightningTestStepNoDataloadersMixin, + LightningValStepNoDataloadersMixin ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin @@ -525,7 +527,7 @@ def test_train_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader + LightningTestModelBaseWithoutDataloader, ): pass @@ -551,7 +553,8 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader + LightningTestModelBaseWithoutDataloader, + LightningValStepNoDataloadersMixin ): pass @@ -571,7 +574,6 @@ class CurrentTestModel( fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False)) - import pdb; pdb.set_trace() results = trainer.fit(model, **fit_options) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -582,7 +584,9 @@ def test_all_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader + LightningTestModelBaseWithoutDataloader, + LightningValStepNoDataloadersMixin, + LightningTestStepNoDataloadersMixin ): pass @@ -615,7 +619,9 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader + LightningTestModelBaseWithoutDataloader, + LightningValStepNoDataloadersMixin, + LightningTestStepNoDataloadersMixin ): pass @@ -650,7 +656,9 @@ def test_mixing_of_dataloader_options(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBase + LightningTestModelBase, + LightningValStepNoDataloadersMixin, + LightningTestStepNoDataloadersMixin ): pass From 35a888038d7382893dede131e70d13682b81f954 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:04:11 -0500 Subject: [PATCH 67/80] fixed error in .fit with loaders --- tests/models/mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 8469fb9969b47..ac49a7ab6b6ff 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -344,7 +344,7 @@ def test_dataloader(self): return None -class LightningValStepNoDataloadersMixin(LightningTestStepMultipleDataloadersMixin): +class LightningValStepNoDataloadersMixin(LightningValidationStepMultipleDataloadersMixin): def val_dataloader(self): return None From cec5931500a973f111c963917e4a083873be4b5a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:07:05 -0500 Subject: [PATCH 68/80] fixed error in .fit with loaders --- tests/test_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 50d961c4d3c53..65f3f99e747b9 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -574,6 +574,7 @@ class CurrentTestModel( fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False)) + import pdb; pdb.set_trace() results = trainer.fit(model, **fit_options) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' From 6c412eda909df558a919f2dc9169e64c24d2c2c9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:12:43 -0500 Subject: [PATCH 69/80] fixed error in .fit with loaders --- tests/models/mixins.py | 102 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 5 deletions(-) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index ac49a7ab6b6ff..8de780ba4dbc6 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -340,13 +340,105 @@ def test_step(self, batch, batch_idx, dataloader_idx): class LightningTestStepNoDataloadersMixin(LightningTestStepMultipleDataloadersMixin): - def test_dataloader(self): - return None + def test_step(self, batch, batch_idx, dataloader_idx): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + loss_test = self.loss(y, y_hat) -class LightningValStepNoDataloadersMixin(LightningValidationStepMultipleDataloadersMixin): - def val_dataloader(self): - return None + # acc + labels_hat = torch.argmax(y_hat, dim=1) + test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + test_acc = torch.tensor(test_acc) + + if self.on_gpu: + test_acc = test_acc.cuda(loss_test.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_test = loss_test.unsqueeze(0) + test_acc = test_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_idx % 1 == 0: + output = OrderedDict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + }) + return output + if batch_idx % 2 == 0: + return test_acc + + if batch_idx % 3 == 0: + output = OrderedDict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + 'test_dic': {'test_loss_a': loss_test} + }) + return output + if batch_idx % 5 == 0: + output = OrderedDict({ + f'test_loss_{dataloader_idx}': loss_test, + f'test_acc_{dataloader_idx}': test_acc, + }) + return output + + +class LightningValStepNoDataloadersMixin: + def validation_step(self, batch, batch_idx, dataloader_idx): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + + if self.on_gpu: + val_acc = val_acc.cuda(loss_val.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + val_acc = val_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_idx % 1 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + }) + return output + if batch_idx % 2 == 0: + return val_acc + + if batch_idx % 3 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + 'test_dic': {'val_loss_a': loss_val} + }) + return output + if batch_idx % 5 == 0: + output = OrderedDict({ + f'val_loss_{dataloader_idx}': loss_val, + f'val_acc_{dataloader_idx}': val_acc, + }) + return output class LightningTestMultipleDataloadersMixin(LightningTestStepMultipleDataloadersMixin): From 6287dd6c14fd72d2721ba8cab227b99aa32cb2f0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:12:55 -0500 Subject: [PATCH 70/80] fixed error in .fit with loaders --- tests/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 65f3f99e747b9..50d961c4d3c53 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -574,7 +574,6 @@ class CurrentTestModel( fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False)) - import pdb; pdb.set_trace() results = trainer.fit(model, **fit_options) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' From dc403e85b4b127bf046f0baa1dae4481f8cd1b19 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:15:23 -0500 Subject: [PATCH 71/80] fixed error in .fit with loaders --- tests/test_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 50d961c4d3c53..7b713df5fb386 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -553,8 +553,8 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( + LightningValStepNoDataloadersMixin, LightningTestModelBaseWithoutDataloader, - LightningValStepNoDataloadersMixin ): pass @@ -584,9 +584,9 @@ def test_all_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader, LightningValStepNoDataloadersMixin, - LightningTestStepNoDataloadersMixin + LightningTestStepNoDataloadersMixin, + LightningTestModelBaseWithoutDataloader, ): pass From b2755f9261e97a01af440422b96d82a7dcc3544d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:16:39 -0500 Subject: [PATCH 72/80] fixed error in .fit with loaders --- tests/test_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 7b713df5fb386..463c4c19475d1 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -619,9 +619,9 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBaseWithoutDataloader, LightningValStepNoDataloadersMixin, - LightningTestStepNoDataloadersMixin + LightningTestStepNoDataloadersMixin, + LightningTestModelBaseWithoutDataloader, ): pass From 1be1cf681beee2ce597a7d2f14c9c13b25252372 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:17:46 -0500 Subject: [PATCH 73/80] fixed error in .fit with loaders --- tests/models/mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 8de780ba4dbc6..9d60682c23300 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -340,7 +340,7 @@ def test_step(self, batch, batch_idx, dataloader_idx): class LightningTestStepNoDataloadersMixin(LightningTestStepMultipleDataloadersMixin): - def test_step(self, batch, batch_idx, dataloader_idx): + def test_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop :param batch: @@ -391,7 +391,7 @@ def test_step(self, batch, batch_idx, dataloader_idx): class LightningValStepNoDataloadersMixin: - def validation_step(self, batch, batch_idx, dataloader_idx): + def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop :param batch: From 83869c299c003d124f9b3bc07967dc60d4042a91 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:25:41 -0500 Subject: [PATCH 74/80] fixed error in .fit with loaders --- tests/models/__init__.py | 6 ++- tests/models/mixins.py | 95 ++++++++++++++++++++++++++++++++++++++-- tests/test_trainer.py | 22 +++++----- 3 files changed, 108 insertions(+), 15 deletions(-) diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 4ab7f066c4359..df16bffd668b8 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -12,8 +12,10 @@ LightningTestMixin, LightningTestStepMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, - LightningTestStepNoDataloadersMixin, - LightningValStepNoDataloadersMixin + LightningTestFitSingleTestDataloadersMixin, + LightningTestFitMultipleTestDataloadersMixin, + LightningValStepFitSingleDataloaderMixin, + LightningValStepFitMultipleDataloadersMixin ) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 9d60682c23300..d6d286a044500 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -339,7 +339,7 @@ def test_step(self, batch, batch_idx, dataloader_idx): return output -class LightningTestStepNoDataloadersMixin(LightningTestStepMultipleDataloadersMixin): +class LightningTestFitSingleTestDataloadersMixin: def test_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop @@ -382,6 +382,51 @@ def test_step(self, batch, batch_idx): 'test_dic': {'test_loss_a': loss_test} }) return output + + +class LightningTestFitMultipleTestDataloadersMixin: + def test_step(self, batch, batch_idx, dataloader_idx): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_test = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + test_acc = torch.tensor(test_acc) + + if self.on_gpu: + test_acc = test_acc.cuda(loss_test.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_test = loss_test.unsqueeze(0) + test_acc = test_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_idx % 1 == 0: + output = OrderedDict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + }) + return output + if batch_idx % 2 == 0: + return test_acc + + if batch_idx % 3 == 0: + output = OrderedDict({ + 'test_loss': loss_test, + 'test_acc': test_acc, + 'test_dic': {'test_loss_a': loss_test} + }) + return output if batch_idx % 5 == 0: output = OrderedDict({ f'test_loss_{dataloader_idx}': loss_test, @@ -389,8 +434,7 @@ def test_step(self, batch, batch_idx): }) return output - -class LightningValStepNoDataloadersMixin: +class LightningValStepFitSingleDataloaderMixin: def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop @@ -433,6 +477,51 @@ def validation_step(self, batch, batch_idx): 'test_dic': {'val_loss_a': loss_val} }) return output + + +class LightningValStepFitMultipleDataloadersMixin: + def validation_step(self, batch, batch_idx, dataloader_idx): + """ + Lightning calls this inside the validation loop + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + + if self.on_gpu: + val_acc = val_acc.cuda(loss_val.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + val_acc = val_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_idx % 1 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + }) + return output + if batch_idx % 2 == 0: + return val_acc + + if batch_idx % 3 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + 'test_dic': {'val_loss_a': loss_val} + }) + return output if batch_idx % 5 == 0: output = OrderedDict({ f'val_loss_{dataloader_idx}': loss_val, diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 463c4c19475d1..850ca6e3d0571 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -17,8 +17,10 @@ LightningValidationStepMixin, LightningValidationMultipleDataloadersMixin, LightningTestMultipleDataloadersMixin, - LightningTestStepNoDataloadersMixin, - LightningValStepNoDataloadersMixin + LightningTestFitSingleTestDataloadersMixin, + LightningTestFitMultipleTestDataloadersMixin, + LightningValStepFitMultipleDataloadersMixin, + LightningValStepFitSingleDataloaderMixin ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin @@ -553,7 +555,7 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningValStepNoDataloadersMixin, + LightningValStepFitSingleDataloaderMixin, LightningTestModelBaseWithoutDataloader, ): pass @@ -584,8 +586,8 @@ def test_all_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningValStepNoDataloadersMixin, - LightningTestStepNoDataloadersMixin, + LightningValStepFitSingleDataloaderMixin, + LightningTestFitSingleTestDataloadersMixin, LightningTestModelBaseWithoutDataloader, ): pass @@ -619,8 +621,8 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningValStepNoDataloadersMixin, - LightningTestStepNoDataloadersMixin, + LightningValStepFitMultipleDataloadersMixin, + LightningTestFitMultipleTestDataloadersMixin, LightningTestModelBaseWithoutDataloader, ): pass @@ -657,8 +659,8 @@ def test_mixing_of_dataloader_options(tmpdir): class CurrentTestModel( LightningTestModelBase, - LightningValStepNoDataloadersMixin, - LightningTestStepNoDataloadersMixin + LightningValStepFitSingleDataloaderMixin, + LightningTestFitSingleTestDataloadersMixin ): pass @@ -675,7 +677,7 @@ class CurrentTestModel( # fit model trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloader=model._dataloader(train=False)) + fit_options = dict(val_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) # fit model From 6bc958719dc8f4511955ffd764c973c556a0c527 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:30:32 -0500 Subject: [PATCH 75/80] fixed error in .fit with loaders --- tests/models/mixins.py | 1 + tests/test_trainer.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index d6d286a044500..940f5e30f350a 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -434,6 +434,7 @@ def test_step(self, batch, batch_idx, dataloader_idx): }) return output + class LightningValStepFitSingleDataloaderMixin: def validation_step(self, batch, batch_idx): """ diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 850ca6e3d0571..e8332d0afe30a 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -608,6 +608,8 @@ class CurrentTestModel( fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=model._dataloader(train=False), test_dataloaders=model._dataloader(train=False)) + + import pdb; pdb.set_trace() results = trainer.fit(model, **fit_options) assert len(trainer.val_dataloaders) == 1, \ From 66f55d79ab7a3ea2f8eb9858f1e438a08bc67613 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:33:42 -0500 Subject: [PATCH 76/80] fixed error in .fit with loaders --- tests/test_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index e8332d0afe30a..d567a028d9294 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -609,9 +609,10 @@ class CurrentTestModel( val_dataloaders=model._dataloader(train=False), test_dataloaders=model._dataloader(train=False)) - import pdb; pdb.set_trace() results = trainer.fit(model, **fit_options) + trainer.test() + assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ @@ -648,6 +649,7 @@ class CurrentTestModel( test_dataloaders=[model._dataloader(train=False), model._dataloader(train=False)]) results = trainer.fit(model, **fit_options) + trainer.test() assert len(trainer.val_dataloaders) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -687,6 +689,8 @@ class CurrentTestModel( fit_options = dict(val_dataloaders=model._dataloader(train=False), test_dataloaders=model._dataloader(train=False)) results = trainer.fit(model, **fit_options) + trainer.test() + assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' assert len(trainer.test_dataloaders) == 1, \ From 05ad57dd97bd291a96fe6aad376efe8653661f1e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:37:08 -0500 Subject: [PATCH 77/80] fixes #909 --- pytorch_lightning/trainer/evaluation_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 4c541d443ed1c..37847393a10cc 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -301,8 +301,8 @@ def evaluate(self, model, dataloaders, max_batches, test=False): def run_evaluation(self, test=False): # when testing make sure user defined a test step - if test and not (self.is_overriden('test_step') and self.is_overriden('test_end')): - m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`. + if test and not (self.is_overriden('test_step')): + m = '''You called `.test()` without defining model's `.test_step()`. Please define and try again''' raise MisconfigurationException(m) From 9504461c6ccd5d48d3c79fb4ac2272168f0bf850 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:40:05 -0500 Subject: [PATCH 78/80] fixes #909 --- tests/test_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index d567a028d9294..bc5b2979de781 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -645,9 +645,9 @@ class CurrentTestModel( trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=[model._dataloader(train=False), - model._dataloader(train=False)], + model._dataloader(train=False)], test_dataloaders=[model._dataloader(train=False), - model._dataloader(train=False)]) + model._dataloader(train=False)]) results = trainer.fit(model, **fit_options) trainer.test() @@ -662,9 +662,9 @@ def test_mixing_of_dataloader_options(tmpdir): tutils.reset_seed() class CurrentTestModel( - LightningTestModelBase, LightningValStepFitSingleDataloaderMixin, - LightningTestFitSingleTestDataloadersMixin + LightningTestFitSingleTestDataloadersMixin, + LightningTestModelBase, ): pass From c12cb922fe10332455a044898993f99460d441cc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 21:59:13 -0500 Subject: [PATCH 79/80] bug fix --- CHANGELOG.md | 4 ++++ tests/test_restore_models.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b96da2d8d891f..18f10757768c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added automatic sampler setup. Depending on DDP or TPU, lightning configures the sampler correctly (user needs to do nothing) ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) +- Added `reload_dataloaders_every_epoch=False` flag for trainer. Some users require reloading data every epoch ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) +- Added `progress_bar_refresh_rate=50` flag for trainer. Throttle refresh rate on notebooks ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) - Updated governance docs - Added a check to ensure that the metric used for early stopping exists before training commences ([#542](https://github.com/PyTorchLightning/pytorch-lightning/pull/542)) - Added `optimizer_idx` argument to `backward` hook ([#733](https://github.com/PyTorchLightning/pytorch-lightning/pull/733)) @@ -22,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Removed `@data_loader` decorator ([#926](https://github.com/PyTorchLightning/pytorch-lightning/pull/926)) - Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752)) - Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767)) - Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749)) diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index a348b26f161a1..ba347e474a326 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -248,7 +248,7 @@ def assert_good_acc(): dp_model = new_trainer.model dp_model.eval() - dataloader = trainer.get_train_dataloader() + dataloader = trainer.train_dataloader tutils.run_prediction(dataloader, dp_model, dp=True) # new model From 3173ad384db9ca996bc49a0abb776eb2f4310c2a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Feb 2020 22:14:32 -0500 Subject: [PATCH 80/80] Fixes #902 --- pytorch_lightning/core/lightning.py | 11 ++++++++++- pytorch_lightning/trainer/training_loop.py | 15 +++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a5d055c3b9959..85a713427e151 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -16,6 +16,13 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +try: + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True + +except ImportError: + XLA_AVAILABLE = False + class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): @@ -798,7 +805,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, sec optimizer.zero_grad() """ - if isinstance(optimizer, torch.optim.LBFGS): + if self.trainer.use_tpu and XLA_AVAILABLE: + xm.optimizer_step(optimizer) + elif isinstance(optimizer, torch.optim.LBFGS): optimizer.step(second_order_closure) else: optimizer.step() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1f6ad4511acc2..0542d0a837f5d 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -168,15 +168,9 @@ def training_step(self, batch, batch_idx): except ImportError: APEX_AVAILABLE = False -try: - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -except ImportError: - XLA_AVAILABLE = False - try: import torch_xla.distributed.parallel_loader as xla_pl + import torch_xla.core.xla_model as xm XLA_AVAILABLE = True @@ -600,11 +594,8 @@ def optimizer_closure(): # override function to modify this behavior model = self.get_model() with self.profiler.profile('optimizer_step'): - if self.use_tpu: - xm.optimizer_step(optimizer) - else: - model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, optimizer_closure) + model.optimizer_step(self.current_epoch, batch_idx, + optimizer, opt_idx, optimizer_closure) # calculate running loss for display self.running_loss.append(self.batch_loss_value)