Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate Trainer.training_type_plugin in favor of trainer.strategy #11141

Merged
merged 5 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))


- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))


### Deprecated

- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/loop_examples/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def on_run_end(self) -> None:
voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
voting_model.trainer = self.trainer
# This requires to connect the new model and move it the right device.
self.trainer.training_type_plugin.connect(voting_model)
self.trainer.training_type_plugin.model_to_device()
self.trainer.strategy.connect(voting_model)
self.trainer.strategy.model_to_device()
self.trainer.test_loop.run()

def on_save_checkpoint(self) -> Dict[str, int]:
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/loop_examples/yielding_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_generator(self, split_batch, batch_idx, opt_idx):
# Here we are basically calling `lightning_module.training_step()`
# and this returns a generator! The `training_step` is handled by the
# accelerator to enable distributed training.
return self.trainer.training_type_plugin.training_step(*step_kwargs.values())
return self.trainer.strategy.training_step(*step_kwargs.values())

def _training_step(self, generator):
# required for logging
Expand All @@ -86,7 +86,7 @@ def _training_step(self, generator):
# Here, instead of calling `lightning_module.training_step()`
# we call next() on the generator!
training_step_output = next(generator)
self.trainer.training_type_plugin.post_training_step()
self.trainer.strategy.post_training_step()

model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
should_stop, reason = self._evaluate_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
should_stop = trainer.strategy.reduce_boolean_decision(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def on_train_batch_end(
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
skip_time = trainer.training_type_plugin.broadcast(skip_time)
skip_time = trainer.strategy.broadcast(skip_time)

if skip_batch and skip_time:
return
Expand Down Expand Up @@ -492,7 +492,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Ten
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)

return should_update_best_and_save

Expand Down Expand Up @@ -598,7 +598,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path)
ckpt_path = trainer.strategy.broadcast(ckpt_path)

self.dirpath = ckpt_path

Expand Down Expand Up @@ -646,7 +646,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
trainer.save_checkpoint(filepath, self.save_weights_only)

if self.last_model_path and self.last_model_path != filepath:
trainer.training_type_plugin.remove_checkpoint(self.last_model_path)
trainer.strategy.remove_checkpoint(self.last_model_path)

self.last_model_path = filepath

Expand All @@ -671,7 +671,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
trainer.save_checkpoint(filepath, self.save_weights_only)

if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
trainer.training_type_plugin.remove_checkpoint(self.best_model_path)
trainer.strategy.remove_checkpoint(self.best_model_path)

self.best_model_path = filepath

Expand Down Expand Up @@ -718,7 +718,7 @@ def _update_best_and_save(
trainer.save_checkpoint(filepath, self.save_weights_only)

if del_filepath is not None and filepath != del_filepath:
trainer.training_type_plugin.remove_checkpoint(del_filepath)
trainer.strategy.remove_checkpoint(del_filepath)

def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
Expand All @@ -733,4 +733,4 @@ def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
state to diverge between ranks."""
exists = self._fs.exists(filepath)
return trainer.training_type_plugin.broadcast(exists)
return trainer.strategy.broadcast(exists)
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def on_load_checkpoint(
def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
assert self._duration is not None
should_stop = self.time_elapsed() >= self._duration
should_stop = trainer.training_type_plugin.broadcast(should_stop)
should_stop = trainer.strategy.broadcast(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop and self._verbose:
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/xla_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def on_train_start(self, trainer, pl_module) -> None:
)

memory_info = xm.get_memory_info(pl_module.device)
total_memory = trainer.training_type_plugin.reduce(memory_info["kb_total"]) * 0.001
total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")

def on_train_epoch_start(self, trainer, pl_module) -> None:
Expand All @@ -91,9 +91,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None:
free_memory = memory_info["kb_free"]
peak_memory = memory_info["kb_total"] - free_memory

free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001
peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
free_memory = trainer.strategy.reduce(free_memory) * 0.001
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
epoch_time = trainer.strategy.reduce(epoch_time)

logs["avg. free memory (MB)"] = free_memory
logs["avg. peak memory (MB)"] = peak_memory
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def log(
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
sync_dist_fn=self.trainer.strategy.reduce or sync_ddp,
sync_dist_group=sync_dist_group,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
Expand Down Expand Up @@ -536,7 +536,7 @@ def all_gather(
the output will also be a collection with tensors of this shape.
"""
group = group if group is not None else torch.distributed.group.WORLD
all_gather = self.trainer.training_type_plugin.all_gather
all_gather = self.trainer.strategy.all_gather
data = convert_to_tensors(data, device=self.device)
return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads)

