Skip to content

Commit

Permalink
deprecate trainer.training_type_plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Dec 21, 2021
1 parent f98cd78 commit d31a27e
Show file tree
Hide file tree
Showing 55 changed files with 387 additions and 381 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))


- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))


### Deprecated

- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/loop_examples/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def on_run_end(self) -> None:
voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
voting_model.trainer = self.trainer
# This requires to connect the new model and move it the right device.
self.trainer.training_type_plugin.connect(voting_model)
self.trainer.training_type_plugin.model_to_device()
self.trainer.strategy.connect(voting_model)
self.trainer.strategy.model_to_device()
self.trainer.test_loop.run()

def on_save_checkpoint(self) -> Dict[str, int]:
Expand Down
4 changes: 2 additions & 2 deletions pl_examples/loop_examples/yielding_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_generator(self, split_batch, batch_idx, opt_idx):
# Here we are basically calling `lightning_module.training_step()`
# and this returns a generator! The `training_step` is handled by the
# accelerator to enable distributed training.
return self.trainer.training_type_plugin.training_step(*step_kwargs.values())
return self.trainer.strategy.training_step(*step_kwargs.values())

def _training_step(self, generator):
# required for logging
Expand All @@ -86,7 +86,7 @@ def _training_step(self, generator):
# Here, instead of calling `lightning_module.training_step()`
# we call next() on the generator!
training_step_output = next(generator)
self.trainer.training_type_plugin.post_training_step()
self.trainer.strategy.post_training_step()

model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
should_stop, reason = self._evaluate_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
should_stop = trainer.strategy.reduce_boolean_decision(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def on_train_batch_end(
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
skip_time = trainer.training_type_plugin.broadcast(skip_time)
skip_time = trainer.strategy.broadcast(skip_time)

if skip_batch and skip_time:
return
Expand Down Expand Up @@ -492,7 +492,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Ten
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])

# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)

return should_update_best_and_save

Expand Down Expand Up @@ -598,7 +598,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path)
ckpt_path = trainer.strategy.broadcast(ckpt_path)

self.dirpath = ckpt_path

Expand Down Expand Up @@ -646,7 +646,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
trainer.save_checkpoint(filepath, self.save_weights_only)

if self.last_model_path and self.last_model_path != filepath:
trainer.training_type_plugin.remove_checkpoint(self.last_model_path)
trainer.strategy.remove_checkpoint(self.last_model_path)

self.last_model_path = filepath

Expand All @@ -671,7 +671,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
trainer.save_checkpoint(filepath, self.save_weights_only)

if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
trainer.training_type_plugin.remove_checkpoint(self.best_model_path)
trainer.strategy.remove_checkpoint(self.best_model_path)

self.best_model_path = filepath

Expand Down Expand Up @@ -718,7 +718,7 @@ def _update_best_and_save(
trainer.save_checkpoint(filepath, self.save_weights_only)

if del_filepath is not None and filepath != del_filepath:
trainer.training_type_plugin.remove_checkpoint(del_filepath)
trainer.strategy.remove_checkpoint(del_filepath)

def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
Expand All @@ -733,4 +733,4 @@ def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
state to diverge between ranks."""
exists = self._fs.exists(filepath)
return trainer.training_type_plugin.broadcast(exists)
return trainer.strategy.broadcast(exists)
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def on_load_checkpoint(
def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
assert self._duration is not None
should_stop = self.time_elapsed() >= self._duration
should_stop = trainer.training_type_plugin.broadcast(should_stop)
should_stop = trainer.strategy.broadcast(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop and self._verbose:
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/xla_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def on_train_start(self, trainer, pl_module) -> None:
)

memory_info = xm.get_memory_info(pl_module.device)
total_memory = trainer.training_type_plugin.reduce(memory_info["kb_total"]) * 0.001
total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")

def on_train_epoch_start(self, trainer, pl_module) -> None:
Expand All @@ -91,9 +91,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None:
free_memory = memory_info["kb_free"]
peak_memory = memory_info["kb_total"] - free_memory

free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001
peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
free_memory = trainer.strategy.reduce(free_memory) * 0.001
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
epoch_time = trainer.strategy.reduce(epoch_time)

logs["avg. free memory (MB)"] = free_memory
logs["avg. peak memory (MB)"] = peak_memory
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def log(
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
sync_dist_fn=self.trainer.strategy.reduce or sync_ddp,
sync_dist_group=sync_dist_group,
metric_attribute=metric_attribute,
rank_zero_only=rank_zero_only,
Expand Down Expand Up @@ -536,7 +536,7 @@ def all_gather(
the output will also be a collection with tensors of this shape.
"""
group = group if group is not None else torch.distributed.group.WORLD
all_gather = self.trainer.training_type_plugin.all_gather
all_gather = self.trainer.strategy.all_gather
data = convert_to_tensors(data, device=self.device)
return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads)

