diff --git a/CHANGELOG.md b/CHANGELOG.md index ff4a53ed8ef62..7069fb7354dbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,16 +33,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477)) +- The `trainer.lightning_module` reference is now properly set at the very beginning of the run ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536)) + + - Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))) +- The `Trainer` functions `reset_{train,val,test,predict}_dataloader`, `reset_train_val_dataloaders`, and `request_dataloader` `model` argument is now optional ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536)) + + - Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) - Improved string conversion for `ResultCollection` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622)) -- +- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536)) ### Deprecated diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f098e2347135f..905d732ee5feb 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -75,15 +75,14 @@ def setup_environment(self) -> None: """ self.training_type_plugin.setup_environment() - def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer") -> None: """ Setup plugins for the trainer fit and creates optimizers. Args: trainer: the trainer instance - model: the LightningModule """ - self.setup_training_type_plugin(model) + self.setup_training_type_plugin() if not self.training_type_plugin.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) self.setup_precision_plugin() @@ -334,9 +333,9 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: self.lr_schedulers = lr_schedulers self.optimizer_frequencies = optimizer_frequencies - def setup_training_type_plugin(self, model: "pl.LightningModule") -> None: + def setup_training_type_plugin(self) -> None: """Attaches the training type plugin to the accelerator.""" - self.training_type_plugin.setup(model) + self.training_type_plugin.setup() def setup_precision_plugin(self) -> None: """Attaches the precision plugin to the accelerator""" @@ -460,7 +459,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: "pl.Li rank_zero_warn( "Accelerator method `connect_training_type_plugin` was deprecated in v1.3. It will be removed in v1.5." ) - self.setup_training_type_plugin(model) + self.setup_training_type_plugin() # todo: remove in v1.5 def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: diff --git a/pytorch_lightning/accelerators/cpu.py b/pytorch_lightning/accelerators/cpu.py index fc61ca54ecda1..48957219a1ec0 100644 --- a/pytorch_lightning/accelerators/cpu.py +++ b/pytorch_lightning/accelerators/cpu.py @@ -20,7 +20,7 @@ class CPUAccelerator(Accelerator): """Accelerator for CPU devices.""" - def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer") -> None: """ Raises: MisconfigurationException: @@ -36,4 +36,4 @@ def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: if "cpu" not in str(self.root_device): raise MisconfigurationException(f"Device should be CPU, got {self.root_device} instead.") - return super().setup(trainer, model) + return super().setup(trainer) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index e25a5fa3c417a..6a38cd2cf50e9 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -32,14 +32,14 @@ def setup_environment(self) -> None: raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") torch.cuda.set_device(self.root_device) - def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer") -> None: """ Raises: MisconfigurationException: If the selected device is not GPU. """ self.set_nvidia_flags(trainer.local_rank) - return super().setup(trainer, model) + return super().setup(trainer) def on_train_start(self) -> None: # clear cache before training diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index ba75032b346b1..954bed3dbc58a 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -32,7 +32,7 @@ class TPUAccelerator(Accelerator): """Accelerator for TPU devices.""" - def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: + def setup(self, trainer: "pl.Trainer") -> None: """ Raises: MisconfigurationException: @@ -45,7 +45,7 @@ def setup(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None: if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") - return super().setup(trainer, model) + return super().setup(trainer) def run_optimizer_step( self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 9b8ee34d11c4d..c8c4af2624771 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -166,11 +166,10 @@ def get_max_batches(self) -> List[Union[int, float]]: def reload_evaluation_dataloaders(self) -> None: """Reloads dataloaders if necessary""" - model = self.trainer.lightning_module if self.trainer.testing: - self.trainer.reset_test_dataloader(model) + self.trainer.reset_test_dataloader() elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch: - self.trainer.reset_val_dataloader(model) + self.trainer.reset_val_dataloader() def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_start`` hooks""" diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index bfcd78bb03547..9981d2a1fc260 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -29,8 +29,7 @@ def global_rank(self) -> int: def world_size(self) -> int: return self.num_nodes - def setup(self, model): - self._model = model + def setup(self) -> None: # set the task idx self.task_idx = self.cluster_environment.local_rank() # the difference to DDP is that we don't call children processes here diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 0d7387d5b5dda..b69fea03b53e3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -136,7 +136,7 @@ def distributed_sampler_kwargs(self): def _is_single_process_single_device(self): return True - def setup(self, model): + def setup(self) -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # pass in a state q smp = mp.get_context("spawn") diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 2787ab5644ccd..beedac2942ac6 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -48,10 +48,10 @@ def node_rank(self) -> int: def world_size(self) -> int: return 1 - def setup(self, model): + def setup(self) -> None: # model needs to be moved to the device before it is wrapped - model.to(self.root_device) - self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) + self.model_to_device() + self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices) def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """ @@ -76,9 +76,8 @@ def mean(t: torch.Tensor) -> torch.Tensor: def root_device(self): return self.parallel_devices[0] - def model_to_device(self): - # no need to do anything when model is wrapped in torch.nn.DataParallel - pass + def model_to_device(self) -> None: + self._model.to(self.root_device) def barrier(self, *args, **kwargs): pass diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index cc888a5364a8b..1acac25e96db4 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -16,7 +16,6 @@ import torch from torch import Tensor -from torch.nn import Module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin @@ -138,9 +137,11 @@ def wrap_policy(*args, **kwargs): ): yield - def connect(self, model: Module) -> None: - super().connect(model) - model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) + def setup_environment(self) -> None: + super().setup_environment() + model_call_configure_sharded_model_hook = getattr( + self.lightning_module, "call_configure_sharded_model_hook", False + ) if not model_call_configure_sharded_model_hook: # if model has not called configure sharded model, we reset # the training type plugin's call_configure_sharded_model_hook diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index b61c7caa6ac23..34fe429d89362 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -58,8 +58,7 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs - def setup(self, model): - self._model = model + def setup(self) -> None: self.model_to_device() def pre_dispatch(self): diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index c62a5c9dc7d64..5399cffe19f68 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -63,9 +63,8 @@ def root_device(self) -> torch.device: def model_to_device(self) -> None: self._model.to(self.root_device) - def setup(self, model: torch.nn.Module) -> torch.nn.Module: + def setup(self) -> None: self.model_to_device() - return self.model @property def is_global_zero(self) -> bool: diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index fa2497d849c7a..d8b9457ffef19 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -111,9 +111,8 @@ def pre_dispatch(self): if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) - def setup(self, model: Module) -> Module: + def setup(self) -> None: self.create_mp_queue() - return self.model def create_mp_queue(self): self.start_method = "fork" diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index a8b444de0bd27..cdff37fd9bcb2 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -54,7 +54,7 @@ def setup_environment(self) -> None: which allows the user to access the accelerator environment before setup is complete. """ - def setup(self, model: Module) -> None: + def setup(self) -> None: """Called by the accelerator to finish setup.""" @property diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 842a10aa69ef1..5aac1acb6c572 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -32,20 +32,20 @@ class TrainerCallbackHookMixin(ABC): callbacks: List[Callback] = [] lightning_module: "pl.LightningModule" - def on_before_accelerator_backend_setup(self, model: "pl.LightningModule") -> None: + def on_before_accelerator_backend_setup(self) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.on_before_accelerator_backend_setup(self, model) + callback.on_before_accelerator_backend_setup(self, self.lightning_module) - def configure_sharded_model(self, model: "pl.LightningModule") -> None: + def on_configure_sharded_model(self) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.on_configure_sharded_model(self, model) + callback.on_configure_sharded_model(self, self.lightning_module) - def setup(self, model: "pl.LightningModule", stage: Optional[str]) -> None: + def setup(self, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.setup(self, model, stage=stage) + callback.setup(self, self.lightning_module, stage=stage) def teardown(self, stage: Optional[str] = None) -> None: """Called at the end of fit (train + validate), validate, test, or predict, or tune.""" diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index c5d76ee61905c..9f8831f13de42 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -15,7 +15,6 @@ from datetime import timedelta from typing import Dict, List, Optional, Union -import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities import rank_zero_info @@ -132,25 +131,19 @@ def attach_model_logging_functions(self, model): callback.log = model.log callback.log_dict = model.log_dict - @staticmethod - def _attach_model_callbacks(model: "pl.LightningModule", trainer) -> None: + def _attach_model_callbacks(self) -> None: """ Attaches the callbacks defined in the model. If a callback returned by the model's configure_callback method has the same type as one or several callbacks already present in the trainer callbacks list, it will replace them. In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks will be pushed to the end of the list, ensuring they run last. - - Args: - model: A model which may or may not define new callbacks in - :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_callbacks`. - trainer: The trainer on which the callbacks get attached/merged. """ - model_callbacks = model.configure_callbacks() + model_callbacks = self.trainer.model.configure_callbacks() if not model_callbacks: return model_callback_types = {type(c) for c in model_callbacks} - trainer_callback_types = {type(c) for c in trainer.callbacks} + trainer_callback_types = {type(c) for c in self.trainer.callbacks} override_types = model_callback_types.intersection(trainer_callback_types) if override_types: rank_zero_info( @@ -159,11 +152,11 @@ def _attach_model_callbacks(model: "pl.LightningModule", trainer) -> None: f" {', '.join(sorted(t.__name__ for t in override_types))}" ) # remove all callbacks with a type that occurs in model callbacks - all_callbacks = [c for c in trainer.callbacks if type(c) not in override_types] + all_callbacks = [c for c in self.trainer.callbacks if type(c) not in override_types] all_callbacks.extend(model_callbacks) all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks) # TODO: connectors refactor: move callbacks list to connector and do not write Trainer state - trainer.callbacks = all_callbacks + self.trainer.callbacks = all_callbacks @staticmethod def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index f6149d83546bb..c6d471fad04b1 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -66,13 +66,13 @@ def get_profiled_train_dataloader(self, train_dataloader): ) return profiled_dl - def prepare_data(self, model): + def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): if self.trainer.datamodule is not None: self.trainer.datamodule.prepare_data() - model.prepare_data() + self.trainer.lightning_module.prepare_data() self.trainer._is_data_prepared = True def can_prepare_data(self): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 9f8afbe451306..361d64569505d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -262,14 +262,14 @@ def _get_distributed_sampler( sampler = cls(dataloader.dataset, **kwargs) return sampler - def reset_train_dataloader(self, model: "pl.LightningModule") -> None: + def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: - model: The current `LightningModule` + model: The `LightningModule` if calling this outside of the trainer scope. """ - self.train_dataloader = self.request_dataloader(model, "train") + self.train_dataloader = self.request_dataloader("train", model=model) if self.overfit_batches > 0: if hasattr(self.train_dataloader, "sampler") and isinstance(self.train_dataloader.sampler, RandomSampler): @@ -351,20 +351,20 @@ def reset_train_dataloader(self, model: "pl.LightningModule") -> None: ) def _reset_eval_dataloader( - self, model: "pl.LightningModule", mode: str + self, mode: str, model: Optional["pl.LightningModule"] = None ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: - model: The current `LightningModule` mode: Either `'val'`, `'test'` or `'predict'` + model: The `LightningModule` if calling this outside of the trainer scope. Returns: Tuple (num_batches, dataloaders) """ # always get the loaders first so we can count how many there are loader_name = f"{mode}_dataloader" - dataloaders = self.request_dataloader(model, mode) + dataloaders = self.request_dataloader(mode, model=model) if not isinstance(dataloaders, list): dataloaders = [dataloaders] @@ -373,7 +373,7 @@ def _reset_eval_dataloader( # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: num_loaders = len(dataloaders) - train_dataloader = self.request_dataloader(model, "train") + train_dataloader = self.request_dataloader("train", model=model) dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)] self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders) @@ -450,52 +450,59 @@ def _reset_eval_dataloader( return loader_num_batches, dataloaders - def reset_val_dataloader(self, model: "pl.LightningModule") -> None: + def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the validation dataloader and determines the number of batches. Args: - model: The current `LightningModule` + model: The `LightningModule` if called outside of the trainer scope. """ - has_loader = is_overridden("val_dataloader", model) - has_step = is_overridden("validation_step", model) + pl_module = self.lightning_module or model + has_loader = is_overridden("val_dataloader", pl_module) + has_step = is_overridden("validation_step", pl_module) if has_loader and has_step: - self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, "val") + self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader("val", model=pl_module) - def reset_test_dataloader(self, model) -> None: + def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the test dataloader and determines the number of batches. Args: - model: The current `LightningModule` + model: The `LightningModule` if called outside of the trainer scope. """ - has_loader = is_overridden("test_dataloader", model) - has_step = is_overridden("test_step", model) + pl_module = self.lightning_module or model + has_loader = is_overridden("test_dataloader", pl_module) + has_step = is_overridden("test_step", pl_module) if has_loader and has_step: - self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader(model, "test") + self.num_test_batches, self.test_dataloaders = self._reset_eval_dataloader("test", model=pl_module) - def reset_predict_dataloader(self, model) -> None: + def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the predict dataloader and determines the number of batches. Args: - model: The current `LightningModule` + model: The `LightningModule` if called outside of the trainer scope. """ - has_loader = is_overridden("predict_dataloader", model) + pl_module = self.lightning_module or model + has_loader = is_overridden("predict_dataloader", pl_module) if has_loader: - self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, "predict") + self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader("predict", model=pl_module) - def reset_train_val_dataloaders(self, model) -> None: + def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None: """ Resets train and val dataloaders if none are attached to the trainer. The val dataloader must be initialized before training loop starts, as the training loop inspects the val dataloader to determine whether to run the evaluation loop. + + Args: + model: The `LightningModule` if called outside of the trainer scope. """ if self.train_dataloader is None: - self.reset_train_dataloader(model) - + self.reset_train_dataloader(model=model) if self.val_dataloaders is None: - self.reset_val_dataloader(model) + self.reset_val_dataloader(model=model) - def request_dataloader(self, model: "pl.LightningModule", stage: str) -> Union[DataLoader, List[DataLoader]]: + def request_dataloader( + self, stage: str, model: Optional["pl.LightningModule"] = None + ) -> Union[DataLoader, List[DataLoader]]: """Handles downloading data in the GPU or TPU case. Returns: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d2c6fab7ba559..081f854a492e1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -861,9 +861,12 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # attach model log function to callback self.callback_connector.attach_model_logging_functions(model) + # attach model to the training type plugin + self.accelerator.connect(model) + # hook - self.data_connector.prepare_data(model) - self.callback_connector._attach_model_callbacks(model, self) + self.data_connector.prepare_data() + self.callback_connector._attach_model_callbacks() if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: self._load_checkpoint_weights() @@ -871,17 +874,16 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # ---------------------------- # SET UP TRAINING # ---------------------------- - self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator.connect(model) + self.call_hook("on_before_accelerator_backend_setup") self.accelerator.setup_environment() - self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + self._call_setup_hook() # allow user to setup lightning_module in accelerator environment # check if we should delay restoring checkpoint till later if not self.accelerator.restore_checkpoint_after_pre_dispatch: self._restore_modules_and_callbacks() - self._call_configure_sharded_model(model) # allow user to setup in model sharded environment - self.accelerator.setup(self, model) # note: this sets up self.lightning_module + self._call_configure_sharded_model() # allow user to setup in model sharded environment + self.accelerator.setup(self) # ---------------------------- # INSPECT THE CORE LOOPS @@ -943,7 +945,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.call_hook("on_fit_end") # teardown - self._call_teardown_hook(model) + self._call_teardown_hook() if self.state.status != TrainerStatus.INTERRUPTED: self.state.status = TrainerStatus.FINISHED @@ -1178,44 +1180,45 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ ) return ckpt_path - def _call_setup_hook(self, model: "pl.LightningModule") -> None: + def _call_setup_hook(self) -> None: fn = self.state.fn._setup_fn self.accelerator.barrier("pre_setup") if self.datamodule is not None: self.datamodule.setup(stage=fn) - self.setup(model, stage=fn) - model.setup(stage=fn) + self.setup(stage=fn) + self.lightning_module.setup(stage=fn) self.accelerator.barrier("post_setup") - def _call_configure_sharded_model(self, model: "pl.LightningModule") -> None: + def _call_configure_sharded_model(self) -> None: # Call configure sharded model hook if accelerator requests. In some cases # we will not call the hook; the hook has initialized the sharded model for example. # used on the model if the user re-create a trainer with resume_from_checkpoint + model = self.lightning_module model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: with self.accelerator.model_sharded_context(): model.configure_sharded_model() - self.configure_sharded_model(model) + self.on_configure_sharded_model() model.call_configure_sharded_model_hook = True self.accelerator.call_configure_sharded_model_hook = False - def _call_teardown_hook(self, model: "pl.LightningModule") -> None: + def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn if self.datamodule is not None: self.datamodule.teardown(stage=fn) self.profiler.teardown(stage=fn) self.teardown(stage=fn) - model.teardown(stage=fn) + self.lightning_module.teardown(stage=fn) - model._current_fx_name = None - model._current_dataloader_idx = None + self.lightning_module._current_fx_name = None + self.lightning_module._current_dataloader_idx = None # these could have become stale if metrics are defined in `setup` - model._metric_attributes = None + self.lightning_module._metric_attributes = None def call_hook(self, hook_name: str, *args, **kwargs) -> Any: if self.lightning_module: diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 7eaab10b9f24f..1a584ed444758 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -18,12 +18,11 @@ def test_unsupported_precision_plugins(): """Test error messages are raised for unsupported precision plugins with CPU.""" trainer = Mock() - model = Mock() accelerator = CPUAccelerator( training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin() ) with pytest.raises(MisconfigurationException, match=r"AMP \+ CPU is not supported"): - accelerator.setup(trainer=trainer, model=model) + accelerator.setup(trainer=trainer) @pytest.mark.parametrize("delay_dispatch", [True, False]) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index eebbe3ec2138f..2ce27a1758533 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -38,6 +38,7 @@ def test_can_prepare_data(local_rank, node_rank): model = BoringModel() dm = BoringDataModule() trainer = Trainer() + trainer.model = model trainer.datamodule = dm # 1 no DM @@ -51,7 +52,7 @@ def test_can_prepare_data(local_rank, node_rank): assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is not None # local rank = 1 (False) @@ -61,7 +62,7 @@ def test_can_prepare_data(local_rank, node_rank): assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is None # prepare_data_per_node = False (prepare across all nodes) @@ -73,7 +74,7 @@ def test_can_prepare_data(local_rank, node_rank): local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is not None # global rank = 1 (False) @@ -83,14 +84,14 @@ def test_can_prepare_data(local_rank, node_rank): local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is None node_rank.return_value = 0 local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() - trainer.data_connector.prepare_data(model) + trainer.data_connector.prepare_data() assert dm.random_full is None # 2 dm diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 338de72a31fed..45efa3c82bcba 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -40,16 +40,18 @@ def test_checkpoint_callbacks_are_last(tmpdir): model = Mock() model.configure_callbacks.return_value = [] trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2]) + trainer.model = model cb_connector = CallbackConnector(trainer) - cb_connector._attach_model_callbacks(model, trainer) + cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] # with model-specific callbacks that substitute ones in Trainer model = Mock() model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2] trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) + trainer.model = model cb_connector = CallbackConnector(trainer) - cb_connector._attach_model_callbacks(model, trainer) + cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2] @@ -90,8 +92,9 @@ def assert_composition(trainer_callbacks, model_callbacks, expected): model = Mock() model.configure_callbacks.return_value = model_callbacks trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) + trainer.model = model cb_connector = CallbackConnector(trainer) - cb_connector._attach_model_callbacks(model, trainer) + cb_connector._attach_model_callbacks() assert trainer.callbacks == expected early_stopping = EarlyStopping() @@ -140,8 +143,9 @@ def test_attach_model_callbacks_override_info(caplog): model = Mock() model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()] trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) + trainer.model = model cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): - cb_connector._attach_model_callbacks(model, trainer) + cb_connector._attach_model_callbacks() assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 56fc6cc74f4ae..fa48dc021d386 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1456,9 +1456,7 @@ def predict_dataloader(self): def test_request_dataloader(tmpdir): - """ - This test asserts dataloader can be modified and properly set to the trainer. - """ + """This test asserts dataloader can be modified and properly set to the trainer.""" class DataLoaderWrapper: def __init__(self, loader): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 62bd0214b6d22..980dfff930dbb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1232,15 +1232,14 @@ class CurrentModel(BoringModel): def setup(self, stage): self.stage = stage - class TrainerSubclass(Trainer): - def setup(self, model, stage): + class CurrentCallback(Callback): + def setup(self, trainer, model, stage): assert model is not None self.stage = stage model = CurrentModel() - - # fit model - trainer = TrainerSubclass(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False) + callback = CurrentCallback() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, callbacks=[callback]) if stage == "fit": trainer.fit(model) @@ -1249,8 +1248,8 @@ def setup(self, model, stage): else: trainer.test(model) - assert trainer.stage == stage - assert trainer.lightning_module.stage == stage + assert callback.stage == stage + assert model.stage == stage @pytest.mark.parametrize("train_batches, max_steps, log_interval", [(10, 10, 1), (3, 10, 1), (3, 10, 5)]) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 57e1b9e27ae20..f63455e36475b 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -112,7 +112,7 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as percent # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == num_train_samples # make sure we turned off shuffle for the user @@ -126,23 +126,23 @@ def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # test overfit_batches as int # ------------------------------------------------------ - loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 1 - loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 5 # ------------------------------------------------------ # test limit_xxx_batches as percent AND int # ------------------------------------------------------ if split == "val": - loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(val_loader)) - loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10 else: - loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == int(0.1 * len(test_loader)) - loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(model, split) + loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(split, model=model) assert loader_num_batches[0] == 10