diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fe3e6c7f3c00..edcb6f76dda28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655)) +- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) diff --git a/pl_examples/loop_examples/kfold.py b/pl_examples/loop_examples/kfold.py index ed4db6faa5011..de5a2d512a7b2 100644 --- a/pl_examples/loop_examples/kfold.py +++ b/pl_examples/loop_examples/kfold.py @@ -205,8 +205,8 @@ def on_run_end(self) -> None: voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths) voting_model.trainer = self.trainer # This requires to connect the new model and move it the right device. - self.trainer.training_type_plugin.connect(voting_model) - self.trainer.training_type_plugin.model_to_device() + self.trainer.strategy.connect(voting_model) + self.trainer.strategy.model_to_device() self.trainer.test_loop.run() def on_save_checkpoint(self) -> Dict[str, int]: diff --git a/pl_examples/loop_examples/yielding_training_step.py b/pl_examples/loop_examples/yielding_training_step.py index 739d4f0f2b6b9..e787c8bd98204 100644 --- a/pl_examples/loop_examples/yielding_training_step.py +++ b/pl_examples/loop_examples/yielding_training_step.py @@ -77,7 +77,7 @@ def _get_generator(self, split_batch, batch_idx, opt_idx): # Here we are basically calling `lightning_module.training_step()` # and this returns a generator! The `training_step` is handled by the # accelerator to enable distributed training. - return self.trainer.training_type_plugin.training_step(*step_kwargs.values()) + return self.trainer.strategy.training_step(*step_kwargs.values()) def _training_step(self, generator): # required for logging @@ -86,7 +86,7 @@ def _training_step(self, generator): # Here, instead of calling `lightning_module.training_step()` # we call next() on the generator! training_step_output = next(generator) - self.trainer.training_type_plugin.post_training_step() + self.trainer.strategy.post_training_step() model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output) strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index b2df8e78bec6d..0a2fe81ab2eb5 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -200,7 +200,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop - should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) + should_stop = trainer.strategy.reduce_boolean_decision(should_stop) trainer.should_stop = trainer.should_stop or should_stop if should_stop: self.stopped_epoch = trainer.current_epoch diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d12752eabfff8..5187ab3ef66f3 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -286,7 +286,7 @@ def on_train_batch_end( skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() # in case we have time differences across ranks # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs - skip_time = trainer.training_type_plugin.broadcast(skip_time) + skip_time = trainer.strategy.broadcast(skip_time) if skip_batch and skip_time: return @@ -492,7 +492,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Ten should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) # If using multiple devices, make sure all processes are unanimous on the decision. - should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) + should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save) return should_update_best_and_save @@ -598,7 +598,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") - ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path) + ckpt_path = trainer.strategy.broadcast(ckpt_path) self.dirpath = ckpt_path @@ -646,7 +646,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ trainer.save_checkpoint(filepath, self.save_weights_only) if self.last_model_path and self.last_model_path != filepath: - trainer.training_type_plugin.remove_checkpoint(self.last_model_path) + trainer.strategy.remove_checkpoint(self.last_model_path) self.last_model_path = filepath @@ -671,7 +671,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate trainer.save_checkpoint(filepath, self.save_weights_only) if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath: - trainer.training_type_plugin.remove_checkpoint(self.best_model_path) + trainer.strategy.remove_checkpoint(self.best_model_path) self.best_model_path = filepath @@ -718,7 +718,7 @@ def _update_best_and_save( trainer.save_checkpoint(filepath, self.save_weights_only) if del_filepath is not None and filepath != del_filepath: - trainer.training_type_plugin.remove_checkpoint(del_filepath) + trainer.strategy.remove_checkpoint(del_filepath) def to_yaml(self, filepath: Optional[_PATH] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML @@ -733,4 +733,4 @@ def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool: """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.""" exists = self._fs.exists(filepath) - return trainer.training_type_plugin.broadcast(exists) + return trainer.strategy.broadcast(exists) diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index 86c84d61e0ec1..dbc929385febd 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -173,7 +173,7 @@ def on_load_checkpoint( def _check_time_remaining(self, trainer: "pl.Trainer") -> None: assert self._duration is not None should_stop = self.time_elapsed() >= self._duration - should_stop = trainer.training_type_plugin.broadcast(should_stop) + should_stop = trainer.strategy.broadcast(should_stop) trainer.should_stop = trainer.should_stop or should_stop if should_stop and self._verbose: elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING))) diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index 9c4f09c08a9b3..a14ee42e9a3ec 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -77,7 +77,7 @@ def on_train_start(self, trainer, pl_module) -> None: ) memory_info = xm.get_memory_info(pl_module.device) - total_memory = trainer.training_type_plugin.reduce(memory_info["kb_total"]) * 0.001 + total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001 rank_zero_info(f"Average Total memory: {total_memory:.2f} MB") def on_train_epoch_start(self, trainer, pl_module) -> None: @@ -91,9 +91,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None: free_memory = memory_info["kb_free"] peak_memory = memory_info["kb_total"] - free_memory - free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001 - peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001 - epoch_time = trainer.training_type_plugin.reduce(epoch_time) + free_memory = trainer.strategy.reduce(free_memory) * 0.001 + peak_memory = trainer.strategy.reduce(peak_memory) * 0.001 + epoch_time = trainer.strategy.reduce(epoch_time) logs["avg. free memory (MB)"] = free_memory logs["avg. peak memory (MB)"] = peak_memory diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ca9b8dfe7f96c..c40b2b65addc0 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -421,7 +421,7 @@ def log( add_dataloader_idx=add_dataloader_idx, batch_size=batch_size, sync_dist=sync_dist and distributed_available(), - sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp, + sync_dist_fn=self.trainer.strategy.reduce or sync_ddp, sync_dist_group=sync_dist_group, metric_attribute=metric_attribute, rank_zero_only=rank_zero_only, @@ -536,7 +536,7 @@ def all_gather( the output will also be a collection with tensors of this shape. """ group = group if group is not None else torch.distributed.group.WORLD - all_gather = self.trainer.training_type_plugin.all_gather + all_gather = self.trainer.strategy.all_gather data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads) @@ -1337,7 +1337,7 @@ def training_step(...): **kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward` """ self._verify_is_manual_optimization("manual_backward") - self.trainer.training_type_plugin.backward(loss, None, None, *args, **kwargs) + self.trainer.strategy.backward(loss, None, None, *args, **kwargs) def backward( self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index b3f49d393824f..c67decf97c3b5 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -161,4 +161,4 @@ def closure_dis(): trainer = self._trainer assert trainer is not None with trainer.profiler.profile(profiler_action): - trainer.training_type_plugin.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 77139186dc92e..904e05ac6a804 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -329,9 +329,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional # Python primitives. However, their states are saved with the model's `state_dict`. # On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`. # The references are provided through the `metric_attributes` dictionary. - v.load_state_dict( - state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce - ) + v.load_state_dict(state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.strategy.reduce) if not self.trainer.is_global_zero: v.reset(metrics=False) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index c63e9a03b7b50..8614034889a19 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -109,7 +109,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: void(*args, **kwargs) dataloader_idx = self.current_dataloader_idx - dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) + dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader) self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader( dataloader, dataloader_idx=dataloader_idx ) diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 1c96f077630e6..7d4c95cca86bd 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -30,11 +30,11 @@ def return_predictions(self) -> bool: @return_predictions.setter def return_predictions(self, return_predictions: Optional[bool] = None) -> None: # `DDPSpawnStrategy` plugins and derivatives don't support return predictions. - is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnStrategy) + is_ddp_spawn = isinstance(self.trainer.strategy, DDPSpawnStrategy) if return_predictions and is_ddp_spawn: raise MisconfigurationException( "`return_predictions` should be set to `False` when using the `DDPSpawnStrategy` or children class. " - f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}." + f"Found {return_predictions} with training_type_plugin {type(self.trainer.strategy)}." ) # For non `DDPSpawnStrategy` plugin, the `return_predictions` is True by default unless user decide otherwise. self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions @@ -86,7 +86,7 @@ def on_run_start(self) -> None: # type: ignore[override] def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) - dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) + dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader) dataloader_iter = enumerate(dataloader) dl_max_batches = self.max_batches[self.current_dataloader_idx] diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c001a6de47e19..ce9e82bc93efd 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -366,9 +366,7 @@ def _should_accumulate(self) -> bool: # Lightning steps on the final batch is_final_batch = self._num_ready_batches_reached() # but the TTP might not - ttp_accumulates_on_final_batch = ( - self.trainer.training_type_plugin.handles_gradient_accumulation or not is_final_batch - ) + ttp_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch return not accumulation_done and ttp_accumulates_on_final_batch @staticmethod diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 49b5a1ba5a1e2..697f1e3b0d840 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -223,7 +223,7 @@ def on_advance_start(self) -> None: # type: ignore[override] def advance(self) -> None: # type: ignore[override] """Runs one whole epoch.""" - dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader) + dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader) data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader) with self.trainer.profiler.profile("run_training_epoch"): @@ -255,7 +255,7 @@ def on_run_end(self) -> None: self.trainer._call_strategy_hook("on_train_end") # give accelerators a chance to finish - self.trainer.training_type_plugin.on_train_end() + self.trainer.strategy.on_train_end() def teardown(self) -> None: self.epoch_loop.teardown() diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 9577d9e15d3a2..490ad6ce630ac 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -103,7 +103,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] # manually capture logged metrics training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values()) - self.trainer.training_type_plugin.post_training_step() + self.trainer.strategy.post_training_step() del step_kwargs diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index d54d06ba53c27..ee4af19134cf3 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -243,7 +243,7 @@ def _run_optimization( if ( # when the training type plugin handles accumulation, we want to always call the optimizer step - not self.trainer.training_type_plugin.handles_gradient_accumulation + not self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate() ): # For gradient accumulation @@ -427,7 +427,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos # manually capture logged metrics training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values()) - self.trainer.training_type_plugin.post_training_step() + self.trainer.strategy.post_training_step() del step_kwargs diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 656f39711d6cb..4e3e3c8cffa23 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -171,8 +171,8 @@ def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Returns: context manager with sync behaviour off """ - if isinstance(trainer.training_type_plugin, ParallelStrategy) and block: - with trainer.training_type_plugin.block_backward_sync(): + if isinstance(trainer.strategy, ParallelStrategy) and block: + with trainer.strategy.block_backward_sync(): yield None else: yield None diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 1e448a226a2a1..1e86ec2633fe9 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -47,7 +47,7 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS: def dispatch(self, trainer: "pl.Trainer") -> None: if not self._connected: - strategy = trainer.training_type_plugin + strategy = trainer.strategy _, strategy.optimizers = amp.initialize( trainer.lightning_module, strategy.optimizers, opt_level=self.amp_level ) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2d6f6d3435bda..aa333a29942a9 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -81,7 +81,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: with pl_legacy_patch(): - loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) + loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): raise ValueError( "The checkpoint you're attempting to load follows an" @@ -113,7 +113,7 @@ def resume_end(self) -> None: torch.cuda.empty_cache() # wait for all to catch up - self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") + self.trainer.strategy.barrier("CheckpointConnector.resume_end") def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and @@ -170,7 +170,7 @@ def restore_model(self) -> None: model.on_hpc_load(self._loaded_checkpoint) # restore model state_dict - self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) + self.trainer.strategy.load_model_state_dict(self._loaded_checkpoint) # reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing. if not self.trainer.is_global_zero: @@ -258,10 +258,7 @@ def restore_loops(self) -> None: def restore_optimizers_and_schedulers(self) -> None: """Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint.""" - if ( - not self._loaded_checkpoint - or not self.trainer.training_type_plugin.lightning_restore_optimizer_and_schedulers - ): + if not self._loaded_checkpoint or not self.trainer.strategy.lightning_restore_optimizer_and_schedulers: return # validation @@ -279,7 +276,7 @@ def restore_optimizers(self) -> None: return # restore the optimizers - self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint) + self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint) for optimizer in self.trainer.optimizers: # move optimizer to GPU 1 weight at a time # avoids OOM @@ -387,7 +384,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: optimizer_states = [] for i, optimizer in enumerate(self.trainer.optimizers): # Rely on accelerator to dump optimizer state - optimizer_state = self.trainer.training_type_plugin.optimizer_state(optimizer) + optimizer_state = self.trainer.strategy.optimizer_state(optimizer) optimizer_states.append(optimizer_state) checkpoint["optimizer_states"] = optimizer_states @@ -463,7 +460,7 @@ def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: weights_only: saving model weights only """ _checkpoint = self.dump_checkpoint(weights_only) - self.trainer.training_type_plugin.save_checkpoint(_checkpoint, filepath) + self.trainer.strategy.save_checkpoint(_checkpoint, filepath) def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metrics = ( @@ -476,7 +473,7 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: metric.persistent(True) metric.sync() - state_dict = self.trainer.training_type_plugin.lightning_module_state_dict() + state_dict = self.trainer.strategy.lightning_module_state_dict() for metric in metrics: # sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 90a0995305df3..c83b5b6483db3 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -106,7 +106,7 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": # note: this is an experimental feature - if not self.trainer.training_type_plugin.on_gpu: + if not self.trainer.strategy.on_gpu: raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") return InterBatchParallelDataFetcher() @@ -118,7 +118,7 @@ def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) data_fetcher.setup( dataloader, stage=stage, - batch_to_device=partial(self.trainer.training_type_plugin.batch_to_device, dataloader_idx=dataloader_idx), + batch_to_device=partial(self.trainer.strategy.batch_to_device, dataloader_idx=dataloader_idx), profiler=self.trainer.profiler, ) setattr(self, f"{stage}_data_fetcher", data_fetcher) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 4460e235b11bd..890fe17a259b4 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -230,7 +230,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - module = model or self.lightning_module or self.datamodule self.num_training_batches = ( len(self.train_dataloader) - if has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module) + if has_len_all_ranks(self.train_dataloader, self.strategy, module) else float("inf") ) @@ -257,7 +257,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - "If you want to disable validation set `limit_val_batches` to 0.0 instead." ) else: - if not has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module): + if not has_len_all_ranks(self.train_dataloader, self.strategy, module): if self.val_check_interval == 1.0: self.val_check_batch = float("inf") else: @@ -323,9 +323,7 @@ def _reset_eval_dataloader( if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): orig_num_batches = num_batches = ( - len(dataloader) - if has_len_all_ranks(dataloader, self.training_type_plugin, module) - else float("inf") + len(dataloader) if has_len_all_ranks(dataloader, self.strategy, module) else float("inf") ) self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}") @@ -430,7 +428,7 @@ def request_dataloader( dataloader = source.dataloader() if isinstance(dataloader, tuple): dataloader = list(dataloader) - self.training_type_plugin.barrier("get_dataloaders") + self.strategy.barrier("get_dataloaders") _validate_fault_tolerant_automatic(dataloader, stage) return dataloader diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 450590066cc4e..8f586afa6bd03 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -673,9 +673,9 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: **kwargs: keyword arguments to be passed to `trainer_fn` """ try: - if isinstance(self.training_type_plugin, DDPSpawnStrategy): - spawn_output: _SpawnOutput = self.training_type_plugin.spawn(trainer_fn, *args, **kwargs) - self.training_type_plugin._recover_results_in_main_process(spawn_output, self) + if isinstance(self.strategy, DDPSpawnStrategy): + spawn_output: _SpawnOutput = self.strategy.spawn(trainer_fn, *args, **kwargs) + self.strategy._recover_results_in_main_process(spawn_output, self) return spawn_output.trainer_results else: return trainer_fn(*args, **kwargs) @@ -691,7 +691,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: # try syncing remaing processes, kill otherwise - self.training_type_plugin.reconciliate_processes(traceback.format_exc()) + self.strategy.reconciliate_processes(traceback.format_exc()) self._on_exception() # reset bookkeeping self.state.stage = None @@ -726,7 +726,7 @@ def fit( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ - self.training_type_plugin.model = model + self.strategy.model = model self._call_and_handle_interrupt( self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) @@ -800,7 +800,7 @@ def validate( :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ - self.training_type_plugin.model = model or self.lightning_module + self.strategy.model = model or self.lightning_module return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _validate_impl( @@ -885,7 +885,7 @@ def test( :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ - self.training_type_plugin.model = model or self.lightning_module + self.strategy.model = model or self.lightning_module return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _test_impl( @@ -969,7 +969,7 @@ def predict( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ - self.training_type_plugin.model = model or self.lightning_module + self.strategy.model = model or self.lightning_module return self._call_and_handle_interrupt( self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path ) @@ -1093,7 +1093,7 @@ def _run( parsing.clean_namespace(model.hparams) # attach model to the training type plugin - self.training_type_plugin.connect(model) + self.strategy.connect(model) self._callback_connector._attach_model_callbacks() self._callback_connector._attach_model_logging_functions() @@ -1107,11 +1107,11 @@ def _run( # SET UP TRAINING # ---------------------------- self._call_callback_hooks("on_before_accelerator_backend_setup") - self.training_type_plugin.setup_environment() + self.strategy.setup_environment() self._call_setup_hook() # allow user to setup lightning_module in accelerator environment # check if we should delay restoring checkpoint till later - if not self.training_type_plugin.restore_checkpoint_after_setup: + if not self.strategy.restore_checkpoint_after_setup: self._restore_modules_and_callbacks(ckpt_path) self._call_configure_sharded_model() # allow user to setup in model sharded environment @@ -1124,7 +1124,7 @@ def _run( {Trainer.fit} or {Trainer.test} or {Trainer.predict} || | || spawn processes || - {self.training_type_plugin.setup_environment} || + {self.strategy.setup_environment} || | || setup accelerator || and strategy || LIGHTNING @@ -1149,7 +1149,7 @@ def _run( self.logger_connector.reset_metrics() # strategy will configure model and move it to the device - self.training_type_plugin.setup(self) + self.strategy.setup(self) # hook if self.state.fn == TrainerFn.FITTING: @@ -1158,7 +1158,7 @@ def _run( self._log_hyperparams() - if self.training_type_plugin.restore_checkpoint_after_setup: + if self.strategy.restore_checkpoint_after_setup: self._restore_modules_and_callbacks(ckpt_path) # restore optimizers, etc. @@ -1183,8 +1183,8 @@ def _run( self.state.status = TrainerStatus.FINISHED self.state.stage = None - if isinstance(self.training_type_plugin, DDPSpawnStrategy): - results = self.training_type_plugin._collect_rank_zero_results(self, results) + if isinstance(self.strategy, DDPSpawnStrategy): + results = self.strategy._collect_rank_zero_results(self, results) return results @@ -1228,8 +1228,8 @@ def _log_hyperparams(self) -> None: def _teardown(self): """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback; those are handled by :meth:`_call_teardown_hook`.""" - self.training_type_plugin.post_dispatch(self) - self.training_type_plugin.teardown() + self.strategy.post_dispatch(self) + self.strategy.teardown() self._data_connector.teardown() self._active_loop.teardown() self.logger_connector.teardown() @@ -1243,8 +1243,8 @@ def run_stage(self) -> None: return self._run_stage() def _run_stage(self): - self.training_type_plugin.barrier("run-stage") - self.training_type_plugin.dispatch(self) + self.strategy.barrier("run-stage") + self.strategy.dispatch(self) self.__setup_profiler() if self.evaluating: @@ -1255,7 +1255,7 @@ def _run_stage(self): def _pre_training_routine(self): # wait for all to join if on distributed - self.training_type_plugin.barrier("setup_training") + self.strategy.barrier("setup_training") # register signals self.signal_connector.register_signal_handlers() @@ -1401,17 +1401,17 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ def _call_setup_hook(self) -> None: fn = self.state.fn._setup_fn - self.training_type_plugin.barrier("pre_setup") + self.strategy.barrier("pre_setup") if self.datamodule is not None: self.datamodule.setup(stage=fn) self._call_callback_hooks("setup", stage=fn) self._call_lightning_module_hook("setup", stage=fn) - self.training_type_plugin.barrier("post_setup") + self.strategy.barrier("post_setup") def _call_configure_sharded_model(self) -> None: - with self.training_type_plugin.model_sharded_context(): + with self.strategy.model_sharded_context(): self._handle_meta_model() self._call_lightning_module_hook("configure_sharded_model") self._call_callback_hooks("on_configure_sharded_model") @@ -1420,7 +1420,7 @@ def _handle_meta_model(self) -> None: if not is_on_meta_device(self.lightning_module): return - if isinstance(self.training_type_plugin, DDPSpawnStrategy): + if isinstance(self.strategy, DDPSpawnStrategy): raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") materialize_module(self.lightning_module) @@ -1492,10 +1492,8 @@ def call_hook( output = accelerator_output if output is None else output # call the ttp hook - if hook_name not in ("setup", "teardown", "on_train_start") and hasattr( - self.training_type_plugin, hook_name - ): - ttp_hook = getattr(self.training_type_plugin, hook_name) + if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(self.strategy, hook_name): + ttp_hook = getattr(self.strategy, hook_name) ttp_output = ttp_hook(*args, **kwargs) output = ttp_output if output is None else output @@ -1645,11 +1643,11 @@ def _call_strategy_hook( prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = hook_name - fn = getattr(self.training_type_plugin, hook_name) + fn = getattr(self.strategy, hook_name) if not callable(fn): return - with self.profiler.profile(f"[Strategy]{self.training_type_plugin.__class__.__name__}.{hook_name}"): + with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"): output = fn(*args, **kwargs) # restore current_fx when nested context @@ -1738,41 +1736,49 @@ def _on_exception(self) -> None: @property def accelerator(self) -> Accelerator: - return self.training_type_plugin.accelerator + return self.strategy.accelerator @property - def training_type_plugin(self) -> Strategy: + def strategy(self) -> Strategy: return self._accelerator_connector.training_type_plugin + @property + def training_type_plugin(self) -> Strategy: + rank_zero_deprecation( + "`Trainer.training_type_plugin` is deprecated in v1.6 and will be removed in v1.8. Use" + " `Trainer.strategy` instead." + ) + return self.strategy + @property def precision_plugin(self) -> PrecisionPlugin: - return self.training_type_plugin.precision_plugin + return self.strategy.precision_plugin @property def global_rank(self) -> int: - return self.training_type_plugin.global_rank + return self.strategy.global_rank @property def local_rank(self) -> int: # some training types define a local rank - return getattr(self.training_type_plugin, "local_rank", 0) + return getattr(self.strategy, "local_rank", 0) @property def node_rank(self) -> int: # some training types define a node rank - return getattr(self.training_type_plugin, "node_rank", 0) + return getattr(self.strategy, "node_rank", 0) @property def world_size(self) -> int: # some training types define a world size - return getattr(self.training_type_plugin, "world_size", 1) + return getattr(self.strategy, "world_size", 1) @property def should_rank_save_checkpoint(self) -> bool: rank_zero_deprecation( "`Trainer.should_rank_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8.", stacklevel=5 ) - ttp = self.training_type_plugin + ttp = self.strategy return isinstance(ttp, pl.plugins.TPUSpawnStrategy) and ttp.local_rank == 0 or ttp.is_global_zero @property @@ -1817,11 +1823,11 @@ def data_parallel_device_ids(self) -> Optional[List[int]]: @property def lightning_module(self) -> "pl.LightningModule": - return self.training_type_plugin.lightning_module + return self.strategy.lightning_module @property def optimizers(self) -> List[Optimizer]: - return self.training_type_plugin.optimizers + return self.strategy.optimizers @optimizers.setter def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: @@ -1830,23 +1836,23 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: # the `lightning_optimizers` trainer property self._lightning_optimizers = None - self.training_type_plugin.optimizers = new_optims + self.strategy.optimizers = new_optims @property def lr_schedulers(self) -> List[LRSchedulerTypeUnion]: - return self.training_type_plugin.lr_schedulers + return self.strategy.lr_schedulers @lr_schedulers.setter def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None: - self.training_type_plugin.lr_schedulers = new_schedulers + self.strategy.lr_schedulers = new_schedulers @property def optimizer_frequencies(self) -> list: - return self.training_type_plugin.optimizer_frequencies + return self.strategy.optimizer_frequencies @optimizer_frequencies.setter def optimizer_frequencies(self, new_freqs: list) -> None: - self.training_type_plugin.optimizer_frequencies = new_freqs + self.strategy.optimizer_frequencies = new_freqs @property def amp_backend(self) -> Optional[AMPType]: @@ -1858,7 +1864,7 @@ def amp_backend(self) -> Optional[AMPType]: @property def precision(self) -> Union[str, int]: - return self.training_type_plugin.precision_plugin.precision + return self.strategy.precision_plugin.precision @property def scaler(self) -> Optional[Any]: @@ -1875,7 +1881,7 @@ def model(self) -> torch.nn.Module: To access the pure LightningModule, use :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. """ - return self.training_type_plugin.model + return self.strategy.model @model.setter def model(self, model: torch.nn.Module) -> None: @@ -1886,7 +1892,7 @@ def model(self, model: torch.nn.Module) -> None: model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending on the backend. """ - self.training_type_plugin.model = model + self.strategy.model = model """ General properties @@ -1903,7 +1909,7 @@ def log_dir(self) -> Optional[str]: else: dirpath = self.logger.save_dir - dirpath = self.training_type_plugin.broadcast(dirpath) + dirpath = self.strategy.broadcast(dirpath) return dirpath @property @@ -1927,8 +1933,8 @@ def lightning_optimizers(self) -> List[LightningOptimizer]: @property def distributed_sampler_kwargs(self) -> Optional[dict]: - if isinstance(self.training_type_plugin, ParallelStrategy): - return self.training_type_plugin.distributed_sampler_kwargs + if isinstance(self.strategy, ParallelStrategy): + return self.strategy.distributed_sampler_kwargs @property def data_parallel(self) -> bool: @@ -2318,8 +2324,8 @@ def _exit_gracefully_on_signal(self) -> None: raise ExitGracefullyException(0) def _should_terminate_gracefully(self) -> bool: - value = torch.tensor(int(self._terminate_gracefully), device=self.training_type_plugin.root_device) - return self.training_type_plugin.reduce(value, reduce_op="sum") > 0 + value = torch.tensor(int(self._terminate_gracefully), device=self.strategy.root_device) + return self.strategy.reduce(value, reduce_op="sum") > 0 @property def weights_summary(self) -> Optional[str]: diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 8f611513c9b0c..84467310568f7 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -266,4 +266,4 @@ def _adjust_batch_size( def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"): module = trainer.lightning_module or trainer.datamodule - return not has_len_all_ranks(dataloader, trainer.training_type_plugin, module) or batch_size <= len(dataloader) + return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 72d1ac83de04e..45386713f326a 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -50,7 +50,7 @@ def test_accelerator_choice_cpu(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, SingleDeviceStrategy) + assert isinstance(trainer.strategy, SingleDeviceStrategy) @pytest.mark.parametrize(("num_processes", "num_nodes"), ([(1, 1), (1, 2), (2, 1), (2, 2)])) @@ -58,8 +58,8 @@ def test_accelerator_choice_ddp_cpu(tmpdir, num_processes: int, num_nodes: int): trainer = Trainer(fast_dev_run=True, accelerator="ddp_cpu", num_processes=num_processes, num_nodes=num_nodes) assert isinstance(trainer.accelerator, CPUAccelerator) no_spawn = num_processes == 1 and num_nodes > 1 - assert isinstance(trainer.training_type_plugin, DDPStrategy if no_spawn else DDPSpawnStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + assert isinstance(trainer.strategy, DDPStrategy if no_spawn else DDPSpawnStrategy) + assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @@ -69,8 +69,8 @@ def test_accelerator_choice_ddp(cuda_available_mock, device_count_mock): with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated"): trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=1) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @@ -80,8 +80,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): with pytest.deprecated_call(match=r"accelerator='ddp_spawn'\)` has been deprecated"): trainer = Trainer(fast_dev_run=True, accelerator="ddp_spawn", gpus=1) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPSpawnStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + assert isinstance(trainer.strategy, DDPSpawnStrategy) + assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) @mock.patch.dict( @@ -103,10 +103,10 @@ def test_accelerator_choice_ddp_slurm(*_): trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=2) assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -128,10 +128,10 @@ def test_accelerator_choice_ddp2_slurm(*_): trainer = Trainer(fast_dev_run=True, accelerator="ddp2", gpus=2) assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDP2Strategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDP2Strategy) + assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -152,10 +152,10 @@ def test_accelerator_choice_ddp_te(*_): with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=2) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -176,10 +176,10 @@ def test_accelerator_choice_ddp2_te(*_): with pytest.deprecated_call(match=r"accelerator='ddp2'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp2", gpus=2) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDP2Strategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDP2Strategy) + assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -190,10 +190,10 @@ def test_accelerator_choice_ddp2_te(*_): def test_accelerator_choice_ddp_cpu_te(*_): trainer = Trainer(fast_dev_run=True, accelerator="ddp_cpu", num_processes=2) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -214,10 +214,10 @@ def test_accelerator_choice_ddp_kubeflow(*_): with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated in v1.5"): trainer = Trainer(fast_dev_run=True, accelerator="ddp", gpus=1) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 0 - assert trainer.training_type_plugin.local_rank == 0 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, KubeflowEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 0 + assert trainer.strategy.local_rank == 0 @mock.patch.dict( @@ -235,10 +235,10 @@ def test_accelerator_choice_ddp_kubeflow(*_): def test_accelerator_choice_ddp_cpu_kubeflow(*_): trainer = Trainer(fast_dev_run=True, accelerator="ddp_cpu", num_processes=1) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 0 - assert trainer.training_type_plugin.local_rank == 0 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, KubeflowEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 0 + assert trainer.strategy.local_rank == 0 @mock.patch.dict( @@ -258,9 +258,9 @@ def test_accelerator_choice_ddp_cpu_slurm(*_): trainer = Trainer(fast_dev_run=True, accelerator="ddp_cpu", num_processes=2) assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.local_rank == 0 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) + assert trainer.strategy.local_rank == 0 @RunIf(skip_windows=True, standalone=True) @@ -283,10 +283,10 @@ def _test_accelerator_choice_ddp_cpu_and_strategy(tmpdir, ddp_strategy_class): accelerator="ddp_cpu", num_processes=2, ) - assert isinstance(trainer.training_type_plugin, ddp_strategy_class) + assert isinstance(trainer.strategy, ddp_strategy_class) assert isinstance(trainer.accelerator, CPUAccelerator) - assert trainer.training_type_plugin.num_processes == 2 - assert trainer.training_type_plugin.parallel_devices == [torch.device("cpu")] * 2 + assert trainer.strategy.num_processes == 2 + assert trainer.strategy.parallel_devices == [torch.device("cpu")] * 2 @mock.patch.dict( @@ -317,8 +317,8 @@ def creates_processes_externally(self) -> bool: default_root_dir=tmpdir, plugins=[CustomCluster()], fast_dev_run=True, accelerator="ddp_cpu", num_processes=2 ) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, CustomCluster) + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, CustomCluster) @mock.patch.dict( @@ -347,7 +347,7 @@ class TrainTypePlugin(SingleDeviceStrategy): ttp = TrainTypePlugin(device=torch.device("cpu"), accelerator=Accel(), precision_plugin=Prec()) trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2) assert isinstance(trainer.accelerator, Accel) - assert isinstance(trainer.training_type_plugin, TrainTypePlugin) + assert isinstance(trainer.strategy, TrainTypePlugin) assert isinstance(trainer.precision_plugin, Prec) assert trainer._accelerator_connector.training_type_plugin is ttp @@ -357,7 +357,7 @@ class DistributedPlugin(DDPStrategy): ttp = DistributedPlugin(accelerator=Accel(), precision_plugin=Prec()) trainer = Trainer(strategy=ttp, fast_dev_run=True, num_processes=2) assert isinstance(trainer.accelerator, Accel) - assert isinstance(trainer.training_type_plugin, DistributedPlugin) + assert isinstance(trainer.strategy, DistributedPlugin) assert isinstance(trainer.precision_plugin, Prec) assert trainer._accelerator_connector.training_type_plugin is ttp @@ -378,8 +378,8 @@ class DistributedPlugin(DDPStrategy): def test_dist_backend_accelerator_mapping(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", num_processes=2) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert trainer.training_type_plugin.local_rank == 0 + assert isinstance(trainer.strategy, DDPStrategy) + assert trainer.strategy.local_rank == 0 @mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) @@ -406,11 +406,11 @@ def test_plugin_accelerator_choice(accelerator: Optional[str], plugin: str): else: with pytest.deprecated_call(match=r"accelerator=.*\)` has been deprecated"): trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) - assert isinstance(trainer.training_type_plugin, DDPShardedStrategy) + assert isinstance(trainer.strategy, DDPShardedStrategy) with pytest.deprecated_call(match="Passing .* `strategy` to the `plugins`"): trainer = Trainer(plugins=plugin, num_processes=2) - assert isinstance(trainer.training_type_plugin, DDPShardedStrategy) + assert isinstance(trainer.strategy, DDPShardedStrategy) @pytest.mark.parametrize( @@ -431,7 +431,7 @@ def test_accelerator_choice_multi_node_gpu( ): with pytest.deprecated_call(match=r"accelerator=.*\)` has been deprecated"): trainer = Trainer(accelerator=accelerator, default_root_dir=tmpdir, num_nodes=2, gpus=gpus) - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) @pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't require GPU") @@ -492,7 +492,7 @@ def test_accelerator_cpu_with_devices(devices, plugin): trainer = Trainer(accelerator="cpu", devices=devices) assert trainer.num_processes == devices - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) assert isinstance(trainer.accelerator, CPUAccelerator) @@ -515,7 +515,7 @@ def test_accelerator_gpu_with_devices(devices, plugin): trainer = Trainer(accelerator="gpu", devices=devices) assert trainer.gpus == devices - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) assert isinstance(trainer.accelerator, GPUAccelerator) @@ -577,7 +577,7 @@ def test_accelerator_ddp_for_cpu(tmpdir): with pytest.deprecated_call(match=r"accelerator='ddp'\)` has been deprecated"): trainer = Trainer(accelerator="ddp", num_processes=2) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) + assert isinstance(trainer.strategy, DDPStrategy) def test_exception_when_strategy_used_with_accelerator(): @@ -610,13 +610,13 @@ def test_exception_invalid_strategy(): ) def test_strategy_choice_cpu_str(tmpdir, strategy, plugin): trainer = Trainer(strategy=strategy, accelerator="cpu", devices=2) - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) @pytest.mark.parametrize("plugin", [DDPSpawnStrategy, DDPStrategy]) def test_strategy_choice_cpu_plugin(tmpdir, plugin): trainer = Trainer(strategy=plugin(), accelerator="cpu", devices=2) - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) @RunIf(min_gpus=2) @@ -636,14 +636,14 @@ def test_strategy_choice_cpu_plugin(tmpdir, plugin): ) def test_strategy_choice_gpu_str(tmpdir, strategy, plugin): trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2) - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) @RunIf(min_gpus=2) @pytest.mark.parametrize("plugin", [DDPSpawnStrategy, DDPStrategy]) def test_strategy_choice_gpu_plugin(tmpdir, plugin): trainer = Trainer(strategy=plugin(), accelerator="gpu", devices=2) - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) @RunIf(min_gpus=2) @@ -651,7 +651,7 @@ def test_strategy_choice_gpu_plugin(tmpdir, plugin): def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin): trainer = Trainer(strategy=plugin(), gpus=2) - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) assert trainer._device_type == _AcceleratorType.GPU assert isinstance(trainer.accelerator, GPUAccelerator) @@ -671,8 +671,8 @@ def test_amp_level_raises_error_with_native(): def test_strategy_choice_ddp_spawn_cpu(tmpdir): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", num_processes=2) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPSpawnStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + assert isinstance(trainer.strategy, DDPSpawnStrategy) + assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @@ -681,8 +681,8 @@ def test_strategy_choice_ddp_spawn_cpu(tmpdir): def test_strategy_choice_ddp(cuda_available_mock, device_count_mock): trainer = Trainer(fast_dev_run=True, strategy="ddp", gpus=1) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @@ -691,8 +691,8 @@ def test_strategy_choice_ddp(cuda_available_mock, device_count_mock): def test_strategy_choice_ddp_spawn(cuda_available_mock, device_count_mock): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", gpus=1) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPSpawnStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) + assert isinstance(trainer.strategy, DDPSpawnStrategy) + assert isinstance(trainer.strategy.cluster_environment, LightningEnvironment) @RunIf(min_gpus=2) @@ -713,10 +713,10 @@ def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy): trainer = Trainer(fast_dev_run=True, strategy=strategy, gpus=2) assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -738,10 +738,10 @@ def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_di trainer = Trainer(fast_dev_run=True, strategy=strategy, gpus=2) assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDP2Strategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDP2Strategy) + assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -761,10 +761,10 @@ def test_strategy_choice_ddp2_slurm(set_device_mock, device_count_mock, setup_di def test_strategy_choice_ddp_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", gpus=2) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -784,10 +784,10 @@ def test_strategy_choice_ddp_te(*_): def test_strategy_choice_ddp2_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp2", gpus=2) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDP2Strategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDP2Strategy) + assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -798,10 +798,10 @@ def test_strategy_choice_ddp2_te(*_): def test_strategy_choice_ddp_cpu_te(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", num_processes=2) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 - assert trainer.training_type_plugin.local_rank == 1 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 1 + assert trainer.strategy.local_rank == 1 @mock.patch.dict( @@ -821,10 +821,10 @@ def test_strategy_choice_ddp_cpu_te(*_): def test_strategy_choice_ddp_kubeflow(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp", gpus=1) assert isinstance(trainer.accelerator, GPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 0 - assert trainer.training_type_plugin.local_rank == 0 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, KubeflowEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 0 + assert trainer.strategy.local_rank == 0 @mock.patch.dict( @@ -842,10 +842,10 @@ def test_strategy_choice_ddp_kubeflow(*_): def test_strategy_choice_ddp_cpu_kubeflow(*_): trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", num_processes=2) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, KubeflowEnvironment) - assert trainer.training_type_plugin.cluster_environment.local_rank() == 0 - assert trainer.training_type_plugin.local_rank == 0 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, KubeflowEnvironment) + assert trainer.strategy.cluster_environment.local_rank() == 0 + assert trainer.strategy.local_rank == 0 @mock.patch.dict( @@ -865,9 +865,9 @@ def test_strategy_choice_ddp_cpu_kubeflow(*_): def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock, strategy): trainer = Trainer(fast_dev_run=True, strategy=strategy, num_processes=2) assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.local_rank == 0 + assert isinstance(trainer.strategy, DDPStrategy) + assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) + assert trainer.strategy.local_rank == 0 def test_unsupported_tpu_choice(monkeypatch): diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index db2f388971c12..c356ecf935ae1 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -127,9 +127,9 @@ def __init__(self): class CustomCallback(Callback): def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - assert isinstance(trainer.training_type_plugin.model, DistributedDataParallel) - assert trainer.training_type_plugin.model.parameters_to_ignore == ("something") - assert trainer.training_type_plugin.model.module._ddp_params_and_buffers_to_ignore == ("something") + assert isinstance(trainer.strategy.model, DistributedDataParallel) + assert trainer.strategy.model.parameters_to_ignore == ("something") + assert trainer.strategy.model.module._ddp_params_and_buffers_to_ignore == ("something") model = CustomModel() trainer = Trainer( diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index 57894f98d18a3..3e9b727dd68b8 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -188,13 +188,13 @@ def test_optimization(tmpdir): def test_mixed_precision(tmpdir): class TestCallback(Callback): def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: - assert trainer.training_type_plugin.model.precision == 16 + assert trainer.strategy.model.precision == 16 raise SystemExit model = IPUModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) - assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) - assert trainer.training_type_plugin.precision_plugin.precision == 16 + assert isinstance(trainer.strategy.precision_plugin, IPUPrecisionPlugin) + assert trainer.strategy.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) @@ -203,8 +203,8 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st def test_pure_half_precision(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert trainer.training_type_plugin.model.precision == 16 - for param in trainer.training_type_plugin.model.parameters(): + assert trainer.strategy.model.precision == 16 + for param in trainer.strategy.model.parameters(): assert param.dtype == torch.float16 raise SystemExit @@ -212,9 +212,9 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: model = model.half() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) - assert isinstance(trainer.training_type_plugin, IPUStrategy) - assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) - assert trainer.training_type_plugin.precision_plugin.precision == 16 + assert isinstance(trainer.strategy, IPUStrategy) + assert isinstance(trainer.strategy.precision_plugin, IPUPrecisionPlugin) + assert trainer.strategy.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) @@ -224,9 +224,9 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: def test_device_iterations_ipu_plugin(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert trainer.training_type_plugin.device_iterations == 2 + assert trainer.strategy.device_iterations == 2 # assert device iterations has been set correctly within the poptorch options - poptorch_model = trainer.training_type_plugin.poptorch_models[RunningStage.TRAINING] + poptorch_model = trainer.strategy.poptorch_models[RunningStage.TRAINING] assert poptorch_model._options.toDict()["device_iterations"] == 2 raise SystemExit @@ -238,7 +238,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: strategy=IPUStrategy(device_iterations=2), callbacks=TestCallback(), ) - assert isinstance(trainer.training_type_plugin, IPUStrategy) + assert isinstance(trainer.strategy, IPUStrategy) with pytest.raises(SystemExit): trainer.fit(model) @@ -251,7 +251,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # since ipu handle accumulation assert trainer.accumulation_scheduler.scheduling == {0: 1} # assert poptorch option have been set correctly - poptorch_model = trainer.training_type_plugin.poptorch_models[RunningStage.TRAINING] + poptorch_model = trainer.strategy.poptorch_models[RunningStage.TRAINING] assert poptorch_model._options.Training.toDict()["gradient_accumulation"] == 2 raise SystemExit @@ -356,9 +356,9 @@ def test_manual_poptorch_opts(tmpdir): ) trainer.fit(model) - assert isinstance(trainer.training_type_plugin, IPUStrategy) - assert trainer.training_type_plugin.training_opts == training_opts - assert trainer.training_type_plugin.inference_opts == inference_opts + assert isinstance(trainer.strategy, IPUStrategy) + assert trainer.strategy.training_opts == training_opts + assert trainer.strategy.inference_opts == inference_opts @RunIf(ipu=True) @@ -380,7 +380,7 @@ def test_manual_poptorch_opts_custom(tmpdir): class TestCallback(Callback): def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # ensure dataloaders were correctly set up during training. - plugin = trainer.training_type_plugin + plugin = trainer.strategy assert isinstance(plugin, IPUStrategy) assert plugin.training_opts.replication_factor == 2 assert plugin.inference_opts.replication_factor == 1 @@ -400,7 +400,7 @@ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin, callbacks=TestCallback()) trainer.fit(model) - plugin = trainer.training_type_plugin + plugin = trainer.strategy assert isinstance(plugin, IPUStrategy) training_opts = plugin.training_opts @@ -422,7 +422,7 @@ def test_replication_factor(tmpdir): plugin = IPUStrategy() trainer = Trainer(ipus=2, default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin) assert trainer.ipus == 2 - assert trainer.training_type_plugin.replication_factor == 2 + assert trainer.strategy.replication_factor == 2 model = BoringModel() training_opts = poptorch.Options() @@ -436,12 +436,12 @@ def test_replication_factor(tmpdir): plugin.model = model model.trainer = trainer trainer.state.fn = TrainerFn.FITTING - trainer.training_type_plugin.setup(trainer) + trainer.strategy.setup(trainer) trainer.state.stage = RunningStage.TRAINING - assert trainer.training_type_plugin.replication_factor == 8 + assert trainer.strategy.replication_factor == 8 trainer.state.stage = RunningStage.VALIDATING - assert trainer.training_type_plugin.replication_factor == 7 + assert trainer.strategy.replication_factor == 7 for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), @@ -450,8 +450,8 @@ def test_replication_factor(tmpdir): ): trainer.state.fn = fn trainer.state.stage = stage - trainer.training_type_plugin.setup(trainer) - assert trainer.training_type_plugin.replication_factor == 7 + trainer.strategy.setup(trainer) + assert trainer.strategy.replication_factor == 7 @RunIf(ipu=True) @@ -462,9 +462,9 @@ def test_default_opts(tmpdir): trainer = Trainer(default_root_dir=tmpdir, ipus=1, fast_dev_run=True) trainer.fit(model) - assert isinstance(trainer.training_type_plugin, IPUStrategy) - inference_opts = trainer.training_type_plugin.inference_opts - training_opts = trainer.training_type_plugin.training_opts + assert isinstance(trainer.strategy, IPUStrategy) + inference_opts = trainer.strategy.inference_opts + training_opts = trainer.strategy.training_opts for opts in (inference_opts, training_opts): assert isinstance(opts, poptorch.Options) assert opts.Training.gradient_accumulation == 1 @@ -529,7 +529,7 @@ def test_accelerator_ipu_with_devices(): trainer = Trainer(accelerator="ipu", devices=8) assert trainer.ipus == 8 - assert isinstance(trainer.training_type_plugin, IPUStrategy) + assert isinstance(trainer.strategy, IPUStrategy) assert isinstance(trainer.accelerator, IPUAccelerator) @@ -563,14 +563,14 @@ def test_set_devices_if_none_ipu(): @RunIf(ipu=True) def test_strategy_choice_ipu_plugin(tmpdir): trainer = Trainer(strategy=IPUStrategy(), accelerator="ipu", devices=8) - assert isinstance(trainer.training_type_plugin, IPUStrategy) + assert isinstance(trainer.strategy, IPUStrategy) @RunIf(ipu=True) def test_device_type_when_training_plugin_ipu_passed(tmpdir): trainer = Trainer(strategy=IPUStrategy(), ipus=8) - assert isinstance(trainer.training_type_plugin, IPUStrategy) + assert isinstance(trainer.strategy, IPUStrategy) assert trainer._device_type == _AcceleratorType.IPU assert isinstance(trainer.accelerator, IPUAccelerator) @@ -585,8 +585,8 @@ def test_poptorch_models_at_different_stages(tmpdir): trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING - trainer.training_type_plugin.setup(trainer) - assert list(trainer.training_type_plugin.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING] + trainer.strategy.setup(trainer) + assert list(trainer.strategy.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING] for fn, stage in ( (TrainerFn.VALIDATING, RunningStage.VALIDATING), @@ -595,8 +595,8 @@ def test_poptorch_models_at_different_stages(tmpdir): ): trainer.state.fn = fn trainer.state.stage = stage - trainer.training_type_plugin.setup(trainer) - assert list(trainer.training_type_plugin.poptorch_models) == [stage] + trainer.strategy.setup(trainer) + assert list(trainer.strategy.poptorch_models) == [stage] @RunIf(ipu=True) diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 794fb54670632..9f79663b90089 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -118,7 +118,7 @@ def test_accelerator_tpu_with_devices(): trainer = Trainer(accelerator="tpu", devices=8) assert trainer.tpu_cores == 8 - assert isinstance(trainer.training_type_plugin, TPUSpawnStrategy) + assert isinstance(trainer.strategy, TPUSpawnStrategy) assert isinstance(trainer.accelerator, TPUAccelerator) @@ -232,13 +232,13 @@ def test_ddp_cpu_not_supported_on_tpus(): @pytest.mark.parametrize("strategy", ["ddp_spawn", "tpu_spawn_debug"]) def test_strategy_choice_tpu_str(tmpdir, strategy): trainer = Trainer(strategy=strategy, accelerator="tpu", devices=8) - assert isinstance(trainer.training_type_plugin, TPUSpawnStrategy) + assert isinstance(trainer.strategy, TPUSpawnStrategy) @RunIf(tpu=True) def test_strategy_choice_tpu_strategy(tmpdir): trainer = Trainer(strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8) - assert isinstance(trainer.training_type_plugin, TPUSpawnStrategy) + assert isinstance(trainer.strategy, TPUSpawnStrategy) @RunIf(tpu=True) @@ -314,7 +314,7 @@ def test_tpu_invalid_raises_set_precision_with_strategy(): @RunIf(tpu=True) def test_xla_checkpoint_plugin_being_default(): trainer = Trainer(tpu_cores=8) - assert isinstance(trainer.training_type_plugin.checkpoint_io, XLACheckpointIO) + assert isinstance(trainer.strategy.checkpoint_io, XLACheckpointIO) @RunIf(tpu=True) diff --git a/tests/benchmarks/test_sharded_parity.py b/tests/benchmarks/test_sharded_parity.py index 776c63584a795..97b4038159688 100644 --- a/tests/benchmarks/test_sharded_parity.py +++ b/tests/benchmarks/test_sharded_parity.py @@ -146,7 +146,7 @@ def plugin_parity_test( custom_plugin_model = model_cls() trainer = Trainer(fast_dev_run=True, max_epochs=1, gpus=gpus, precision=precision, strategy="ddp_sharded_spawn") - assert isinstance(trainer.training_type_plugin, DDPSpawnShardedStrategy) + assert isinstance(trainer.strategy, DDPSpawnShardedStrategy) max_memory_custom, custom_model_time = record_ddp_fit_model_stats( trainer=trainer, model=custom_plugin_model, use_cuda=use_cuda diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 1ffaadcd00e58..3d45f3f1c33be 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -99,9 +99,9 @@ def on_train_end(self, trainer, pl_module): assert trainer.accumulate_grad_batches == 2 assert trainer.num_training_batches == 5 - if not isinstance(trainer.training_type_plugin, DDPSpawnStrategy): + if not isinstance(trainer.strategy, DDPSpawnStrategy): # check backward call count. the batchnorm update epoch should not backward - assert trainer.training_type_plugin.backward.call_count == trainer.max_epochs * trainer.limit_train_batches + assert trainer.strategy.backward.call_count == trainer.max_epochs * trainer.limit_train_batches # check call counts assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1) @@ -131,7 +131,7 @@ def train_with_swa( num_processes=num_processes, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward): + with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward): trainer.fit(model) # check the model is the expected diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 2c14c7de29b9c..66ac5648b83fc 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -102,7 +102,7 @@ def training_epoch_end(self, outputs) -> None: self.log("my_loss_2", (1 + local_rank), on_epoch=True, rank_zero_only=True) data = str(self.global_rank) obj = [[data], (data,), set(data)] - out = self.trainer.training_type_plugin.broadcast(obj) + out = self.trainer.strategy.broadcast(obj) assert obj == [[str(self.global_rank)], (str(self.global_rank),), set(str(self.global_rank))] assert out == [["0"], ("0",), set("0")] diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 4f327e5eaa13c..485c1b8b834e9 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -311,7 +311,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): model.transfer_batch_to_device = dm.transfer_batch_to_device model.on_after_batch_transfer = dm.on_after_batch_transfer - batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device) + batch_gpu = trainer.strategy.batch_to_device(batch, expected_device) assert dm.on_before_batch_transfer_hook_rank == 0 assert dm.transfer_batch_to_device_hook_rank == 1 diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index b80ec8c88ded5..914afd7c1ac41 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -321,7 +321,7 @@ def on_save_checkpoint(self, checkpoint) -> None: assert state_dict["items"]["validation_step.v"]["value"].device.type == device # sync fn should be kept - assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce + assert results["validation_step.v"].meta.sync.fn == self.trainer.strategy.reduce # sync fn dropped from the state dict assert "fn" not in state_dict["items"]["validation_step.v"]["meta"]["_sync"] @@ -331,7 +331,7 @@ def on_save_checkpoint(self, checkpoint) -> None: assert results["validation_step.v"].value.device.type == device # sync fn was preserved in the original result - assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce + assert results["validation_step.v"].meta.sync.fn == self.trainer.strategy.reduce # default sync fn new_results = _ResultCollection(False, device) @@ -458,7 +458,7 @@ def on_epoch_end(self) -> None: assert not model.has_validated_sum tmpdir = ( - trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0) + trainer.strategy.broadcast(trainer_kwargs["default_root_dir"], 0) if num_processes >= 2 else trainer_kwargs["default_root_dir"] ) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index db75d982a4b4b..46f0a6c91dba8 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -232,3 +232,9 @@ def test_v1_8_0_deprecate_trainer_callback_hook_mixin(): ) with pytest.deprecated_call(match="was deprecated in v1.6 and will be removed in v1.8"): trainer.on_before_zero_grad(optimizer=optim.SGD(model.parameters(), lr=0.01, momentum=0.9)) + + +def test_v1_8_0_deprecated_training_type_plugin_property(): + trainer = Trainer() + with pytest.deprecated_call(match="in v1.6 and will be removed in v1.8"): + trainer.training_type_plugin diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 76cbf8b4fcee7..f09257c83b029 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -58,8 +58,8 @@ def on_train_start(self) -> None: assert self.device == expected_device def training_epoch_end(self, outputs) -> None: - res = self.trainer.training_type_plugin.reduce(torch.tensor(1.0, device=self.device), reduce_op="sum") - assert res.sum() == self.trainer.training_type_plugin.world_size + res = self.trainer.strategy.reduce(torch.tensor(1.0, device=self.device), reduce_op="sum") + assert res.sum() == self.trainer.strategy.world_size model = TestModel() trainer = Trainer(**trainer_options) diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 1a6c91d0b9b98..b1d0116eb165a 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -153,11 +153,11 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): assert trainer.state.finished, "amp + ddp model failed to complete" # test root model address - assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) - assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address("abc") == "abc" - assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address("abc[23]") == "abc23" - assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address("abc[23-24]") == "abc23" - generated = trainer.training_type_plugin.cluster_environment.resolve_root_node_address("abc[23-24, 45-40, 40]") + assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) + assert trainer.strategy.cluster_environment.resolve_root_node_address("abc") == "abc" + assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23]") == "abc23" + assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24]") == "abc23" + generated = trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24, 45-40, 40]") assert generated == "abc23" diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index fe2deab90e680..a3d9977b31c80 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -252,35 +252,35 @@ def test_single_gpu_batch_parse(): # non-transferrable types primitive_objects = [None, {}, [], 1.0, "x", [None, 2], {"x": (1, 2), "y": None}] for batch in primitive_objects: - data = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + data = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) assert data == batch # batch is just a tensor batch = torch.rand(2, 3) - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) assert batch.device.index == 0 and batch.type() == "torch.cuda.FloatTensor" # tensor list batch = [torch.rand(2, 3), torch.rand(2, 3)] - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) assert batch[0].device.index == 0 and batch[0].type() == "torch.cuda.FloatTensor" assert batch[1].device.index == 0 and batch[1].type() == "torch.cuda.FloatTensor" # tensor list of lists batch = [[torch.rand(2, 3), torch.rand(2, 3)]] - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor" assert batch[0][1].device.index == 0 and batch[0][1].type() == "torch.cuda.FloatTensor" # tensor dict batch = [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)}] - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) assert batch[0]["a"].device.index == 0 and batch[0]["a"].type() == "torch.cuda.FloatTensor" assert batch[0]["b"].device.index == 0 and batch[0]["b"].type() == "torch.cuda.FloatTensor" # tuple of tensor list and list of tensor dict batch = ([torch.rand(2, 3) for _ in range(2)], [{"a": torch.rand(2, 3), "b": torch.rand(2, 3)} for _ in range(2)]) - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) assert batch[0][0].device.index == 0 and batch[0][0].type() == "torch.cuda.FloatTensor" assert batch[1][0]["a"].device.index == 0 @@ -292,7 +292,7 @@ def test_single_gpu_batch_parse(): # namedtuple of tensor BatchType = namedtuple("BatchType", ["a", "b"]) batch = [BatchType(a=torch.rand(2, 3), b=torch.rand(2, 3)) for _ in range(2)] - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) assert batch[0].a.device.index == 0 assert batch[0].a.type() == "torch.cuda.FloatTensor" @@ -305,7 +305,7 @@ def to(self, *args, **kwargs): self.a = self.a.to(*args, **kwargs) return self - batch = trainer.training_type_plugin.batch_to_device(CustomBatchType(), torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(CustomBatchType(), torch.device("cuda:0")) assert batch.a.type() == "torch.cuda.FloatTensor" # torchtext.data.Batch @@ -331,7 +331,7 @@ def to(self, *args, **kwargs): batch = Batch(data=examples, dataset=dataset) with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"): - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) assert batch.text.type() == "torch.cuda.LongTensor" assert batch.label.type() == "torch.cuda.LongTensor" @@ -344,7 +344,7 @@ def test_non_blocking(): batch = torch.zeros(2, 3) with patch.object(batch, "to", wraps=batch.to) as mocked: - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) mocked.assert_called_with(torch.device("cuda", 0), non_blocking=True) class BatchObject: @@ -353,5 +353,5 @@ def to(self, *args, **kwargs): batch = BatchObject() with patch.object(batch, "to", wraps=batch.to) as mocked: - batch = trainer.training_type_plugin.batch_to_device(batch, torch.device("cuda:0")) + batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0")) mocked.assert_called_with(torch.device("cuda", 0)) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5842152278df4..06d7b63d0347a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -160,7 +160,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): # running .fit() would require us to implement custom data loaders, we mock the model reference instead model_getter_mock.return_value = model - batch_gpu = trainer.training_type_plugin.batch_to_device(batch, expected_device) + batch_gpu = trainer.strategy.batch_to_device(batch, expected_device) assert model.on_before_batch_transfer_hook_rank == 0 assert model.transfer_batch_to_device_hook_rank == 1 diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 59a22cf1656d1..c3dc03b1a7fde 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -340,7 +340,7 @@ def _compute_batch(): metric = Accuracy( compute_on_step=True, dist_sync_on_step=True, - dist_sync_fn=trainer.training_type_plugin.all_gather, + dist_sync_fn=trainer.strategy.all_gather, threshold=threshold, ) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 084d6a739e26c..4e59761f2e934 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -283,9 +283,9 @@ def test_broadcast_on_tpu(): def test_broadcast(rank): trainer = Trainer(tpu_cores=8) assert isinstance(trainer.accelerator, TPUAccelerator) - assert isinstance(trainer.training_type_plugin, TPUSpawnStrategy) + assert isinstance(trainer.strategy, TPUSpawnStrategy) obj = ("ver_0.5", "logger_name", rank) - result = trainer.training_type_plugin.broadcast(obj) + result = trainer.strategy.broadcast(obj) assert result == ("ver_0.5", "logger_name", 0) xmp.spawn(test_broadcast, nprocs=8, start_method="fork") @@ -349,9 +349,9 @@ def test_reduce(rank): for reduce_op in reduce_ops: if reduce_op == "undefined" or reduce_op == ReduceOp.MAX: with pytest.raises(MisconfigurationException, match="TPUSpawn Strategy only support"): - result = trainer.training_type_plugin.reduce(1, reduce_op) + result = trainer.strategy.reduce(1, reduce_op) else: - result = trainer.training_type_plugin.reduce(1, reduce_op) + result = trainer.strategy.reduce(1, reduce_op) if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"): assert result.item() == 1 else: @@ -473,6 +473,6 @@ def teardown(self, stage): def test_device_type_when_training_plugin_tpu_passed(tmpdir): trainer = Trainer(strategy=TPUSpawnStrategy(), tpu_cores=8) - assert isinstance(trainer.training_type_plugin, TPUSpawnStrategy) + assert isinstance(trainer.strategy, TPUSpawnStrategy) assert trainer._device_type == _AcceleratorType.TPU assert isinstance(trainer.accelerator, TPUAccelerator) diff --git a/tests/plugins/environments/torch_elastic_deadlock.py b/tests/plugins/environments/torch_elastic_deadlock.py index 7d612db3a5743..1acf7f60f1619 100644 --- a/tests/plugins/environments/torch_elastic_deadlock.py +++ b/tests/plugins/environments/torch_elastic_deadlock.py @@ -25,7 +25,7 @@ def training_step(self, batch, batch_idx): trainer = Trainer( default_root_dir=".", max_epochs=1, limit_train_batches=5, num_sanity_val_steps=0, gpus=2, strategy="ddp" ) - assert isinstance(trainer.training_type_plugin, DDPStrategy) + assert isinstance(trainer.strategy, DDPStrategy) with suppress(DeadlockDetectedException): # simulate random failure in training_step on rank 0 diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py index 41e4922b5af1b..3694db9ccf629 100644 --- a/tests/plugins/test_cluster_integration.py +++ b/tests/plugins/test_cluster_integration.py @@ -108,7 +108,7 @@ def test_ranks_available_automatic_plugin_selection(mock0, mock1, trainer_kwargs with mock.patch.dict(os.environ, variables): trainer = Trainer(**trainer_kwargs) - assert type(trainer.training_type_plugin.cluster_environment) is type(cluster) + assert type(trainer.strategy.cluster_environment) is type(cluster) assert rank_zero_only.rank == expected["global_rank"] assert trainer.global_rank == expected["global_rank"] assert trainer.local_rank == expected["local_rank"] diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 6c47eaaaa98c5..98a65caacaa2d 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -23,8 +23,8 @@ def test_invalid_on_cpu(tmpdir): MisconfigurationException, match="You selected accelerator to be `ddp_fully_sharded`, but GPU is not available." ): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp") - assert isinstance(trainer.training_type_plugin, DDPFullyShardedStrategy) - trainer.training_type_plugin.setup_environment() + assert isinstance(trainer.strategy, DDPFullyShardedStrategy) + trainer.strategy.setup_environment() @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @@ -34,8 +34,8 @@ def test_invalid_on_cpu(tmpdir): def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", gpus=1, precision=16) - assert isinstance(trainer.training_type_plugin, DDPFullyShardedStrategy) - assert isinstance(trainer.training_type_plugin.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) + assert isinstance(trainer.strategy, DDPFullyShardedStrategy) + assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) class TestFSDPModel(BoringModel): @@ -110,7 +110,7 @@ def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir): def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): # Use FullySharded to get the state dict for the sake of comparison - model_state_dict = trainer.training_type_plugin.lightning_module_state_dict() + model_state_dict = trainer.strategy.lightning_module_state_dict() if trainer.is_global_zero: saved_model = cls.load_from_checkpoint(ckpt_path) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 80033cb4ceb51..028d1a719f8ae 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -29,7 +29,7 @@ class BoringModelGPU(BoringModel): def on_train_start(self) -> None: # make sure that the model is on GPU when training - assert self.device == torch.device(f"cuda:{self.trainer.training_type_plugin.local_rank}") + assert self.device == torch.device(f"cuda:{self.trainer.strategy.local_rank}") self.start_cuda_memory = torch.cuda.memory_allocated() @@ -38,11 +38,11 @@ def test_ddp_with_2_gpus(): """Tests if device is set correctely when training and after teardown for DDPStrategy.""" trainer = Trainer(gpus=2, strategy="ddp", fast_dev_run=True) # assert training type plugin attributes for device setting - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert trainer.training_type_plugin.on_gpu - assert not trainer.training_type_plugin.on_tpu - local_rank = trainer.training_type_plugin.local_rank - assert trainer.training_type_plugin.root_device == torch.device(f"cuda:{local_rank}") + assert isinstance(trainer.strategy, DDPStrategy) + assert trainer.strategy.on_gpu + assert not trainer.strategy.on_tpu + local_rank = trainer.strategy.local_rank + assert trainer.strategy.root_device == torch.device(f"cuda:{local_rank}") model = BoringModelGPU() @@ -56,12 +56,12 @@ def test_ddp_with_2_gpus(): class BarrierModel(BoringModel): def setup(self, stage=None): - assert not isinstance(self.trainer.training_type_plugin.model, DistributedDataParallel) - self.trainer.training_type_plugin.barrier("barrier before model is wrapped") + assert not isinstance(self.trainer.strategy.model, DistributedDataParallel) + self.trainer.strategy.barrier("barrier before model is wrapped") def on_train_start(self): - assert isinstance(self.trainer.training_type_plugin.model, DistributedDataParallel) - self.trainer.training_type_plugin.barrier("barrier after model is wrapped") + assert isinstance(self.trainer.strategy.model, DistributedDataParallel) + self.trainer.strategy.barrier("barrier after model is wrapped") @RunIf(min_gpus=4, standalone=True) @@ -109,11 +109,11 @@ def test_ddp_configure_ddp(): ) # test wrap the model if fitting trainer.state.fn = TrainerFn.FITTING - trainer.training_type_plugin.connect(model) + trainer.strategy.connect(model) trainer.lightning_module.trainer = trainer - trainer.training_type_plugin.setup_environment() + trainer.strategy.setup_environment() assert isinstance(trainer.model, LightningModule) - trainer.training_type_plugin.setup(trainer) + trainer.strategy.setup(trainer) # in DDPStrategy configure_ddp(), model wrapped by DistributedDataParallel assert isinstance(trainer.model, DistributedDataParallel) @@ -123,9 +123,9 @@ def test_ddp_configure_ddp(): ) # test do not wrap the model if trainerFN is not fitting trainer.state.fn = TrainerFn.VALIDATING - trainer.training_type_plugin.connect(model) + trainer.strategy.connect(model) trainer.lightning_module.trainer = trainer - trainer.training_type_plugin.setup_environment() - trainer.training_type_plugin.setup(trainer) + trainer.strategy.setup_environment() + trainer.strategy.setup(trainer) # in DDPStrategy configure_ddp(), model are still LightningModule assert isinstance(trainer.model, LightningModule) diff --git a/tests/plugins/test_ddp_plugin_with_comm_hook.py b/tests/plugins/test_ddp_plugin_with_comm_hook.py index 2a5872fc6221c..7b4cfff5923d1 100644 --- a/tests/plugins/test_ddp_plugin_with_comm_hook.py +++ b/tests/plugins/test_ddp_plugin_with_comm_hook.py @@ -40,7 +40,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.training_type_plugin.model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook expected_comm_hook = default.fp16_compress_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -63,7 +63,7 @@ def test_ddp_sgd_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.training_type_plugin.model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -87,7 +87,7 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): fast_dev_run=True, ) trainer.fit(model) - trainer_comm_hook = trainer.training_type_plugin.model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -132,7 +132,7 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir): sync_batchnorm=True, ) trainer.fit(model) - trainer_comm_hook = trainer.training_type_plugin.model.get_ddp_logging_data().comm_hook + trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 5f0074dcd7718..48eda9091bc9c 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -56,10 +56,10 @@ def test_ddp_cpu(): trainer = Trainer(num_processes=2, fast_dev_run=True) # assert training type plugin attributes for device setting - assert isinstance(trainer.training_type_plugin, DDPSpawnStrategy) - assert not trainer.training_type_plugin.on_gpu - assert not trainer.training_type_plugin.on_tpu - assert trainer.training_type_plugin.root_device == torch.device("cpu") + assert isinstance(trainer.strategy, DDPSpawnStrategy) + assert not trainer.strategy.on_gpu + assert not trainer.strategy.on_tpu + assert trainer.strategy.root_device == torch.device("cpu") model = BoringModelDDPCPU() @@ -72,9 +72,9 @@ def test_ddp_spawn_extra_parameters(tmpdir): with Lightning Module (deprecated way).""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2, strategy="ddp_spawn") - assert isinstance(trainer.training_type_plugin, DDPSpawnStrategy) - assert trainer.training_type_plugin.on_gpu - assert trainer.training_type_plugin.root_device == torch.device("cuda:0") + assert isinstance(trainer.strategy, DDPSpawnStrategy) + assert trainer.strategy.on_gpu + assert trainer.strategy.root_device == torch.device("cuda:0") val: float = 1.0 val_name: str = "val_acc" diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 870686c1d9291..51c6299743d5d 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -133,8 +133,8 @@ def test_deepspeed_plugin_string(tmpdir, plugin): fast_dev_run=True, default_root_dir=tmpdir, strategy=plugin if isinstance(plugin, str) else plugin() ) - assert isinstance(trainer.training_type_plugin, DeepSpeedStrategy) - assert trainer.training_type_plugin.parallel_devices == [torch.device("cpu")] + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert trainer.strategy.parallel_devices == [torch.device("cpu")] @RunIf(deepspeed=True) @@ -147,7 +147,7 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, strategy="deepspeed") - plugin = trainer.training_type_plugin + plugin = trainer.strategy assert isinstance(plugin, DeepSpeedStrategy) assert plugin.parallel_devices == [torch.device("cpu")] assert plugin.config == deepspeed_config @@ -169,9 +169,9 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): fast_dev_run=True, default_root_dir=tmpdir, strategy="deepspeed", amp_backend=amp_backend, precision=precision ) - assert isinstance(trainer.training_type_plugin, DeepSpeedStrategy) - assert isinstance(trainer.training_type_plugin.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.training_type_plugin.precision_plugin.precision == precision + assert isinstance(trainer.strategy, DeepSpeedStrategy) + assert isinstance(trainer.strategy.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.strategy.precision_plugin.precision == precision @RunIf(deepspeed=True) @@ -240,8 +240,8 @@ def train_dataloader(self): class AssertCallback(Callback): def setup(self, trainer, pl_module, stage: Optional[str] = None) -> None: - assert isinstance(trainer.training_type_plugin, DeepSpeedStrategy) - config = trainer.training_type_plugin.config + assert isinstance(trainer.strategy, DeepSpeedStrategy) + config = trainer.strategy.config # int value overrides auto mode expected_value = value if isinstance(value, int) else 1 @@ -336,11 +336,11 @@ def test_deepspeed_custom_precision_params(tmpdir): class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: - assert trainer.training_type_plugin.config["fp16"]["loss_scale"] == 10 - assert trainer.training_type_plugin.config["fp16"]["initial_scale_power"] == 10 - assert trainer.training_type_plugin.config["fp16"]["loss_scale_window"] == 10 - assert trainer.training_type_plugin.config["fp16"]["hysteresis"] == 10 - assert trainer.training_type_plugin.config["fp16"]["min_loss_scale"] == 10 + assert trainer.strategy.config["fp16"]["loss_scale"] == 10 + assert trainer.strategy.config["fp16"]["initial_scale_power"] == 10 + assert trainer.strategy.config["fp16"]["loss_scale_window"] == 10 + assert trainer.strategy.config["fp16"]["hysteresis"] == 10 + assert trainer.strategy.config["fp16"]["min_loss_scale"] == 10 raise SystemExit() model = BoringModel() @@ -406,7 +406,7 @@ def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_co class TestCallback(Callback): def on_before_accelerator_backend_setup(self, trainer, pl_module) -> None: - assert trainer.training_type_plugin.config["zero_optimization"]["offload_optimizer"] is False + assert trainer.strategy.config["zero_optimization"]["offload_optimizer"] is False raise SystemExit() model = BoringModel() @@ -478,7 +478,7 @@ def test_deepspeed_multigpu_single_file(tmpdir): trainer = Trainer( default_root_dir=tmpdir, strategy=DeepSpeedStrategy(stage=3), gpus=1, fast_dev_run=True, precision=16 ) - plugin = trainer.training_type_plugin + plugin = trainer.strategy assert isinstance(plugin, DeepSpeedStrategy) assert not plugin.load_full_weights with pytest.raises(MisconfigurationException, match="DeepSpeed was unable to load the checkpoint."): @@ -491,7 +491,7 @@ def test_deepspeed_multigpu_single_file(tmpdir): fast_dev_run=True, precision=16, ) - plugin = trainer.training_type_plugin + plugin = trainer.strategy assert isinstance(plugin, DeepSpeedStrategy) assert plugin.load_full_weights trainer.test(model, ckpt_path=checkpoint_path) @@ -690,8 +690,8 @@ class TestCallback(Callback): def on_train_batch_start( self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int ) -> None: - original_deepspeed_plugin = initial_trainer.training_type_plugin - current_deepspeed_plugin = trainer.training_type_plugin + original_deepspeed_plugin = initial_trainer.strategy + current_deepspeed_plugin = trainer.strategy assert isinstance(original_deepspeed_plugin, DeepSpeedStrategy) assert isinstance(current_deepspeed_plugin, DeepSpeedStrategy) @@ -731,7 +731,7 @@ def __init__(self): self.on_train_batch_start_called = False def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: - deepspeed_engine = trainer.training_type_plugin.model + deepspeed_engine = trainer.strategy.model assert trainer.global_step == deepspeed_engine.global_steps self.on_train_batch_start_called = True @@ -830,7 +830,7 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat When using windows, ranks environment variables should not be set, and deepspeed should handle this. """ trainer = Trainer(default_root_dir=tmpdir, strategy=DeepSpeedStrategy(stage=3)) - plugin = trainer.training_type_plugin + plugin = trainer.strategy assert isinstance(plugin, DeepSpeedStrategy) with mock.patch("platform.system", return_value=platform) as mock_platform: plugin._init_deepspeed_distributed() @@ -840,18 +840,18 @@ def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, plat # assert no env variables have been set within the DeepSpeedStrategy assert all(k not in os.environ for k in ("MASTER_PORT", "MASTER_ADDR", "RANK", "WORLD_SIZE", "LOCAL_RANK")) else: - assert os.environ["MASTER_ADDR"] == str(trainer.training_type_plugin.cluster_environment.main_address) - assert os.environ["MASTER_PORT"] == str(trainer.training_type_plugin.cluster_environment.main_port) - assert os.environ["RANK"] == str(trainer.training_type_plugin.global_rank) - assert os.environ["WORLD_SIZE"] == str(trainer.training_type_plugin.world_size) - assert os.environ["LOCAL_RANK"] == str(trainer.training_type_plugin.local_rank) + assert os.environ["MASTER_ADDR"] == str(trainer.strategy.cluster_environment.main_address) + assert os.environ["MASTER_PORT"] == str(trainer.strategy.cluster_environment.main_port) + assert os.environ["RANK"] == str(trainer.strategy.global_rank) + assert os.environ["WORLD_SIZE"] == str(trainer.strategy.world_size) + assert os.environ["LOCAL_RANK"] == str(trainer.strategy.local_rank) def _assert_save_model_is_equal(model, tmpdir, trainer): checkpoint_path = os.path.join(tmpdir, "model.pt") - checkpoint_path = trainer.training_type_plugin.broadcast(checkpoint_path) + checkpoint_path = trainer.strategy.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - trainer.training_type_plugin.barrier() + trainer.strategy.barrier() # carry out the check only on rank 0 if trainer.is_global_zero: diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py index 3511c6e44f27a..d071c3843c31f 100644 --- a/tests/plugins/test_plugins_registry.py +++ b/tests/plugins/test_plugins_registry.py @@ -78,7 +78,7 @@ def test_deepspeed_training_type_plugins_registry_with_trainer(tmpdir, plugin): trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, precision=16) - assert isinstance(trainer.training_type_plugin, DeepSpeedStrategy) + assert isinstance(trainer.strategy, DeepSpeedStrategy) def test_tpu_spawn_debug_plugins_registry(tmpdir): @@ -91,7 +91,7 @@ def test_tpu_spawn_debug_plugins_registry(tmpdir): trainer = Trainer(strategy=plugin) - assert isinstance(trainer.training_type_plugin, TPUSpawnStrategy) + assert isinstance(trainer.strategy, TPUSpawnStrategy) def test_fsdp_strategys_registry(tmpdir): @@ -103,7 +103,7 @@ def test_fsdp_strategys_registry(tmpdir): trainer = Trainer(strategy=plugin) - assert isinstance(trainer.training_type_plugin, DDPFullyShardedStrategy) + assert isinstance(trainer.strategy, DDPFullyShardedStrategy) @pytest.mark.parametrize( @@ -119,7 +119,7 @@ def test_ddp_find_unused_parameters_training_type_plugins_registry(tmpdir, plugi trainer = Trainer(default_root_dir=tmpdir, strategy=plugin_name) - assert isinstance(trainer.training_type_plugin, plugin) + assert isinstance(trainer.strategy, plugin) assert plugin_name in TrainingTypePluginsRegistry assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"find_unused_parameters": False} @@ -148,5 +148,5 @@ def remove_checkpoint(self, path): ) trainer = Trainer(strategy="ddp_custom_checkpoint_io", accelerator="cpu", devices=2) - assert isinstance(trainer.training_type_plugin, DDPStrategy) - assert trainer.training_type_plugin.checkpoint_io == custom_checkpoint_io + assert isinstance(trainer.strategy, DDPStrategy) + assert trainer.strategy.checkpoint_io == custom_checkpoint_io diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 994148fd1ff95..1ae27fec88569 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -37,7 +37,7 @@ def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_v def test_sharded_ddp_choice(tmpdir, strategy, expected): """Test to ensure that plugin is correctly chosen.""" trainer = Trainer(fast_dev_run=True, strategy=strategy) - assert isinstance(trainer.training_type_plugin, expected) + assert isinstance(trainer.strategy, expected) @RunIf(min_gpus=1, fairscale=True) @@ -47,7 +47,7 @@ def test_sharded_ddp_choice(tmpdir, strategy, expected): def test_ddp_choice_sharded_amp(tmpdir, strategy, expected): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(fast_dev_run=True, gpus=1, precision=16, strategy=strategy) - assert isinstance(trainer.training_type_plugin, expected) + assert isinstance(trainer.strategy, expected) @RunIf(skip_windows=True, fairscale=True) diff --git a/tests/plugins/test_single_device_plugin.py b/tests/plugins/test_single_device_plugin.py index 93da416ce7aab..835065ac5c70d 100644 --- a/tests/plugins/test_single_device_plugin.py +++ b/tests/plugins/test_single_device_plugin.py @@ -22,10 +22,10 @@ def test_single_cpu(): """Tests if on_gpu and on_tpu is set correctly for single cpu plugin.""" trainer = Trainer() - assert isinstance(trainer.training_type_plugin, SingleDeviceStrategy) - assert not trainer.training_type_plugin.on_gpu - assert not trainer.training_type_plugin.on_tpu - assert trainer.training_type_plugin.root_device == torch.device("cpu") + assert isinstance(trainer.strategy, SingleDeviceStrategy) + assert not trainer.strategy.on_gpu + assert not trainer.strategy.on_tpu + assert trainer.strategy.root_device == torch.device("cpu") class BoringModelGPU(BoringModel): @@ -40,10 +40,10 @@ def test_single_gpu(): """Tests if device is set correctly when training and after teardown for single GPU plugin.""" trainer = Trainer(gpus=1, fast_dev_run=True) # assert training type plugin attributes for device setting - assert isinstance(trainer.training_type_plugin, SingleDeviceStrategy) - assert trainer.training_type_plugin.on_gpu - assert not trainer.training_type_plugin.on_tpu - assert trainer.training_type_plugin.root_device == torch.device("cuda:0") + assert isinstance(trainer.strategy, SingleDeviceStrategy) + assert trainer.strategy.on_gpu + assert not trainer.strategy.on_tpu + assert trainer.strategy.root_device == torch.device("cuda:0") model = BoringModelGPU() diff --git a/tests/plugins/test_tpu_spawn.py b/tests/plugins/test_tpu_spawn.py index 167d7c54fc464..dbaf13ff262cf 100644 --- a/tests/plugins/test_tpu_spawn.py +++ b/tests/plugins/test_tpu_spawn.py @@ -95,10 +95,10 @@ def test_model_tpu_one_core(): """Tests if device/debug flag is set correctely when training and after teardown for TPUSpawnStrategy.""" trainer = Trainer(tpu_cores=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) # assert training type plugin attributes for device setting - assert isinstance(trainer.training_type_plugin, TPUSpawnStrategy) - assert not trainer.training_type_plugin.on_gpu - assert trainer.training_type_plugin.on_tpu - assert trainer.training_type_plugin.root_device == torch.device("xla", index=1) + assert isinstance(trainer.strategy, TPUSpawnStrategy) + assert not trainer.strategy.on_gpu + assert trainer.strategy.on_tpu + assert trainer.strategy.root_device == torch.device("xla", index=1) model = BoringModelTPU() trainer.fit(model) assert "PT_XLA_DEBUG" not in os.environ diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 99b29cc279b8b..6da4659d32673 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -128,7 +128,7 @@ def on_train_end(self): ) scaler_step = scaler_step_patch.start() - with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: + with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -162,7 +162,7 @@ def training_epoch_end(self, outputs) -> None: enable_model_summary=False, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: + with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -189,7 +189,7 @@ def training_epoch_end(self, outputs) -> None: enable_model_summary=False, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: + with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 assert set(trainer.logged_metrics) == {"a_step", "a_epoch"} @@ -212,7 +212,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir): gpus=1, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: + with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -470,7 +470,7 @@ def log_grad_norm(self, grad_norm_dict): track_grad_norm=2, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: + with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 @@ -540,7 +540,7 @@ def configure_optimizers(self): log_every_n_steps=1, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: + with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 2 assert trainer.progress_bar_metrics["train_loss_step"] == model._losses[-1] @@ -596,7 +596,7 @@ def configure_optimizers(self): log_every_n_steps=1, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward) as bwd_mock: + with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 2 diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 780ecea1a19fe..eebf9bb992615 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -260,7 +260,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, num_batches = 128 / batch_size for dl in (train_dl, val_dl, test_dl): - if has_len_all_ranks(dl, trainer.training_type_plugin, model): + if has_len_all_ranks(dl, trainer.strategy, model): assert len(dl) == num_batches else: assert sum(1 for _ in dl) == num_batches @@ -759,7 +759,7 @@ def __len__(self): # with __len__ defined trainer = Trainer(default_root_dir=tmpdir, max_steps=3) dataloader = DataLoader(IterableWithLen(), batch_size=16) - assert has_len_all_ranks(dataloader, trainer.training_type_plugin, model) + assert has_len_all_ranks(dataloader, trainer.strategy, model) assert has_iterable_dataset(dataloader) with pytest.warns(UserWarning, match="Your `IterableDataset` has `__len__` defined."): trainer.validate(model, dataloaders=[dataloader]) @@ -773,7 +773,7 @@ def __len__(self): # without __len__ defined trainer = Trainer(default_root_dir=tmpdir, max_steps=3) dataloader = DataLoader(IterableWithoutLen(), batch_size=16) - assert not has_len_all_ranks(dataloader, trainer.training_type_plugin, model) + assert not has_len_all_ranks(dataloader, trainer.strategy, model) assert has_iterable_dataset(dataloader) trainer.validate(model, dataloaders=dataloader) trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=[dataloader]) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index fe24e3c061816..743b1cea351b4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1432,7 +1432,7 @@ def predict( else: results = trainer.predict(model, dataloaders=dataloaders) - if not isinstance(trainer.training_type_plugin, DDPSpawnStrategy): + if not isinstance(trainer.strategy, DDPSpawnStrategy): if use_callbacks: assert cb.write_on_batch_end_called assert not cb.write_on_epoch_end_called @@ -1530,7 +1530,7 @@ def test_spawn_predict_return_predictions(_, __, accelerator): """Test that `return_predictions=True` raise a MisconfigurationException with spawn training type plugins.""" model = BoringModel() trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True) - assert isinstance(trainer.training_type_plugin, DDPSpawnStrategy) + assert isinstance(trainer.strategy, DDPSpawnStrategy) with pytest.raises(ProcessRaisedException, match="`return_predictions` should be set to `False`"): trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True) @@ -1948,7 +1948,7 @@ def configure_optimizers(self): class Check(Callback): def on_epoch_start(self, trainer, *_): - assert isinstance(trainer.training_type_plugin.model, DistributedDataParallel) + assert isinstance(trainer.strategy.model, DistributedDataParallel) def current_memory(): # before measuring the memory force release any leftover allocations, including CUDA tensors @@ -1969,7 +1969,7 @@ def current_memory(): trainer = Trainer(**trainer_kwargs) trainer.fit(model) - assert trainer.training_type_plugin.model is model + assert trainer.strategy.model is model assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu") assert trainer.callback_metrics["train_loss"].device == torch.device("cpu") diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index e202941cf0fbb..629d141505004 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -111,9 +111,9 @@ def test_has_len_all_rank(): model = BoringModel() with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."): - assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model) + assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model) - assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model) + assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy, model) def test_update_dataloader_typerror_custom_exception(): diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index 0058b49d3fda2..7911eaa8ea902 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -31,9 +31,9 @@ def test_deepspeed_collate_checkpoint(tmpdir): ) trainer.fit(model) checkpoint_path = os.path.join(tmpdir, "model.pt") - checkpoint_path = trainer.training_type_plugin.broadcast(checkpoint_path) + checkpoint_path = trainer.strategy.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) - trainer.training_type_plugin.barrier() + trainer.strategy.barrier() if trainer.is_global_zero: # ensure function call works output_path = os.path.join(tmpdir, "single_model.pt")