From d55043f7ce412b4bbaa4f4327eb72276c615d1bf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 16:00:23 +0200 Subject: [PATCH 01/15] Move trainer functions --- pytorch_lightning/trainer/trainer.py | 700 +++++++++++++-------------- 1 file changed, 350 insertions(+), 350 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b98c1c0c551c2..72c675d60a5f4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -410,6 +410,279 @@ def __init__( # Callback system self.on_init_end() + def fit( + self, + model: LightningModule, + train_dataloader: Any = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + ) -> None: + r""" + Runs the full optimization routine. + + Args: + model: Model to fit. + + train_dataloader: Either a single PyTorch DataLoader or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please + see this :ref:`page ` + + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + """ + Trainer._log_api_event("fit") + + self.state = TrainerState.FITTING + self.training = True + + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + raise MisconfigurationException( + 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`' + ) + + # links data to the trainer + self.data_connector.attach_data( + model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + ) + + self._run(model) + + assert self.state.stopped + self.training = False + + def validate( + self, + model: Optional[LightningModule] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ) -> _EVALUATE_OUTPUT: + r""" + Perform one evaluation epoch over the validation set. + + Args: + model: The model to validate. + + val_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying validation samples. + + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. + If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. + + verbose: If True, prints the validation results. + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + Returns: + The dictionary with final validation results returned by validation_epoch_end. + If validation_epoch_end is not defined, the output is a list of the dictionaries + returned by validation_step. + """ + # -------------------- + # SETUP HOOK + # -------------------- + Trainer._log_api_event("validate") + self.verbose_evaluate = verbose + + self.state = TrainerState.VALIDATING + self.validating = True + + # If you supply a datamodule you can't supply val_dataloaders + if val_dataloaders is not None and datamodule: + raise MisconfigurationException( + 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' + ) + + model_provided = model is not None + model = model or self.lightning_module + + # links data to the trainer + self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule) + + if not model_provided: + self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) + + # run validate + results = self._run(model) + + assert self.state.stopped + self.validating = False + + return results + + def test( + self, + model: Optional[LightningModule] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ) -> _EVALUATE_OUTPUT: + r""" + Perform one evaluation epoch over the test set. It's separated from + fit to make sure you never run on your test set until you want to. + + Args: + model: The model to test. + + test_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying test samples. + + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the current weights of the model. + When the model is given as argument, this parameter will not apply. + + verbose: If True, prints the test results. + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + Returns: + Returns a list of dictionaries, one for each test dataloader containing their respective metrics. + """ + # -------------------- + # SETUP HOOK + # -------------------- + Trainer._log_api_event("test") + self.verbose_evaluate = verbose + + self.state = TrainerState.TESTING + self.testing = True + + # If you supply a datamodule you can't supply test_dataloaders + if test_dataloaders is not None and datamodule: + raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`') + + model_provided = model is not None + model = model or self.lightning_module + + # links data to the trainer + self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule) + + if not model_provided: + self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) + + # run test + results = self._run(model) + + assert self.state.stopped + self.testing = False + + return results + + def predict( + self, + model: Optional[LightningModule] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + return_predictions: Optional[bool] = None, + ) -> Optional[_PREDICT_OUTPUT]: + r""" + + Separates from fit to make sure you never run on your predictions set until you want to. + This will call the model forward function to compute predictions. + + Args: + model: The model to predict with. + + dataloaders: Either a single PyTorch DataLoader or a list of them, specifying inference samples. + + datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. + + return_predictions: Whether to return predictions. + ``True`` by default except when an accelerator that spawns processes is used (not supported). + + Returns: + Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. + """ + + # -------------------- + # SETUP HOOK + # -------------------- + # If you supply a datamodule you can't supply dataloaders + Trainer._log_api_event("predict") + + model = model or self.lightning_module + + self.predict_loop.return_predictions = return_predictions + + self.state = TrainerState.PREDICTING + self.predicting = True + + if dataloaders is not None and datamodule: + raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') + + # links data to the trainer + self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) + + results = self._run(model) + + assert self.state.stopped + self.predicting = False + + return results + + def tune( + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, + lr_find_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Optional[Union[int, _LRFinder]]]: + r""" + Runs routines to tune hyperparameters before training. + + Args: + model: Model to tune. + + train_dataloader: A Pytorch DataLoader with training samples. If the model has + a predefined train_dataloader method this will be skipped. + + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped + + datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + + scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` + + lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` + """ + Trainer._log_api_event("tune") + self.state = TrainerState.TUNING + self.tuning = True + + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + raise MisconfigurationException( + 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`' + ) + + # links data to the trainer + self.data_connector.attach_data( + model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + ) + + result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) + + assert self.state.stopped + self.tuning = False + + return result + def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: # clean hparams if hasattr(model, "hparams"): @@ -720,276 +993,108 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: # save predictions to disk self.evaluation_loop.predictions.to_disk() - # enable train mode again - self.evaluation_loop.on_evaluation_model_train() - - # reset cached results - self.logger_connector.reset() - - torch.set_grad_enabled(True) - - return eval_loop_results - - def track_output_for_epoch_end(self, outputs, output): - if output is not None: - if isinstance(output, Result): - output = output.detach() - if self.move_metrics_to_cpu: - output = output.cpu() - elif isinstance(output, dict): - output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu) - elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu: - output = output.cpu() - outputs.append(output) - return outputs - - def run_evaluate(self) -> _EVALUATE_OUTPUT: - if not self.is_global_zero and self.progress_bar_callback is not None: - self.progress_bar_callback.disable() - - assert self.evaluating - - with self.profiler.profile(f"run_{self._running_stage}_evaluation"): - eval_loop_results = self.run_evaluation() - - # remove the tensors from the eval results - for i, result in enumerate(eval_loop_results): - if isinstance(result, dict): - for k, v in result.items(): - if isinstance(v, torch.Tensor): - result[k] = v.cpu().item() - - return eval_loop_results - - def run_predict(self) -> Optional[_PREDICT_OUTPUT]: - # prepare dataloaders - dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() - - # check if we want to skip this evaluation - if self.predict_loop.should_skip_predict(max_batches): - return [] - - # set up the eval loop - self.predict_loop.setup(max_batches, dataloaders) - - # call hook - self.predict_loop.on_predict_start() - - # run validation/testing - for dataloader_idx, dataloader in enumerate(dataloaders): - dataloader = self.accelerator.process_dataloader(dataloader) - dl_max_batches = self.predict_loop.max_batches[dataloader_idx] - for batch_idx, batch in enumerate(dataloader): - if batch is None: - continue - - # stop short when running on limited batches - if batch_idx >= dl_max_batches: - break - - # lightning module methods - with self.profiler.profile("predict_step"): - self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) - - # call hook - results = self.predict_loop.on_predict_epoch_end() - - # call hook - self.predict_loop.on_predict_end() - - return results - - def run_sanity_check(self, ref_model): - using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) - should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 - - # run tiny validation (if validation defined) - # to make sure program won't crash during val - if should_sanity_check: - stage = self._running_stage - self.sanity_checking = True - - # hook and callback - self.on_sanity_check_start() - - # run eval step - self.run_evaluation() - - self.on_sanity_check_end() - - self._running_stage = stage - - # reset the seed to what it was before sanity check - # prevents sanity check to affect random sampling in training - reset_seed() - - def fit( - self, - model: LightningModule, - train_dataloader: Any = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - datamodule: Optional[LightningDataModule] = None, - ) -> None: - r""" - Runs the full optimization routine. - - Args: - model: Model to fit. - - train_dataloader: Either a single PyTorch DataLoader or a collection of these - (list, dict, nested lists and dicts). In the case of multiple dataloaders, please - see this :ref:`page ` - - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped - - datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. - """ - Trainer._log_api_event("fit") - - self.state = TrainerState.FITTING - self.training = True - - # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: - raise MisconfigurationException( - 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`' - ) - - # links data to the trainer - self.data_connector.attach_data( - model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule - ) - - self._run(model) - - assert self.state.stopped - self.training = False - - def validate( - self, - model: Optional[LightningModule] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ckpt_path: Optional[str] = 'best', - verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - ) -> _EVALUATE_OUTPUT: - r""" - Perform one evaluation epoch over the validation set. - - Args: - model: The model to validate. - - val_dataloaders: Either a single PyTorch DataLoader or a list of them, - specifying validation samples. - - ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. - If ``None``, use the current weights of the model. - When the model is given as argument, this parameter will not apply. - - verbose: If True, prints the validation results. + # enable train mode again + self.evaluation_loop.on_evaluation_model_train() - datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + # reset cached results + self.logger_connector.reset() - Returns: - The dictionary with final validation results returned by validation_epoch_end. - If validation_epoch_end is not defined, the output is a list of the dictionaries - returned by validation_step. - """ - # -------------------- - # SETUP HOOK - # -------------------- - Trainer._log_api_event("validate") - self.verbose_evaluate = verbose + torch.set_grad_enabled(True) - self.state = TrainerState.VALIDATING - self.validating = True + return eval_loop_results - # If you supply a datamodule you can't supply val_dataloaders - if val_dataloaders is not None and datamodule: - raise MisconfigurationException( - 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' - ) + def track_output_for_epoch_end(self, outputs, output): + if output is not None: + if isinstance(output, Result): + output = output.detach() + if self.move_metrics_to_cpu: + output = output.cpu() + elif isinstance(output, dict): + output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu) + elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu: + output = output.cpu() + outputs.append(output) + return outputs - model_provided = model is not None - model = model or self.lightning_module + def run_evaluate(self) -> _EVALUATE_OUTPUT: + if not self.is_global_zero and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() - # links data to the trainer - self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule) + assert self.evaluating - if not model_provided: - self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) + with self.profiler.profile(f"run_{self._running_stage}_evaluation"): + eval_loop_results = self.run_evaluation() - # run validate - results = self._run(model) + # remove the tensors from the eval results + for i, result in enumerate(eval_loop_results): + if isinstance(result, dict): + for k, v in result.items(): + if isinstance(v, torch.Tensor): + result[k] = v.cpu().item() - assert self.state.stopped - self.validating = False + return eval_loop_results - return results + def run_predict(self) -> Optional[_PREDICT_OUTPUT]: + # prepare dataloaders + dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() - def test( - self, - model: Optional[LightningModule] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ckpt_path: Optional[str] = 'best', - verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - ) -> _EVALUATE_OUTPUT: - r""" - Perform one evaluation epoch over the test set. It's separated from - fit to make sure you never run on your test set until you want to. + # check if we want to skip this evaluation + if self.predict_loop.should_skip_predict(max_batches): + return [] - Args: - model: The model to test. + # set up the eval loop + self.predict_loop.setup(max_batches, dataloaders) - test_dataloaders: Either a single PyTorch DataLoader or a list of them, - specifying test samples. + # call hook + self.predict_loop.on_predict_start() - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the current weights of the model. - When the model is given as argument, this parameter will not apply. + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + dataloader = self.accelerator.process_dataloader(dataloader) + dl_max_batches = self.predict_loop.max_batches[dataloader_idx] + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue - verbose: If True, prints the test results. + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break - datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. + # lightning module methods + with self.profiler.profile("predict_step"): + self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) - Returns: - Returns a list of dictionaries, one for each test dataloader containing their respective metrics. - """ - # -------------------- - # SETUP HOOK - # -------------------- - Trainer._log_api_event("test") - self.verbose_evaluate = verbose + # call hook + results = self.predict_loop.on_predict_epoch_end() - self.state = TrainerState.TESTING - self.testing = True + # call hook + self.predict_loop.on_predict_end() - # If you supply a datamodule you can't supply test_dataloaders - if test_dataloaders is not None and datamodule: - raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`') + return results - model_provided = model is not None - model = model or self.lightning_module + def run_sanity_check(self, ref_model): + using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) + should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 - # links data to the trainer - self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule) + # run tiny validation (if validation defined) + # to make sure program won't crash during val + if should_sanity_check: + stage = self._running_stage + self.sanity_checking = True - if not model_provided: - self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) + # hook and callback + self.on_sanity_check_start() - # run test - results = self._run(model) + # run eval step + self.run_evaluation() - assert self.state.stopped - self.testing = False + self.on_sanity_check_end() - return results + self._running_stage = stage + + # reset the seed to what it was before sanity check + # prevents sanity check to affect random sampling in training + reset_seed() def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: if ckpt_path is None: @@ -1027,111 +1132,6 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: ) return ckpt_path - def predict( - self, - model: Optional[LightningModule] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - datamodule: Optional[LightningDataModule] = None, - return_predictions: Optional[bool] = None, - ) -> Optional[_PREDICT_OUTPUT]: - r""" - - Separates from fit to make sure you never run on your predictions set until you want to. - This will call the model forward function to compute predictions. - - Args: - model: The model to predict with. - - dataloaders: Either a single PyTorch DataLoader or a list of them, specifying inference samples. - - datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. - - return_predictions: Whether to return predictions. - ``True`` by default except when an accelerator that spawns processes is used (not supported). - - Returns: - Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. - """ - - # -------------------- - # SETUP HOOK - # -------------------- - # If you supply a datamodule you can't supply dataloaders - Trainer._log_api_event("predict") - - model = model or self.lightning_module - - self.predict_loop.return_predictions = return_predictions - - self.state = TrainerState.PREDICTING - self.predicting = True - - if dataloaders is not None and datamodule: - raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') - - # links data to the trainer - self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) - - results = self._run(model) - - assert self.state.stopped - self.predicting = False - - return results - - def tune( - self, - model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - datamodule: Optional[LightningDataModule] = None, - scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, - lr_find_kwargs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Optional[Union[int, _LRFinder]]]: - r""" - Runs routines to tune hyperparameters before training. - - Args: - model: Model to tune. - - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. - - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped - - datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. - - scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` - - lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` - """ - Trainer._log_api_event("tune") - self.state = TrainerState.TUNING - self.tuning = True - - # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: - raise MisconfigurationException( - 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`' - ) - - # links data to the trainer - self.data_connector.attach_data( - model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule - ) - - result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) - - assert self.state.stopped - self.tuning = False - - return result - def call_setup_hook(self, model: LightningModule) -> None: assert self.state.running, f"TrainerState: {self.state}" state = self._setup_state From 4df5804f7ea93dfd79a1998bafd1ce93b22d126b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 3 May 2021 12:46:06 +0200 Subject: [PATCH 02/15] Progress --- docs/source/advanced/multiple_loaders.rst | 17 --- .../trainer/connectors/data_connector.py | 10 +- pytorch_lightning/trainer/trainer.py | 128 ++++++++++++------ pytorch_lightning/tuner/tuning.py | 41 +++--- pytorch_lightning/utilities/types.py | 5 + tests/deprecated_api/test_remove_1-6.py | 26 ++++ 6 files changed, 146 insertions(+), 81 deletions(-) create mode 100644 tests/deprecated_api/test_remove_1-6.py diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index 1a82641953c3c..02d5db143c95c 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -91,23 +91,6 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra Furthermore, Lightning also supports that nested lists and dicts (or a combination) can be returned. -.. testcode:: - - class LitModel(LightningModule): - - def train_dataloader(self): - - loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) - loader_b = torch.utils.data.DataLoader(range(16), batch_size=2) - - return {'a': loader_a, 'b': loader_b} - - def training_step(self, batch, batch_idx): - # access a dictionnary with a batch from each dataloader - batch_a = batch["a"] - batch_b = batch["b"] - - .. testcode:: class LitModel(LightningModule): diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9fb531f8eb67c..ba931afcbe03b 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -70,7 +70,7 @@ def can_prepare_data(self): def attach_data( self, model: 'pl.LightningModule', - train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, @@ -79,7 +79,7 @@ def attach_data( # set up the passed in dataloaders (if needed) self.attach_dataloaders( model, - train_dataloader=train_dataloader, + train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, test_dataloaders=test_dataloaders, predict_dataloaders=predict_dataloaders, @@ -91,15 +91,15 @@ def attach_data( def attach_dataloaders( self, model: 'pl.LightningModule', - train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, ) -> None: # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations - if train_dataloader is not None: - model.train_dataloader = _PatchDataLoader(train_dataloader) + if train_dataloaders is not None: + model.train_dataloader = _PatchDataLoader(train_dataloaders) if val_dataloaders is not None: model.val_dataloader = _PatchDataLoader(val_dataloaders) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 72c675d60a5f4..7d7e9a2fe159e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, List, Optional, Union import torch -from torch.utils.data import DataLoader from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback @@ -58,13 +57,13 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT +from pytorch_lightning.utilities.types import _DATALOADERS, _EVALUATE_OUTPUT, _PREDICT_OUTPUT log = logging.getLogger(__name__) # warnings to ignore in trainer @@ -413,9 +412,10 @@ def __init__( def fit( self, model: LightningModule, - train_dataloader: Any = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, + train_dataloader=None, # noqa ) -> None: r""" Runs the full optimization routine. @@ -423,12 +423,13 @@ def fit( Args: model: Model to fit. - train_dataloader: Either a single PyTorch DataLoader or a collection of these - (list, dict, nested lists and dicts). In the case of multiple dataloaders, please - see this :ref:`page ` + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + val_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. + In the case of multiple dataloaders, please see this :ref:`page `. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ @@ -437,19 +438,25 @@ def fit( self.state = TrainerState.FITTING self.training = True + if train_dataloader is not None: + rank_zero_deprecation( + "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6." + " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'" + ) + train_dataloaders = train_dataloader # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None + if isinstance(train_dataloaders, LightningDataModule): + datamodule = train_dataloaders + train_dataloaders = None # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None: raise MisconfigurationException( 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`' ) # links data to the trainer self.data_connector.attach_data( - model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) self._run(model) @@ -460,10 +467,11 @@ def fit( def validate( self, model: Optional[LightningModule] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + val_dataloaders=None, # noqa ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the validation set. @@ -471,8 +479,9 @@ def validate( Args: model: The model to validate. - val_dataloaders: Either a single PyTorch DataLoader or a list of them, - specifying validation samples. + dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. + In the case of multiple dataloaders, please see this :ref:`page `. ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. If ``None``, use the current weights of the model. @@ -496,17 +505,25 @@ def validate( self.state = TrainerState.VALIDATING self.validating = True - # If you supply a datamodule you can't supply val_dataloaders - if val_dataloaders is not None and datamodule: - raise MisconfigurationException( - 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' + if val_dataloaders is not None: + rank_zero_deprecation( + "`trainer.validate(val_dataloaders)` is deprecated in v1.4 and will be removed in v1.6." + " Use `trainer.validate(dataloaders)` instead." ) + dataloaders = val_dataloaders + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(dataloaders, LightningDataModule): + datamodule = dataloaders + dataloaders = None + # If you supply a datamodule you can't supply val_dataloaders + if dataloaders is not None and datamodule: + raise MisconfigurationException('You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`') model_provided = model is not None model = model or self.lightning_module # links data to the trainer - self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule) + self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) if not model_provided: self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) @@ -522,10 +539,11 @@ def validate( def test( self, model: Optional[LightningModule] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + test_dataloaders=None, # noqa ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the test set. It's separated from @@ -534,8 +552,9 @@ def test( Args: model: The model to test. - test_dataloaders: Either a single PyTorch DataLoader or a list of them, - specifying test samples. + dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. + In the case of multiple dataloaders, please see this :ref:`page `. ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the current weights of the model. @@ -557,15 +576,25 @@ def test( self.state = TrainerState.TESTING self.testing = True + if test_dataloaders is not None: + rank_zero_deprecation( + "`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6." + " Use `trainer.test(dataloaders)` instead." + ) + dataloaders = test_dataloaders + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(dataloaders, LightningDataModule): + datamodule = dataloaders + dataloaders = None # If you supply a datamodule you can't supply test_dataloaders - if test_dataloaders is not None and datamodule: - raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`') + if dataloaders is not None and datamodule: + raise MisconfigurationException('You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`') model_provided = model is not None model = model or self.lightning_module # links data to the trainer - self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule) + self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) if not model_provided: self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) @@ -581,7 +610,7 @@ def test( def predict( self, model: Optional[LightningModule] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ) -> Optional[_PREDICT_OUTPUT]: @@ -593,7 +622,9 @@ def predict( Args: model: The model to predict with. - dataloaders: Either a single PyTorch DataLoader or a list of them, specifying inference samples. + dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying inference samples. + In the case of multiple dataloaders, please see this :ref:`page `. datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. @@ -617,6 +648,10 @@ def predict( self.state = TrainerState.PREDICTING self.predicting = True + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(dataloaders, LightningDataModule): + datamodule = dataloaders + dataloaders = None if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') @@ -633,11 +668,12 @@ def predict( def tune( self, model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, + train_dataloader=None, # noqa ) -> Dict[str, Optional[Union[int, _LRFinder]]]: r""" Runs routines to tune hyperparameters before training. @@ -645,11 +681,13 @@ def tune( Args: model: Model to tune. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + val_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. + In the case of multiple dataloaders, please see this :ref:`page `. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. @@ -661,19 +699,25 @@ def tune( self.state = TrainerState.TUNING self.tuning = True + if train_dataloader is not None: + rank_zero_deprecation( + "`trainer.tune(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6." + " Use `trainer.tune(train_dataloaders)` instead. HINT: added 's'" + ) + train_dataloaders = train_dataloader # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None + if isinstance(train_dataloaders, LightningDataModule): + datamodule = train_dataloaders + train_dataloaders = None # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None: raise MisconfigurationException( 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`' ) # links data to the trainer self.data_connector.attach_data( - model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 8e3862b195cd6..441c11196fc77 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union - -from torch.utils.data import DataLoader +from typing import Any, Dict, Optional, Union import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find +from pytorch_lightning.utilities.types import _DATALOADERS class Tuner: @@ -65,14 +64,15 @@ def _run(self, *args: Any, **kwargs: Any) -> None: def scale_batch_size( self, model: 'pl.LightningModule', - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[_DATALOADERS, 'pl.LightningDataModule']] = None, + val_dataloaders: Optional[_DATALOADERS] = None, datamodule: Optional['pl.LightningDataModule'] = None, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', + train_dataloader=None, # noqa ) -> Optional[int]: """ Iteratively try to find the largest batch size for a given model @@ -81,11 +81,13 @@ def scale_batch_size( Args: model: Model to tune. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + val_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. + In the case of multiple dataloaders, please see this :ref:`page `. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. @@ -116,7 +118,8 @@ def scale_batch_size( self.trainer.auto_scale_batch_size = True result = self.trainer.tune( model, - train_dataloader=train_dataloader, + train_dataloaders=train_dataloaders, + train_dataloader=train_dataloader, # deprecated val_dataloaders=val_dataloaders, datamodule=datamodule, scale_batch_size_kwargs={ @@ -133,8 +136,8 @@ def scale_batch_size( def lr_find( self, model: 'pl.LightningModule', - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[_DATALOADERS, 'pl.LightningDataModule']] = None, + val_dataloaders: Optional[_DATALOADERS] = None, datamodule: Optional['pl.LightningDataModule'] = None, min_lr: float = 1e-8, max_lr: float = 1, @@ -142,6 +145,7 @@ def lr_find( mode: str = 'exponential', early_stop_threshold: float = 4.0, update_attr: bool = False, + train_dataloader=None, # noqa ) -> Optional[_LRFinder]: """ Enables the user to do a range test of good initial learning rates, @@ -150,11 +154,13 @@ def lr_find( Args: model: Model to tune. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + val_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. + In the case of multiple dataloaders, please see this :ref:`page `. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. @@ -183,7 +189,8 @@ def lr_find( self.trainer.auto_lr_find = True result = self.trainer.tune( model, - train_dataloader=train_dataloader, + train_dataloaders=train_dataloaders, + train_dataloader=train_dataloader, # deprecated val_dataloaders=val_dataloaders, datamodule=datamodule, lr_find_kwargs={ diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index ecb0101a2279e..65f9ea4f8f256 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Iterator, List, Union import torch +from torch.utils.data import DataLoader from torchmetrics import Metric """ Convention: @@ -13,3 +14,7 @@ _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] +_DATALOADERS = Union[DataLoader, List[DataLoader], Dict[str, DataLoader], List[Dict[str, DataLoader]], + Dict[str, Dict[str, DataLoader]], List[List[DataLoader]], # ??? + Dict[str, List[DataLoader]], # ??? + 'pl.trainer.supporters.CombinedLoader'] diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py new file mode 100644 index 0000000000000..2d8a312406da8 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-6.py @@ -0,0 +1,26 @@ +import pytest + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +def test_v1_6_0_dataloader_renaming(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + dl = model.train_dataloader() + + with pytest.deprecated_call(match=r"fit\(train_dataloader\)` is deprecated in v1.4"): + trainer.fit(model, train_dataloader=dl) + + with pytest.deprecated_call(match=r"validate\(val_dataloaders\)` is deprecated in v1.4"): + trainer.validate(model, val_dataloaders=dl) + + with pytest.deprecated_call(match=r"test\(test_dataloaders\)` is deprecated in v1.4"): + trainer.test(model, test_dataloaders=dl) + + with pytest.deprecated_call(match=r"tune\(train_dataloader\)` is deprecated in v1.4"): + trainer.tune(model, train_dataloader=dl) + with pytest.deprecated_call(match=r"tune\(train_dataloader\)` is deprecated in v1.4"): + trainer.tuner.scale_batch_size(model, train_dataloader=dl) + with pytest.deprecated_call(match=r"tune\(train_dataloader\)` is deprecated in v1.4"): + trainer.tuner.lr_find(model, train_dataloader=dl) From 368e05d31607c876cfce987436a0b69077fcdd85 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 3 May 2021 12:46:54 +0200 Subject: [PATCH 03/15] Debugging --- .../test_multiple_eval_dataloaders.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py index 9a532cfe1ce47..6e653def65680 100644 --- a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -15,7 +15,7 @@ from torch.utils.data import Dataset from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel +from tests.helpers.boring_model import BoringModel, RandomDataset class RandomDatasetA(Dataset): @@ -172,3 +172,49 @@ def configure_optimizers(self): trainer.fit(model) assert model.opt_0_seen assert model.opt_1_seen + + +def test_kk(): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + print(batch) + + def train_dataloader(self): + # Dict[str, DataLoader] + loaders_a_b = { + 'a': torch.utils.data.DataLoader(RandomDataset(32, 64)), + 'b': torch.utils.data.DataLoader(RandomDataset(32, 64)) + } + loaders_c_d = { + 'c': torch.utils.data.DataLoader(RandomDataset(32, 64)), + 'd': torch.utils.data.DataLoader(RandomDataset(32, 64)) + } + + # Dict[str, Dict[str, DataLoader]] + loaders = {'loaders_a_b': loaders_a_b, 'loaders_c_d': loaders_c_d} + + # List[Dict[str, DataLoader]] + #loaders = [loaders_a_b, loaders_c_d] + + # List[DataLoader] + loaders_a_b = [ + torch.utils.data.DataLoader(RandomDataset(32, 64)), + torch.utils.data.DataLoader(RandomDataset(32, 64)) + ] + loaders_c_d = [ + torch.utils.data.DataLoader(RandomDataset(32, 64)), + torch.utils.data.DataLoader(RandomDataset(32, 64)) + ] + + # List[List[DataLoader] + #loaders = [loaders_a_b, loaders_c_d] + + # Dict[str, List[DataLoader]] + # ??? + + return loaders + + trainer = Trainer() + trainer.fit(TestModel()) From 1b8b28b35115aecf5f18dab36de059f33b003b29 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 7 May 2021 17:12:57 +0200 Subject: [PATCH 04/15] Progress --- pytorch_lightning/core/hooks.py | 21 ++++---- .../trainer/connectors/data_connector.py | 30 +++++------ pytorch_lightning/trainer/trainer.py | 53 +++++++++---------- pytorch_lightning/tuner/tuning.py | 26 ++++----- pytorch_lightning/utilities/types.py | 10 ++-- .../test_multiple_eval_dataloaders.py | 48 +---------------- 6 files changed, 67 insertions(+), 121 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 7ab0c8acbe329..57876224c42d4 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -20,7 +20,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS class ModelHooks: @@ -428,14 +428,13 @@ def teardown(self, stage: Optional[str] = None) -> None: stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` """ - def train_dataloader(self) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]: + def train_dataloader(self) -> TRAIN_DATALOADERS: """ Implement one or more PyTorch DataLoaders for training. Return: - Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these - (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see - this :ref:`page ` + A collection of :class:`torch.utils.data.DataLoader` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. The dataloader you return will not be called every epoch unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. @@ -503,7 +502,7 @@ def train_dataloader(self): """ rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") - def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def test_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for testing. @@ -533,7 +532,7 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: There is no need to set it yourself. Return: - Single or multiple PyTorch DataLoaders. + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples. Example:: @@ -563,7 +562,7 @@ def test_dataloader(self): will have an argument ``dataloader_idx`` which matches the order here. """ - def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def val_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for validation. @@ -584,7 +583,7 @@ def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: There is no need to set it yourself. Return: - Single or multiple PyTorch DataLoaders. + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. Examples:: @@ -614,7 +613,7 @@ def val_dataloader(self): will have an argument ``dataloader_idx`` which matches the order here. """ - def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def predict_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for prediction. @@ -632,7 +631,7 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: There is no need to set it yourself. Return: - Single or multiple PyTorch DataLoaders. + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples. Note: In the case where you return multiple prediction dataloaders, the :meth:`predict` diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index ba931afcbe03b..0054cbd2ce3c4 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union - -from torch.utils.data import DataLoader +from typing import Optional, Union import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import prefetch_iterator from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -class DataConnector(object): +class DataConnector: def __init__(self, trainer): self.trainer = trainer @@ -70,10 +69,10 @@ def can_prepare_data(self): def attach_data( self, model: 'pl.LightningModule', - train_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[TRAIN_DATALOADERS] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + test_dataloaders: Optional[EVAL_DATALOADERS] = None, + predict_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional['pl.LightningDataModule'] = None ) -> None: # set up the passed in dataloaders (if needed) @@ -91,10 +90,10 @@ def attach_data( def attach_dataloaders( self, model: 'pl.LightningModule', - train_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[TRAIN_DATALOADERS] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + test_dataloaders: Optional[EVAL_DATALOADERS] = None, + predict_dataloaders: Optional[EVAL_DATALOADERS] = None, ) -> None: # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations @@ -139,17 +138,16 @@ def attach_datamodule( model.data_pipeline = datamodule.data_pipeline -class _PatchDataLoader(object): +class _PatchDataLoader: r""" Callable object for patching dataloaders passed into trainer.fit(). Use this class to override model.*_dataloader() and be pickle-compatible. Args: dataloader: Dataloader object to return when called. - """ - def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): + def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: self.dataloader = dataloader # cannot pickle __code__ so cannot verify if PatchDataloader @@ -157,5 +155,5 @@ def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): # so, we hack it by using the string representation self.patch_loader_code = str(self.__call__.__code__) - def __call__(self) -> Union[List[DataLoader], DataLoader]: + def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: return self.dataloader diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a490bae22ecbb..bfefba86c9d4b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -64,7 +64,13 @@ from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _DATALOADERS, _EVALUATE_OUTPUT, _PREDICT_OUTPUT +from pytorch_lightning.utilities.types import ( + _DATALOADERS, + _EVALUATE_OUTPUT, + _PREDICT_OUTPUT, + EVAL_DATALOADERS, + TRAIN_DATALOADERS, +) log = logging.getLogger(__name__) # warnings to ignore in trainer @@ -414,10 +420,10 @@ def __init__( def fit( self, model: LightningModule, - train_dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[_DATALOADERS] = None, + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, - train_dataloader=None, # noqa + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> None: r""" Runs the full optimization routine. @@ -429,9 +435,7 @@ def fit( :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a - :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. - In the case of multiple dataloaders, please see this :ref:`page `. + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ @@ -470,11 +474,11 @@ def fit( def validate( self, model: Optional[LightningModule] = None, - dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, - val_dataloaders=None, # noqa + val_dataloaders=None, # noqa TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the validation set. @@ -482,9 +486,8 @@ def validate( Args: model: The model to validate. - dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a - :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. - In the case of multiple dataloaders, please see this :ref:`page `. + dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, + or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. If ``None``, use the current weights of the model. @@ -543,11 +546,11 @@ def validate( def test( self, model: Optional[LightningModule] = None, - dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, - test_dataloaders=None, # noqa + test_dataloaders=None, # noqa TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the test set. It's separated from @@ -556,9 +559,8 @@ def test( Args: model: The model to test. - dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a - :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. - In the case of multiple dataloaders, please see this :ref:`page `. + dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, + or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the current weights of the model. @@ -615,7 +617,7 @@ def test( def predict( self, model: Optional[LightningModule] = None, - dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ) -> Optional[_PREDICT_OUTPUT]: @@ -627,9 +629,8 @@ def predict( Args: model: The model to predict with. - dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a - :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying inference samples. - In the case of multiple dataloaders, please see this :ref:`page `. + dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, + or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples. datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. @@ -673,12 +674,12 @@ def predict( def tune( self, model: LightningModule, - train_dataloaders: Optional[Union[_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[_DATALOADERS] = None, + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, - train_dataloader=None, # noqa + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> Dict[str, Optional[Union[int, _LRFinder]]]: r""" Runs routines to tune hyperparameters before training. @@ -690,9 +691,7 @@ def tune( :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a - :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. - In the case of multiple dataloaders, please see this :ref:`page `. + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 96d62df7b6cb8..3953cdbcd4846 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -17,7 +17,7 @@ from pytorch_lightning.trainer.states import TrainerStatus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find -from pytorch_lightning.utilities.types import _DATALOADERS +from pytorch_lightning.utilities.types import _DATALOADERS, EVAL_DATALOADERS, TRAIN_DATALOADERS class Tuner: @@ -66,15 +66,15 @@ def _run(self, *args: Any, **kwargs: Any) -> None: def scale_batch_size( self, model: 'pl.LightningModule', - train_dataloaders: Optional[Union[_DATALOADERS, 'pl.LightningDataModule']] = None, - val_dataloaders: Optional[_DATALOADERS] = None, + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, 'pl.LightningDataModule']] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional['pl.LightningDataModule'] = None, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', - train_dataloader=None, # noqa + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> Optional[int]: """ Iteratively try to find the largest batch size for a given model @@ -87,9 +87,7 @@ def scale_batch_size( :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a - :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. - In the case of multiple dataloaders, please see this :ref:`page `. + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. @@ -121,7 +119,7 @@ def scale_batch_size( result = self.trainer.tune( model, train_dataloaders=train_dataloaders, - train_dataloader=train_dataloader, # deprecated + train_dataloader=train_dataloader, # TODO: deprecated - remove with 1.6 val_dataloaders=val_dataloaders, datamodule=datamodule, scale_batch_size_kwargs={ @@ -138,8 +136,8 @@ def scale_batch_size( def lr_find( self, model: 'pl.LightningModule', - train_dataloaders: Optional[Union[_DATALOADERS, 'pl.LightningDataModule']] = None, - val_dataloaders: Optional[_DATALOADERS] = None, + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, 'pl.LightningDataModule']] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional['pl.LightningDataModule'] = None, min_lr: float = 1e-8, max_lr: float = 1, @@ -147,7 +145,7 @@ def lr_find( mode: str = 'exponential', early_stop_threshold: float = 4.0, update_attr: bool = False, - train_dataloader=None, # noqa + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> Optional[_LRFinder]: """ Enables the user to do a range test of good initial learning rates, @@ -160,9 +158,7 @@ def lr_find( :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a - :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. - In the case of multiple dataloaders, please see this :ref:`page `. + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. @@ -192,7 +188,7 @@ def lr_find( result = self.trainer.tune( model, train_dataloaders=train_dataloaders, - train_dataloader=train_dataloader, # deprecated + train_dataloader=train_dataloader, # TODO: deprecated - remove with 1.6 val_dataloaders=val_dataloaders, datamodule=datamodule, lr_find_kwargs={ diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 65f9ea4f8f256..2008d433929f7 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, List, Union +from typing import Any, Dict, Iterator, List, Sequence, Union import torch from torch.utils.data import DataLoader @@ -14,7 +14,7 @@ _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] -_DATALOADERS = Union[DataLoader, List[DataLoader], Dict[str, DataLoader], List[Dict[str, DataLoader]], - Dict[str, Dict[str, DataLoader]], List[List[DataLoader]], # ??? - Dict[str, List[DataLoader]], # ??? - 'pl.trainer.supporters.CombinedLoader'] +TRAIN_DATALOADERS = Union[DataLoader, Sequence[DataLoader], Dict[str, DataLoader], + # TODO: expand allowed types once fixed. + 'pl.trainer.supporters.CombinedLoader', ] +EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] diff --git a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py index 6e653def65680..9a532cfe1ce47 100644 --- a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -15,7 +15,7 @@ from torch.utils.data import Dataset from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.boring_model import BoringModel class RandomDatasetA(Dataset): @@ -172,49 +172,3 @@ def configure_optimizers(self): trainer.fit(model) assert model.opt_0_seen assert model.opt_1_seen - - -def test_kk(): - - class TestModel(BoringModel): - - def training_step(self, batch, batch_idx): - print(batch) - - def train_dataloader(self): - # Dict[str, DataLoader] - loaders_a_b = { - 'a': torch.utils.data.DataLoader(RandomDataset(32, 64)), - 'b': torch.utils.data.DataLoader(RandomDataset(32, 64)) - } - loaders_c_d = { - 'c': torch.utils.data.DataLoader(RandomDataset(32, 64)), - 'd': torch.utils.data.DataLoader(RandomDataset(32, 64)) - } - - # Dict[str, Dict[str, DataLoader]] - loaders = {'loaders_a_b': loaders_a_b, 'loaders_c_d': loaders_c_d} - - # List[Dict[str, DataLoader]] - #loaders = [loaders_a_b, loaders_c_d] - - # List[DataLoader] - loaders_a_b = [ - torch.utils.data.DataLoader(RandomDataset(32, 64)), - torch.utils.data.DataLoader(RandomDataset(32, 64)) - ] - loaders_c_d = [ - torch.utils.data.DataLoader(RandomDataset(32, 64)), - torch.utils.data.DataLoader(RandomDataset(32, 64)) - ] - - # List[List[DataLoader] - #loaders = [loaders_a_b, loaders_c_d] - - # Dict[str, List[DataLoader]] - # ??? - - return loaders - - trainer = Trainer() - trainer.fit(TestModel()) From 5a799853eff1752f9cc5ad53c040f9db7b2af145 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 7 May 2021 17:16:10 +0200 Subject: [PATCH 05/15] Fixes --- pytorch_lightning/core/hooks.py | 3 +-- pytorch_lightning/trainer/trainer.py | 8 +------- pytorch_lightning/tuner/tuning.py | 2 +- pytorch_lightning/utilities/types.py | 5 ++--- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 57876224c42d4..2a81a67ea3217 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -13,11 +13,10 @@ # limitations under the License. """Various hooks to be used in the Lightning code.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import torch from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bfefba86c9d4b..afb8695c1c143 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -64,13 +64,7 @@ from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import ( - _DATALOADERS, - _EVALUATE_OUTPUT, - _PREDICT_OUTPUT, - EVAL_DATALOADERS, - TRAIN_DATALOADERS, -) +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS log = logging.getLogger(__name__) # warnings to ignore in trainer diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 3953cdbcd4846..449f9d862ecef 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -17,7 +17,7 @@ from pytorch_lightning.trainer.states import TrainerStatus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find -from pytorch_lightning.utilities.types import _DATALOADERS, EVAL_DATALOADERS, TRAIN_DATALOADERS +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS class Tuner: diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 2008d433929f7..9d1bc54f782aa 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -14,7 +14,6 @@ _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] -TRAIN_DATALOADERS = Union[DataLoader, Sequence[DataLoader], Dict[str, DataLoader], - # TODO: expand allowed types once fixed. - 'pl.trainer.supporters.CombinedLoader', ] +# TODO: expand allowed train_dataloaders types once fixed. +TRAIN_DATALOADERS = Union[DataLoader, Sequence[DataLoader], Dict[str, DataLoader], 'CombinedLoader'] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] From cddb65211a02b6da8ef3436a520084299049b6a5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 7 May 2021 17:21:49 +0200 Subject: [PATCH 06/15] Update CHANGELOG --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index af142fdba3414..1ec2abf4fdc49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,12 +10,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431)) + + ### Changed ### Deprecated +- Standardized the dataloaders arguments of `trainer.{fit,valdiate,test,tune}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431)) + + ### Removed From ea564b8ad7a16f2176fe121da7e17b940d1e2d9a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 7 May 2021 17:36:10 +0200 Subject: [PATCH 07/15] Update tests --- tests/core/test_datamodules.py | 4 ++-- tests/trainer/test_config_validator.py | 12 ++++-------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 0041ccb52c2bb..2faab9d8d175d 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -355,12 +355,12 @@ def test_full_loop(tmpdir): assert dm.trainer is not None # validate - result = trainer.validate(datamodule=dm) + result = trainer.validate(model, dm) assert dm.trainer is not None assert result[0]['val_acc'] > 0.7 # test - result = trainer.test(datamodule=dm) + result = trainer.test(model, dm) assert dm.trainer is not None assert result[0]['test_acc'] > 0.6 diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 9fccd9b36440a..344709af8d451 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -128,17 +128,13 @@ def test_dataloader(self): def predict_dataloader(self): return self._dataloaders - dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + data = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + if datamodule: + data = TestLightningDataModule(data) model = TestModel() - trainer = Trainer(default_root_dir=tmpdir) - - if datamodule: - datamodule = TestLightningDataModule(dataloaders) - results = trainer.predict(model, datamodule=datamodule) - else: - results = trainer.predict(model, dataloaders=dataloaders) + results = trainer.predict(model, data) assert len(results) == 2 assert results[0][0].shape == torch.Size([1, 2]) From 1de17925f1c4ede2b71f202315a4d6e36210b05f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 7 May 2021 17:36:27 +0200 Subject: [PATCH 08/15] Update docs --- docs/source/common/lightning_module.rst | 4 ++-- docs/source/common/test_set.rst | 4 ++-- docs/source/common/trainer.rst | 2 +- docs/source/extensions/datamodules.rst | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 3865400121fe2..2196dba5f90f9 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -441,12 +441,12 @@ There are two ways to call `test()`: trainer.fit(model) # automatically auto-loads the best weights - trainer.test(test_dataloaders=test_dataloader) + trainer.test(dataloaders=test_dataloader) # or call with pretrained model model = MyLightningModule.load_from_checkpoint(PATH) trainer = Trainer() - trainer.test(model, test_dataloaders=test_dataloader) + trainer.test(model, dataloaders=test_dataloader) ---------- diff --git a/docs/source/common/test_set.rst b/docs/source/common/test_set.rst index 4c9e9a6061977..5703d71d956de 100644 --- a/docs/source/common/test_set.rst +++ b/docs/source/common/test_set.rst @@ -80,10 +80,10 @@ is not available at the time your model was declared. .. code-block:: python # setup your data loader - test = DataLoader(...) + test_dataloader = DataLoader(...) # test (pass in the loader) - trainer.test(test_dataloaders=test) + trainer.test(dataloaders=test_dataloader) You can either pass in a single dataloader or a list of them. This optional named parameter can be used in conjunction with any of the above use cases. Additionally, diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index f9275fcbd898b..e5c4d76bc7334 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -159,7 +159,7 @@ or after it has already been trained. .. code-block:: python - trainer.validate(val_dataloaders=val_dataloaders) + trainer.validate(dataloaders=val_dataloaders) ------------ diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index a602a75b0f877..5129093bf614f 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -53,7 +53,7 @@ Datamodules are for you if you ever asked the questions: What is a DataModule -------------------- -A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the +A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required. Here's a simple PyTorch example: From 2b577da22a0face1ad9b5cbf3fd974467a7ebb5d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 7 May 2021 17:42:07 +0200 Subject: [PATCH 09/15] Fix tests --- pytorch_lightning/utilities/types.py | 5 ++++- tests/plugins/test_tpu_spawn.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 9d1bc54f782aa..abde43c0233b3 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -15,5 +15,8 @@ _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] # TODO: expand allowed train_dataloaders types once fixed. -TRAIN_DATALOADERS = Union[DataLoader, Sequence[DataLoader], Dict[str, DataLoader], 'CombinedLoader'] +TRAIN_DATALOADERS = Union[DataLoader, Sequence[DataLoader], Dict[str, DataLoader], + # can't import, otherwise circular imports + 'CombinedLoader', # noqa: F821 undefined name 'CombinedLoader' + ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 8aa56c636cf47..a05be79e17fef 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -44,7 +44,7 @@ def predict_dataloader(self): @pytest.mark.parametrize( - "train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders", + "train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders", [ (_loader_no_len, None, None, None), (None, _loader_no_len, None, None), @@ -55,14 +55,14 @@ def predict_dataloader(self): ) @mock.patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm") def test_error_patched_iterable_dataloaders( - _, tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders + _, tmpdir, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders ): model = BoringModelNoDataloaders() connector = DataConnector(MagicMock()) connector.attach_dataloaders( model, - train_dataloader=train_dataloader, + train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, test_dataloaders=test_dataloaders, predict_dataloaders=predict_dataloaders, From 2ff11bba9530663eef97ee00d043e22f08459b53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 May 2021 18:19:38 +0000 Subject: [PATCH 10/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/governance.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/governance.rst b/docs/source/governance.rst index fac8b68e1df53..5b1f9bd1916c1 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -38,5 +38,3 @@ Alumni - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) - - From 0e2559c785c8b8d4c5f1f0b719a08e963d113d53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 May 2021 21:29:08 +0000 Subject: [PATCH 11/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-6.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 57343afdb8b06..6ee4d0cff3e80 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -29,7 +29,7 @@ def test_v1_6_0_trainer_model_hook_mixin(tmpdir): with pytest.deprecated_call(match="is deprecated in v1.4 and will be removed in v1.6"): trainer.has_arg("training_step", "batch") - + def test_v1_6_0_dataloader_renaming(tmpdir): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) From d17f2773426bdf69246a1a54a43da0f5fef5a562 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 May 2021 10:26:03 +0000 Subject: [PATCH 12/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-6.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 06c16d6030185..160f6119f9510 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -52,7 +52,7 @@ def test_v1_6_0_dataloader_renaming(tmpdir): with pytest.deprecated_call(match=r"tune\(train_dataloader\)` is deprecated in v1.4"): trainer.tuner.lr_find(model, train_dataloader=dl) - + def test_v1_6_0_ddp_num_nodes(): with pytest.deprecated_call(match="Argument `num_nodes` in `DDPPlugin` is deprecated in v1.4"): DDPPlugin(num_nodes=1) From bfabb36ee50fd8fd7c08afbc5c02f9a07372b725 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 13 May 2021 02:39:13 +0200 Subject: [PATCH 13/15] Resolve TODO --- pytorch_lightning/utilities/types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 8cde6896ae4e6..4f547898c51c7 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -28,8 +28,9 @@ _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] -# TODO: expand allowed train_dataloaders types once fixed. -TRAIN_DATALOADERS = Union[DataLoader, Sequence[DataLoader], Dict[str, DataLoader], +TRAIN_DATALOADERS = Union[DataLoader, Sequence[DataLoader], Sequence[Sequence[DataLoader]], Sequence[Dict[str, + DataLoader]], + Dict[str, DataLoader], Dict[str, Dict[str, DataLoader]], Dict[str, Sequence[DataLoader]], # can't import, otherwise circular imports 'CombinedLoader', # noqa: F821 undefined name 'CombinedLoader' ] From f59e2f7bdc7f869c3ec0c62465a803fdc611c28e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 13 May 2021 02:41:59 +0200 Subject: [PATCH 14/15] Resolve TODO --- pytorch_lightning/utilities/types.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 4f547898c51c7..0950a355c8eb9 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -28,10 +28,17 @@ _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] -TRAIN_DATALOADERS = Union[DataLoader, Sequence[DataLoader], Sequence[Sequence[DataLoader]], Sequence[Dict[str, - DataLoader]], - Dict[str, DataLoader], Dict[str, Dict[str, DataLoader]], Dict[str, Sequence[DataLoader]], - # can't import, otherwise circular imports - 'CombinedLoader', # noqa: F821 undefined name 'CombinedLoader' - ] +# yapf: disable +TRAIN_DATALOADERS = Union[ + DataLoader, + Sequence[DataLoader], + Sequence[Sequence[DataLoader]], + Sequence[Dict[str, DataLoader]], + Dict[str, DataLoader], + Dict[str, Dict[str, DataLoader]], + Dict[str, Sequence[DataLoader]], + # can't import, otherwise circular imports + 'CombinedLoader', # noqa: F821 undefined name 'CombinedLoader' +] +# yapf: enable EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] From 8950e113cd53157085e2f927e4de7b6890659f4b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 15 Jun 2021 13:25:57 +0200 Subject: [PATCH 15/15] Remove CombinedDataLoader --- pytorch_lightning/utilities/types.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 61814e71e9def..95ee9028ebace 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -39,8 +39,6 @@ Dict[str, DataLoader], Dict[str, Dict[str, DataLoader]], Dict[str, Sequence[DataLoader]], - # can't import, otherwise circular imports - 'CombinedLoader', # noqa: F821 undefined name 'CombinedLoader' ] # yapf: enable EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]