diff --git a/CHANGELOG.md b/CHANGELOG.md index 0829c7e069ff2..cd2b01db6c0aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574)) +- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431)) + + - Added argument `trainer.predict(ckpt_path)` ([#7430](https://github.com/PyTorchLightning/pytorch-lightning/pull/7430)) @@ -174,6 +177,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Standardized the dataloaders arguments of `trainer.{fit,valdiate,test,tune}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431)) + + - Deprecated `DataModule` properties: `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test`, `has_teardown_predict` ([#7657](https://github.com/PyTorchLightning/pytorch-lightning/pull/7657/)) diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index 1a82641953c3c..02d5db143c95c 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -91,23 +91,6 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra Furthermore, Lightning also supports that nested lists and dicts (or a combination) can be returned. -.. testcode:: - - class LitModel(LightningModule): - - def train_dataloader(self): - - loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) - loader_b = torch.utils.data.DataLoader(range(16), batch_size=2) - - return {'a': loader_a, 'b': loader_b} - - def training_step(self, batch, batch_idx): - # access a dictionnary with a batch from each dataloader - batch_a = batch["a"] - batch_b = batch["b"] - - .. testcode:: class LitModel(LightningModule): diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 532691455ed82..824fa7a2513f2 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -441,12 +441,12 @@ There are two ways to call `test()`: trainer.fit(model) # automatically auto-loads the best weights - trainer.test(test_dataloaders=test_dataloader) + trainer.test(dataloaders=test_dataloader) # or call with pretrained model model = MyLightningModule.load_from_checkpoint(PATH) trainer = Trainer() - trainer.test(model, test_dataloaders=test_dataloader) + trainer.test(model, dataloaders=test_dataloader) ---------- diff --git a/docs/source/common/test_set.rst b/docs/source/common/test_set.rst index 4c9e9a6061977..5703d71d956de 100644 --- a/docs/source/common/test_set.rst +++ b/docs/source/common/test_set.rst @@ -80,10 +80,10 @@ is not available at the time your model was declared. .. code-block:: python # setup your data loader - test = DataLoader(...) + test_dataloader = DataLoader(...) # test (pass in the loader) - trainer.test(test_dataloaders=test) + trainer.test(dataloaders=test_dataloader) You can either pass in a single dataloader or a list of them. This optional named parameter can be used in conjunction with any of the above use cases. Additionally, diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index ea2c425c11e0e..ea32ea3dd55dc 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -159,7 +159,7 @@ or after it has already been trained. .. code-block:: python - trainer.validate(val_dataloaders=val_dataloaders) + trainer.validate(dataloaders=val_dataloaders) ------------ diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 27fdf176f5554..b710a43b2c580 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -53,7 +53,7 @@ Datamodules are for you if you ever asked the questions: What is a DataModule -------------------- -A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the +A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required. Here's a simple PyTorch example: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b015b9c809b93..50b058c3c24c2 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -13,14 +13,13 @@ # limitations under the License. """Various hooks to be used in the Lightning code.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import torch from torch.optim.optimizer import Optimizer -from torch.utils.data import DataLoader from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS class ModelHooks: @@ -428,14 +427,13 @@ def teardown(self, stage: Optional[str] = None) -> None: stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` """ - def train_dataloader(self) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]: + def train_dataloader(self) -> TRAIN_DATALOADERS: """ Implement one or more PyTorch DataLoaders for training. Return: - Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these - (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see - this :ref:`page ` + A collection of :class:`torch.utils.data.DataLoader` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. The dataloader you return will not be called every epoch unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. @@ -503,7 +501,7 @@ def train_dataloader(self): """ rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") - def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def test_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for testing. @@ -533,7 +531,7 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: There is no need to set it yourself. Return: - Single or multiple PyTorch DataLoaders. + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples. Example:: @@ -563,7 +561,7 @@ def test_dataloader(self): will have an argument ``dataloader_idx`` which matches the order here. """ - def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def val_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for validation. @@ -584,7 +582,7 @@ def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: There is no need to set it yourself. Return: - Single or multiple PyTorch DataLoaders. + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. Examples:: @@ -614,7 +612,7 @@ def val_dataloader(self): will have an argument ``dataloader_idx`` which matches the order here. """ - def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: + def predict_dataloader(self) -> EVAL_DATALOADERS: r""" Implement one or multiple PyTorch DataLoaders for prediction. @@ -632,7 +630,7 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]: There is no need to set it yourself. Return: - Single or multiple PyTorch DataLoaders. + A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples. Note: In the case where you return multiple prediction dataloaders, the :meth:`predict` diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index a867bf96a8d77..4ff7e5aa21a42 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union - -from torch.utils.data import DataLoader +from typing import Optional, Union import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import prefetch_iterator from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -class DataConnector(object): +class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer @@ -71,16 +70,16 @@ def can_prepare_data(self): def attach_data( self, model: 'pl.LightningModule', - train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[TRAIN_DATALOADERS] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + test_dataloaders: Optional[EVAL_DATALOADERS] = None, + predict_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional['pl.LightningDataModule'] = None ) -> None: # set up the passed in dataloaders (if needed) self.attach_dataloaders( model, - train_dataloader=train_dataloader, + train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, test_dataloaders=test_dataloaders, predict_dataloaders=predict_dataloaders, @@ -92,15 +91,15 @@ def attach_data( def attach_dataloaders( self, model: 'pl.LightningModule', - train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[TRAIN_DATALOADERS] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + test_dataloaders: Optional[EVAL_DATALOADERS] = None, + predict_dataloaders: Optional[EVAL_DATALOADERS] = None, ) -> None: # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations - if train_dataloader is not None: - model.train_dataloader = _PatchDataLoader(train_dataloader) + if train_dataloaders is not None: + model.train_dataloader = _PatchDataLoader(train_dataloaders) if val_dataloaders is not None: model.val_dataloader = _PatchDataLoader(val_dataloaders) @@ -140,17 +139,16 @@ def attach_datamodule( model.data_pipeline = datamodule.data_pipeline -class _PatchDataLoader(object): +class _PatchDataLoader: r""" Callable object for patching dataloaders passed into trainer.fit(). Use this class to override model.*_dataloader() and be pickle-compatible. Args: dataloader: Dataloader object to return when called. - """ - def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): + def __init__(self, dataloader: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: self.dataloader = dataloader # cannot pickle __code__ so cannot verify if PatchDataloader @@ -158,5 +156,5 @@ def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): # so, we hack it by using the string representation self.patch_loader_code = str(self.__call__.__code__) - def __call__(self) -> Union[List[DataLoader], DataLoader]: + def __call__(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: return self.dataloader diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a6f93d9b4263d..675ea0b70f1c8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -21,7 +21,6 @@ from weakref import proxy import torch -from torch.utils.data import DataLoader from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback @@ -65,13 +64,13 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_warn +from pytorch_lightning.utilities import DeviceType, parsing, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS log = logging.getLogger(__name__) # warnings to ignore in trainer @@ -416,9 +415,10 @@ def __init__( def fit( self, model: LightningModule, - train_dataloader: Any = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> None: r""" Runs the full optimization routine. @@ -426,12 +426,11 @@ def fit( Args: model: Model to fit. - train_dataloader: Either a single PyTorch DataLoader or a collection of these - (list, dict, nested lists and dicts). In the case of multiple dataloaders, please - see this :ref:`page ` + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ @@ -441,19 +440,25 @@ def fit( self.state.status = TrainerStatus.RUNNING self.training = True + if train_dataloader is not None: + rank_zero_deprecation( + "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6." + " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'" + ) + train_dataloaders = train_dataloader # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None + if isinstance(train_dataloaders, LightningDataModule): + datamodule = train_dataloaders + train_dataloaders = None # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None: raise MisconfigurationException( 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`' ) # links data to the trainer self.data_connector.attach_data( - model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) self._run(model) @@ -464,10 +469,11 @@ def fit( def validate( self, model: Optional[LightningModule] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + val_dataloaders=None, # noqa TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the validation set. @@ -475,8 +481,8 @@ def validate( Args: model: The model to validate. - val_dataloaders: Either a single PyTorch DataLoader or a list of them, - specifying validation samples. + dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, + or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. If ``None``, use the current weights of the model. @@ -501,11 +507,19 @@ def validate( self.state.status = TrainerStatus.RUNNING self.validating = True - # If you supply a datamodule you can't supply val_dataloaders - if val_dataloaders is not None and datamodule: - raise MisconfigurationException( - 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`' + if val_dataloaders is not None: + rank_zero_deprecation( + "`trainer.validate(val_dataloaders)` is deprecated in v1.4 and will be removed in v1.6." + " Use `trainer.validate(dataloaders)` instead." ) + dataloaders = val_dataloaders + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(dataloaders, LightningDataModule): + datamodule = dataloaders + dataloaders = None + # If you supply a datamodule you can't supply val_dataloaders + if dataloaders is not None and datamodule: + raise MisconfigurationException('You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`') model_provided = model is not None model = model or self.lightning_module @@ -515,7 +529,7 @@ def validate( ) # links data to the trainer - self.data_connector.attach_data(model, val_dataloaders=val_dataloaders, datamodule=datamodule) + self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) if not model_provided: self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) @@ -531,10 +545,11 @@ def validate( def test( self, model: Optional[LightningModule] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, + test_dataloaders=None, # noqa TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the test set. It's separated from @@ -543,8 +558,8 @@ def test( Args: model: The model to test. - test_dataloaders: Either a single PyTorch DataLoader or a list of them, - specifying test samples. + dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, + or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the current weights of the model. @@ -567,9 +582,19 @@ def test( self.state.status = TrainerStatus.RUNNING self.testing = True + if test_dataloaders is not None: + rank_zero_deprecation( + "`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6." + " Use `trainer.test(dataloaders)` instead." + ) + dataloaders = test_dataloaders + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(dataloaders, LightningDataModule): + datamodule = dataloaders + dataloaders = None # If you supply a datamodule you can't supply test_dataloaders - if test_dataloaders is not None and datamodule: - raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`') + if dataloaders is not None and datamodule: + raise MisconfigurationException('You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`') model_provided = model is not None model = model or self.lightning_module @@ -579,7 +604,7 @@ def test( ) # links data to the trainer - self.data_connector.attach_data(model, test_dataloaders=test_dataloaders, datamodule=datamodule) + self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) if not model_provided: self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) @@ -595,7 +620,7 @@ def test( def predict( self, model: Optional[LightningModule] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, ckpt_path: Optional[str] = 'best', @@ -608,7 +633,8 @@ def predict( Args: model: The model to predict with. - dataloaders: Either a single PyTorch DataLoader or a list of them, specifying inference samples. + dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, + or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples. datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. @@ -634,6 +660,10 @@ def predict( self.predict_loop.return_predictions = return_predictions + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(dataloaders, LightningDataModule): + datamodule = dataloaders + dataloaders = None if dataloaders is not None and datamodule: raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') @@ -660,11 +690,12 @@ def predict( def tune( self, model: LightningModule, - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> Dict[str, Optional[Union[int, _LRFinder]]]: r""" Runs routines to tune hyperparameters before training. @@ -672,11 +703,11 @@ def tune( Args: model: Model to tune. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. @@ -690,19 +721,25 @@ def tune( self.state.status = TrainerStatus.RUNNING self.tuning = True + if train_dataloader is not None: + rank_zero_deprecation( + "`trainer.tune(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6." + " Use `trainer.tune(train_dataloaders)` instead. HINT: added 's'" + ) + train_dataloaders = train_dataloader # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None + if isinstance(train_dataloaders, LightningDataModule): + datamodule = train_dataloaders + train_dataloaders = None # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None: raise MisconfigurationException( 'You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`' ) # links data to the trainer self.data_connector.attach_data( - model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule ) result = self.tuner._tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs, lr_find_kwargs=lr_find_kwargs) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index a25b950ee3fca..449f9d862ecef 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union - -from torch.utils.data import DataLoader +from typing import Any, Dict, Optional, Union import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerStatus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS class Tuner: @@ -67,14 +66,15 @@ def _run(self, *args: Any, **kwargs: Any) -> None: def scale_batch_size( self, model: 'pl.LightningModule', - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, 'pl.LightningDataModule']] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional['pl.LightningDataModule'] = None, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> Optional[int]: """ Iteratively try to find the largest batch size for a given model @@ -83,11 +83,11 @@ def scale_batch_size( Args: model: Model to tune. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. @@ -118,7 +118,8 @@ def scale_batch_size( self.trainer.auto_scale_batch_size = True result = self.trainer.tune( model, - train_dataloader=train_dataloader, + train_dataloaders=train_dataloaders, + train_dataloader=train_dataloader, # TODO: deprecated - remove with 1.6 val_dataloaders=val_dataloaders, datamodule=datamodule, scale_batch_size_kwargs={ @@ -135,8 +136,8 @@ def scale_batch_size( def lr_find( self, model: 'pl.LightningModule', - train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, 'pl.LightningDataModule']] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional['pl.LightningDataModule'] = None, min_lr: float = 1e-8, max_lr: float = 1, @@ -144,6 +145,7 @@ def lr_find( mode: str = 'exponential', early_stop_threshold: float = 4.0, update_attr: bool = False, + train_dataloader=None, # noqa TODO: remove with 1.6 ) -> Optional[_LRFinder]: """ Enables the user to do a range test of good initial learning rates, @@ -152,11 +154,11 @@ def lr_find( Args: model: Model to tune. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a + :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. + In the case of multiple dataloaders, please see this :ref:`page `. - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped + val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. @@ -185,7 +187,8 @@ def lr_find( self.trainer.auto_lr_find = True result = self.trainer.tune( model, - train_dataloader=train_dataloader, + train_dataloaders=train_dataloaders, + train_dataloader=train_dataloader, # TODO: deprecated - remove with 1.6 val_dataloaders=val_dataloaders, datamodule=datamodule, lr_find_kwargs={ diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index a04f7ba87284d..95ee9028ebace 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,9 +17,10 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) """ from numbers import Number -from typing import Any, Dict, Iterator, List, Mapping, Union +from typing import Any, Dict, Iterator, List, Mapping, Sequence, Union import torch +from torch.utils.data import DataLoader from torchmetrics import Metric _METRIC = Union[Metric, torch.Tensor, Number] @@ -29,3 +30,15 @@ _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] +# yapf: disable +TRAIN_DATALOADERS = Union[ + DataLoader, + Sequence[DataLoader], + Sequence[Sequence[DataLoader]], + Sequence[Dict[str, DataLoader]], + Dict[str, DataLoader], + Dict[str, Dict[str, DataLoader]], + Dict[str, Sequence[DataLoader]], +] +# yapf: enable +EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 66abba8d2ca67..30131cdcc80d2 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -380,12 +380,12 @@ def test_full_loop(tmpdir): assert dm.trainer is not None # validate - result = trainer.validate(datamodule=dm) + result = trainer.validate(model, dm) assert dm.trainer is not None assert result[0]['val_acc'] > 0.7 # test - result = trainer.test(datamodule=dm) + result = trainer.test(model, dm) assert dm.trainer is not None assert result[0]['test_acc'] > 0.6 diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index ced066381a6de..cb150cb013ec2 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -32,6 +32,28 @@ def test_v1_6_0_trainer_model_hook_mixin(tmpdir): trainer.has_arg("training_step", "batch") +def test_v1_6_0_dataloader_renaming(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + dl = model.train_dataloader() + + with pytest.deprecated_call(match=r"fit\(train_dataloader\)` is deprecated in v1.4"): + trainer.fit(model, train_dataloader=dl) + + with pytest.deprecated_call(match=r"validate\(val_dataloaders\)` is deprecated in v1.4"): + trainer.validate(model, val_dataloaders=dl) + + with pytest.deprecated_call(match=r"test\(test_dataloaders\)` is deprecated in v1.4"): + trainer.test(model, test_dataloaders=dl) + + with pytest.deprecated_call(match=r"tune\(train_dataloader\)` is deprecated in v1.4"): + trainer.tune(model, train_dataloader=dl) + with pytest.deprecated_call(match=r"tune\(train_dataloader\)` is deprecated in v1.4"): + trainer.tuner.scale_batch_size(model, train_dataloader=dl) + with pytest.deprecated_call(match=r"tune\(train_dataloader\)` is deprecated in v1.4"): + trainer.tuner.lr_find(model, train_dataloader=dl) + + def test_old_transfer_batch_to_device_hook(tmpdir): class OldModel(BoringModel): diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 85e1ecb781946..54c65c336fdd3 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -49,7 +49,7 @@ def predict_dataloader(self): @pytest.mark.parametrize( - "train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders", + "train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders", [ (_loader_no_len, None, None, None), (None, _loader_no_len, None, None), @@ -60,14 +60,14 @@ def predict_dataloader(self): ) @mock.patch("pytorch_lightning.plugins.training_type.tpu_spawn.xm") def test_error_patched_iterable_dataloaders( - _, tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders + _, tmpdir, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders ): model = BoringModelNoDataloaders() connector = DataConnector(MagicMock()) connector.attach_dataloaders( model, - train_dataloader=train_dataloader, + train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, test_dataloaders=test_dataloaders, predict_dataloaders=predict_dataloaders, diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index f31829d02f2d9..6762d65f41bab 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -128,17 +128,13 @@ def test_dataloader(self): def predict_dataloader(self): return self._dataloaders - dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + data = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + if datamodule: + data = TestLightningDataModule(data) model = TestModel() - trainer = Trainer(default_root_dir=tmpdir) - - if datamodule: - datamodule = TestLightningDataModule(dataloaders) - results = trainer.predict(model, datamodule=datamodule) - else: - results = trainer.predict(model, dataloaders=dataloaders) + results = trainer.predict(model, data) assert len(results) == 2 assert results[0][0].shape == torch.Size([1, 2])