Expand Down Expand Up @@ -1337,7 +1337,7 @@ def training_step(...):
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
"""
self._verify_is_manual_optimization("manual_backward")
self.trainer.training_type_plugin.backward(loss, None, None, *args, **kwargs)
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)

def backward(
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,4 @@ def closure_dis():
trainer = self._trainer
assert trainer is not None
with trainer.profiler.profile(profiler_action):
trainer.training_type_plugin.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
4 changes: 1 addition & 3 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
# Python primitives. However, their states are saved with the model's `state_dict`.
# On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`.
# The references are provided through the `metric_attributes` dictionary.
v.load_state_dict(
state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
)
v.load_state_dict(state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.strategy.reduce)

if not self.trainer.is_global_zero:
v.reset(metrics=False)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
void(*args, **kwargs)

dataloader_idx = self.current_dataloader_idx
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader)
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def return_predictions(self) -> bool:
@return_predictions.setter
def return_predictions(self, return_predictions: Optional[bool] = None) -> None:
# `DDPSpawnStrategy` plugins and derivatives don't support return predictions.
is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnStrategy)
is_ddp_spawn = isinstance(self.trainer.strategy, DDPSpawnStrategy)
if return_predictions and is_ddp_spawn:
raise MisconfigurationException(
"`return_predictions` should be set to `False` when using the `DDPSpawnStrategy` or children class. "
f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}."
f"Found {return_predictions} with training_type_plugin {type(self.trainer.strategy)}."
)
# For non `DDPSpawnStrategy` plugin, the `return_predictions` is True by default unless user decide otherwise.
self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions
Expand Down Expand Up @@ -86,7 +86,7 @@ def on_run_start(self) -> None: # type: ignore[override]
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Predicts one entire dataloader."""
void(*args, **kwargs)
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
dl_max_batches = self.max_batches[self.current_dataloader_idx]

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,7 @@ def _should_accumulate(self) -> bool:
# Lightning steps on the final batch
is_final_batch = self._num_ready_batches_reached()
# but the TTP might not
ttp_accumulates_on_final_batch = (
self.trainer.training_type_plugin.handles_gradient_accumulation or not is_final_batch
)
ttp_accumulates_on_final_batch = self.trainer.strategy.handles_gradient_accumulation or not is_final_batch
return not accumulation_done and ttp_accumulates_on_final_batch

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def on_advance_start(self) -> None: # type: ignore[override]

def advance(self) -> None: # type: ignore[override]
"""Runs one whole epoch."""
dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
Expand Down Expand Up @@ -255,7 +255,7 @@ def on_run_end(self) -> None:
self.trainer._call_strategy_hook("on_train_end")

# give accelerators a chance to finish
self.trainer.training_type_plugin.on_train_end()
self.trainer.strategy.on_train_end()

def teardown(self) -> None:
self.epoch_loop.teardown()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/optimization/manual_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]

# manually capture logged metrics
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
self.trainer.strategy.post_training_step()

del step_kwargs

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _run_optimization(

if (
# when the training type plugin handles accumulation, we want to always call the optimizer step
not self.trainer.training_type_plugin.handles_gradient_accumulation
not self.trainer.strategy.handles_gradient_accumulation
and self.trainer.fit_loop._should_accumulate()
):
# For gradient accumulation
Expand Down Expand Up @@ -427,7 +427,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

# manually capture logged metrics
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
self.trainer.strategy.post_training_step()

del step_kwargs

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) ->
Returns:
context manager with sync behaviour off
"""
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
with trainer.training_type_plugin.block_backward_sync():
if isinstance(trainer.strategy, ParallelPlugin) and block:
with trainer.strategy.block_backward_sync():
yield None
else:
yield None
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS:

def dispatch(self, trainer: "pl.Trainer") -> None:
if not self._connected:
strategy = trainer.training_type_plugin
strategy = trainer.strategy
_, strategy.optimizers = amp.initialize(
trainer.lightning_module, strategy.optimizers, opt_level=self.amp_level
)
Expand Down
19 changes: 8 additions & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:

def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
with pl_legacy_patch():
loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS):
raise ValueError(
"The checkpoint you're attempting to load follows an"
Expand Down Expand Up @@ -113,7 +113,7 @@ def resume_end(self) -> None:
torch.cuda.empty_cache()

# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")
self.trainer.strategy.barrier("CheckpointConnector.resume_end")

def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and
Expand Down Expand Up @@ -170,7 +170,7 @@ def restore_model(self) -> None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)
self.trainer.strategy.load_model_state_dict(self._loaded_checkpoint)

# reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing.
if not self.trainer.is_global_zero:
Expand Down Expand Up @@ -258,10 +258,7 @@ def restore_loops(self) -> None:

def restore_optimizers_and_schedulers(self) -> None:
"""Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint."""
if (
not self._loaded_checkpoint
or not self.trainer.training_type_plugin.lightning_restore_optimizer_and_schedulers
):
if not self._loaded_checkpoint or not self.trainer.strategy.lightning_restore_optimizer_and_schedulers:
return

# validation
Expand All @@ -279,7 +276,7 @@ def restore_optimizers(self) -> None:
return

# restore the optimizers
self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint)
self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
for optimizer in self.trainer.optimizers:
# move optimizer to GPU 1 weight at a time
# avoids OOM
Expand Down Expand Up @@ -387,7 +384,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.training_type_plugin.optimizer_state(optimizer)
optimizer_state = self.trainer.strategy.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)

checkpoint["optimizer_states"] = optimizer_states
Expand Down Expand Up @@ -463,7 +460,7 @@ def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
weights_only: saving model weights only
"""
_checkpoint = self.dump_checkpoint(weights_only)
self.trainer.training_type_plugin.save_checkpoint(_checkpoint, filepath)
self.trainer.strategy.save_checkpoint(_checkpoint, filepath)

def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metrics = (
Expand All @@ -476,7 +473,7 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metric.persistent(True)
metric.sync()

state_dict = self.trainer.training_type_plugin.lightning_module_state_dict()
state_dict = self.trainer.strategy.lightning_module_state_dict()

for metric in metrics:
# sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check
Expand Down
Loading

0 comments on commit d31a27e

Please sign in to comment.