From bdc4272e99813997668d75457b458e2559e78a65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 28 Apr 2021 17:41:08 +0200 Subject: [PATCH] `_launch` refactor and types [1/n] (#7232) --- .../training_type/training_type_plugin.py | 7 +- .../logger_connector/logger_connector.py | 3 +- pytorch_lightning/trainer/predict_loop.py | 8 +- pytorch_lightning/trainer/trainer.py | 118 +++++++++--------- pytorch_lightning/tuner/batch_size_scaling.py | 4 +- pytorch_lightning/tuner/lr_finder.py | 4 +- pytorch_lightning/tuner/tuning.py | 9 +- pytorch_lightning/utilities/types.py | 2 + 8 files changed, 86 insertions(+), 69 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 15be889c85e3e..f4cf24b9285b7 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -37,7 +38,7 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None - self._results = None + self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None self._call_configure_sharded_model_hook = True def connect(self, model: Module) -> None: @@ -124,12 +125,12 @@ def lightning_module(self) -> 'pl.LightningModule': return unwrap_lightning_module(self._model) @property - def results(self) -> Any: + def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: """ The results of the last training/testing run will be cached here. In distributed training, we make sure to transfer the results to the appropriate master process. """ - # TODO: improve these docs + # TODO(@awaelchli): improve these docs return self._results @property diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 796b381e95223..932e6a49dcb6b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -27,6 +27,7 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DeviceType from pytorch_lightning.utilities.metrics import metrics_to_scalars +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT class LoggerConnector: @@ -267,7 +268,7 @@ def prepare_eval_loop_results(self): for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): self.add_to_eval_loop_results(dl_idx, has_been_initialized) - def get_evaluate_epoch_results(self): + def get_evaluate_epoch_results(self) -> _EVALUATE_OUTPUT: if not self.trainer.sanity_checking: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() diff --git a/pytorch_lightning/trainer/predict_loop.py b/pytorch_lightning/trainer/predict_loop.py index e8d3221f4f0e9..4815987e26240 100644 --- a/pytorch_lightning/trainer/predict_loop.py +++ b/pytorch_lightning/trainer/predict_loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union +from typing import Any, List, Optional import torch from torch.utils.data.dataloader import DataLoader @@ -19,6 +19,7 @@ from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.plugins import DDPSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _PREDICT_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -31,6 +32,7 @@ def __init__(self, trainer): self.warning_cache = WarningCache() self.batch_indices: Optional[List[int]] = None self.epoch_batch_indices: Optional[List[List[int]]] = None + self.predictions: Optional[List[List[Any]]] = None # `DDPSpawnPlugin` plugins and derivate don't support return predictions. self._return_predictions: Optional[bool] = None self._previous_grad_status: Optional[bool] = None @@ -138,10 +140,10 @@ def on_predict_start(self) -> None: self.trainer.call_hook("on_predict_start") self.trainer.call_hook("on_predict_epoch_start") - def on_predict_epoch_end(self) -> Optional[Union[List[Any], List[List[Any]]]]: + def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: self.trainer.profiler.describe() - results: List[List[Any]] = self.predictions + results = self.predictions self.trainer.call_hook("on_predict_epoch_end", results) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fa42a75c24829..133e25ee5ffc4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -63,6 +63,7 @@ from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT log = logging.getLogger(__name__) # warnings to ignore in trainer @@ -408,36 +409,13 @@ def __init__( # Callback system self.on_init_end() - def fit( + def _launch( self, model: LightningModule, train_dataloader: Any = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, - ): - r""" - Runs the full optimization routine. - - Args: - datamodule: A instance of :class:`LightningDataModule`. - - model: Model to fit. - - train_dataloader: Either a single PyTorch DataLoader or a collection of these - (list, dict, nested lists and dicts). In the case of multiple dataloaders, please - see this :ref:`page ` - - val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined val_dataloaders method this will be skipped - - """ - Trainer._log_api_event("fit") - # we reuse fit for other functions. When already set, it shouldn't be modified. - if not self.state.running: - self.state = TrainerState.FITTING - if self._running_stage is None or self.tuning: - self.training = True - + ) -> Union[int, _EVALUATE_OUTPUT, _PREDICT_OUTPUT]: # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -545,18 +523,14 @@ def dispatch(self): else: self.accelerator.start_training(self) - def run_stage(self): - results = None - + def run_stage(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: self.profile_connector.setup() if self.evaluating: - results = self.run_evaluate() - elif self.predicting: - results = self.run_predict() - else: - self.run_train() - return results + return self.run_evaluate() + if self.predicting: + return self.run_predict() + return self.run_train() def _pre_training_routine(self): # wait for all to join if on distributed @@ -586,7 +560,6 @@ def _pre_training_routine(self): ref_model.on_pretrain_routine_end() def run_train(self) -> None: - self._pre_training_routine() if not self.is_global_zero and self.progress_bar_callback is not None: @@ -660,7 +633,7 @@ def run_train(self) -> None: self._running_stage = None raise - def run_evaluation(self, on_epoch=False): + def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: if not (self.evaluating or self.sanity_checking): rank_zero_warn( f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}." @@ -777,7 +750,7 @@ def track_output_for_epoch_end(self, outputs, output): outputs.append(output) return outputs - def run_evaluate(self): + def run_evaluate(self) -> _EVALUATE_OUTPUT: if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() @@ -786,9 +759,6 @@ def run_evaluate(self): with self.profiler.profile(f"run_{self._running_stage}_evaluation"): eval_loop_results = self.run_evaluation() - if len(eval_loop_results) == 0: - return 1 - # remove the tensors from the eval results for i, result in enumerate(eval_loop_results): if isinstance(result, dict): @@ -798,7 +768,7 @@ def run_evaluate(self): return eval_loop_results - def run_predict(self): + def run_predict(self) -> Optional[_PREDICT_OUTPUT]: # prepare dataloaders dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() @@ -860,6 +830,42 @@ def run_sanity_check(self, ref_model): # prevents sanity check to affect random sampling in training reset_seed() + def fit( + self, + model: LightningModule, + train_dataloader: Any = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + ) -> Optional[int]: + r""" + Runs the full optimization routine. + + Args: + model: Model to fit. + + train_dataloader: Either a single PyTorch DataLoader or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please + see this :ref:`page ` + + val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + If the model has a predefined val_dataloaders method this will be skipped + + datamodule: A instance of :class:`LightningDataModule`. + """ + Trainer._log_api_event("fit") + + self.state = TrainerState.FITTING + self.training = True + + results = self._launch( + model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + ) + + assert self.state.stopped + self.training = False + + return results + def validate( self, model: Optional[LightningModule] = None, @@ -867,7 +873,7 @@ def validate( ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, - ): + ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the validation set. @@ -914,10 +920,10 @@ def validate( self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders) if not model_provided: - self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + self.validated_ckpt_path = self.__load_ckpt_weights(ckpt_path) # run validate - results = self.fit(model) + results = self._launch(model) assert self.state.stopped self.validating = False @@ -931,7 +937,7 @@ def test( ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, - ): + ) -> _EVALUATE_OUTPUT: r""" Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. @@ -975,21 +981,17 @@ def test( self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) if not model_provided: - self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path) + self.tested_ckpt_path = self.__load_ckpt_weights(ckpt_path) # run test - results = self.fit(model) + results = self._launch(model) assert self.state.stopped self.testing = False return results - def __load_ckpt_weights( - self, - model, - ckpt_path: Optional[str] = None, - ) -> Optional[str]: + def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]: if ckpt_path is None: return @@ -1031,7 +1033,7 @@ def predict( dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, return_predictions: Optional[bool] = None, - ): + ) -> Optional[_PREDICT_OUTPUT]: r""" Separates from fit to make sure you never run on your predictions set until you want to. @@ -1039,7 +1041,9 @@ def predict( Args: model: The model to predict with. + dataloaders: Either a single PyTorch DataLoader or a list of them, specifying inference samples. + datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. return_predictions: Whether to return predictions. @@ -1063,16 +1067,14 @@ def predict( self.predicting = True if dataloaders is not None and datamodule: - raise MisconfigurationException( - 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.' - ) + raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`') # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model, datamodule) # Attach dataloaders (if given) self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders) - results = self.fit(model) + results = self._launch(model) assert self.state.stopped self.predicting = False @@ -1085,7 +1087,7 @@ def tune( train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, - ): + ) -> None: r""" Runs routines to tune hyperparameters before training. diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 9c5e966c14cc1..7e9dc524099de 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -189,7 +189,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f trainer.global_step = 0 # reset after each try try: # Try fit - trainer.fit(model, **fit_kwargs) + trainer.tuner._launch(model, **fit_kwargs) # Double in size new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: @@ -218,7 +218,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, trainer.global_step = 0 # reset after each try try: # Try fit - trainer.fit(model, **fit_kwargs) + trainer.tuner._launch(model, **fit_kwargs) count += 1 if count > max_trials: break diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 2d122a3d30cfd..14f21da856145 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -176,7 +176,9 @@ def lr_find( model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # Fit, lr & loss logged in callback - trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) + trainer.tuner._launch( + model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule + ) # Prompt if we stopped early if trainer.global_step != num_training: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index a7aa1ee256a5d..9d471e2c5cbca 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import Any, List, Optional, Union from torch.utils.data import DataLoader @@ -71,6 +71,13 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): self.trainer.state = TrainerState.FINISHED + def _launch(self, *args: Any, **kwargs: Any) -> None: + """`_launch` wrapper to set the proper state during tuning, as this can be called multiple times""" + self.trainer.state = TrainerState.TUNING # last `_launch` call might have set it to `FINISHED` + self.trainer.training = True + self.trainer._launch(*args, **kwargs) + self.trainer.tuning = True + def scale_batch_size( self, model, diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index c1c40b98c71c7..ecb0101a2279e 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -10,4 +10,6 @@ _METRIC = Union[Metric, torch.Tensor, int, float] STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]] EPOCH_OUTPUT = List[STEP_OUTPUT] +_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader +_PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter]