Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4/n Move Accelerator into strategy - remove X_step() from accelerator #10890

Merged
merged 11 commits into from
Dec 6, 2021
Prev Previous commit
Next Next commit
update
four4fish committed Dec 4, 2021
commit 47c4da3077c36314d07fc4a500f0e2574a255d30
25 changes: 0 additions & 25 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,6 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.utilities.types import STEP_OUTPUT


class Accelerator:
@@ -118,30 +117,6 @@ def teardown(self) -> None:
"""
self.training_type_plugin.teardown()

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
"""The actual validation step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
"""
with self.training_type_plugin.precision_plugin.val_step_context():
return self.training_type_plugin.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
"""The actual test step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
"""
with self.training_type_plugin.precision_plugin.test_step_context():
return self.training_type_plugin.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual predict step.

See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
"""
with self.training_type_plugin.precision_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*args, **kwargs)

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Gets stats for a given device.

4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
@@ -220,11 +220,11 @@ def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]:
if self.trainer.testing:
self.trainer.lightning_module._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator.test_step(*kwargs.values())
output = self.trainer.training_type_plugin.test_step(*kwargs.values())
else:
self.trainer.lightning_module._current_fx_name = "validation_step"
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator.validation_step(*kwargs.values())
output = self.trainer.training_type_plugin.validation_step(*kwargs.values())

return output

2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
@@ -132,7 +132,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
self.batch_progress.increment_started()

model_ref._current_fx_name = "predict_step"
predictions = self.trainer.accelerator.predict_step(*step_kwargs.values())
predictions = self.trainer.training_type_plugin.predict_step(*step_kwargs.values())

self.batch_progress.increment_processed()

2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache

warning_cache = WarningCache()
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
from fairscale.nn import default_auto_wrap_policy, enable_wrap