Expand Down Expand Up @@ -1337,7 +1337,7 @@ def training_step(...):
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
"""
self._verify_is_manual_optimization("manual_backward")
self.trainer.training_type_plugin.backward(loss, None, None, *args, **kwargs)
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)

def backward(
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ def closure_dis():
trainer = self._trainer
assert trainer is not None
with trainer.profiler.profile(profiler_action):
trainer.training_type_plugin.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
4 changes: 1 addition & 3 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
# Python primitives. However, their states are saved with the model's `state_dict`.
# On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`.
# The references are provided through the `metric_attributes` dictionary.
v.load_state_dict(
state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
)
v.load_state_dict(state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.strategy.reduce)

if not self.trainer.is_global_zero:
v.reset(metrics=False)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
void(*args, **kwargs)

dataloader_idx = self.current_dataloader_idx
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader)
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def return_predictions(self) -> bool:
@return_predictions.setter
def return_predictions(self, return_predictions: Optional[bool] = None) -> None:
# `DDPSpawnStrategy` plugins and derivatives don't support return predictions.
is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnStrategy)
is_ddp_spawn = isinstance(self.trainer.strategy, DDPSpawnStrategy)
if return_predictions and is_ddp_spawn:
raise MisconfigurationException(
"`return_predictions` should be set to `False` when using the `DDPSpawnStrategy` or children class. "
f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}."
f"Found {return_predictions} with training_type_plugin {type(self.trainer.strategy)}."
)
# For non `DDPSpawnStrategy` plugin, the `return_predictions` is True by default unless user decide otherwise.
self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions
Expand Down Expand Up @@ -86,7 +86,7 @@ def on_run_start(self) -> None: # type: ignore[override]
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Predicts one entire dataloader."""
void(*args, **kwargs)
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
dl_max_batches = self.max_batches[self.current_dataloader_idx]

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,7 @@ def _should_accumulate(self) -> bool:
# Lightning steps on the final batch
is_final_batch = self._num_ready_batches_reached()
# but the TTP might not
ttp_accumulates_on_final_batch = (
self.trainer.training_type_plugin.handles_gradient_accumulation or not is_final_batch
)
ttp_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch
return not accumulation_done and ttp_accumulates_on_final_batch

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def on_advance_start(self) -> None: # type: ignore[override]

def advance(self) -> None: # type: ignore[override]
"""Runs one whole epoch."""
dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
Expand Down Expand Up @@ -255,7 +255,7 @@ def on_run_end(self) -> None:
self.trainer._call_strategy_hook("on_train_end")

# give accelerators a chance to finish
self.trainer.training_type_plugin.on_train_end()
self.trainer.strategy.on_train_end()

def teardown(self) -> None:
self.epoch_loop.teardown()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]

# manually capture logged metrics
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
self.trainer.strategy.post_training_step()

del step_kwargs

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _run_optimization(

if (
# when the training type plugin handles accumulation, we want to always call the optimizer step
not self.trainer.training_type_plugin.handles_gradient_accumulation
not self.trainer.strategy.handles_gradient_accumulation
and self.trainer.fit_loop._should_accumulate()
):
# For gradient accumulation
Expand Down Expand Up @@ -427,7 +427,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

# manually capture logged metrics
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
self.trainer.strategy.post_training_step()

del step_kwargs

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) ->
Returns:
context manager with sync behaviour off
"""
if isinstance(trainer.training_type_plugin, ParallelStrategy) and block:
with trainer.training_type_plugin.block_backward_sync():
if isinstance(trainer.strategy, ParallelStrategy) and block:
with trainer.strategy.block_backward_sync():
yield None
else:
yield None
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS:

def dispatch(self, trainer: "pl.Trainer") -> None:
if not self._connected:
strategy = trainer.training_type_plugin
strategy = trainer.strategy
_, strategy.optimizers = amp.initialize(
trainer.lightning_module, strategy.optimizers, opt_level=self.amp_level
)
Expand Down
19 changes: 8 additions & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:

def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
with pl_legacy_patch():
loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
raise ValueError(
"The checkpoint you're attempting to load follows an"
Expand Down Expand Up @@ -113,7 +113,7 @@ def resume_end(self) -> None:
torch.cuda.empty_cache()

# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")
self.trainer.strategy.barrier("CheckpointConnector.resume_end")

def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and
Expand Down Expand Up @@ -170,7 +170,7 @@ def restore_model(self) -> None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)
self.trainer.strategy.load_model_state_dict(self._loaded_checkpoint)

# reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing.
if not self.trainer.is_global_zero:
Expand Down Expand Up @@ -258,10 +258,7 @@ def restore_loops(self) -> None:

def restore_optimizers_and_schedulers(self) -> None:
"""Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint."""
if (
not self._loaded_checkpoint
or not self.trainer.training_type_plugin.lightning_restore_optimizer_and_schedulers
):
if not self._loaded_checkpoint or not self.trainer.strategy.lightning_restore_optimizer_and_schedulers:
return

# validation
Expand All @@ -279,7 +276,7 @@ def restore_optimizers(self) -> None:
return

# restore the optimizers
self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint)
self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
for optimizer in self.trainer.optimizers:
# move optimizer to GPU 1 weight at a time
# avoids OOM
Expand Down Expand Up @@ -387,7 +384,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.training_type_plugin.optimizer_state(optimizer)
optimizer_state = self.trainer.strategy.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)

checkpoint["optimizer_states"] = optimizer_states
Expand Down Expand Up @@ -463,7 +460,7 @@ def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
weights_only: saving model weights only
"""
_checkpoint = self.dump_checkpoint(weights_only)
self.trainer.training_type_plugin.save_checkpoint(_checkpoint, filepath)
self.trainer.strategy.save_checkpoint(_checkpoint, filepath)

def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metrics = (
Expand All @@ -476,7 +473,7 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metric.persistent(True)
metric.sync()

state_dict = self.trainer.training_type_plugin.lightning_module_state_dict()
state_dict = self.trainer.strategy.lightning_module_state_dict()

for metric in metrics:
# sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check
Expand Down
Loading