From d5fa02e7985c3920e72e268ece1366a1de96281b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 10 Dec 2020 14:06:13 +0100 Subject: [PATCH] simplify accelerator steps (#5015) * simplify accelerator steps * Apply suggestions from code review Co-authored-by: Rohit Gupta Co-authored-by: Rohit Gupta --- .../accelerators/cpu_accelerator.py | 25 +++++------- .../accelerators/dp_accelerator.py | 11 ++--- .../accelerators/horovod_accelerator.py | 40 +++++-------------- .../trainer/connectors/slurm_connector.py | 3 +- 4 files changed, 27 insertions(+), 52 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 9113331ef0a7d..2b290c5226d1b 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Callable import torch @@ -61,29 +61,22 @@ def train(self): results = self.train_or_test() return results - def training_step(self, args): + def _step(self, model_step: Callable, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): - output = self.trainer.model.training_step(*args) + output = model_step(*args) else: - output = self.trainer.model.training_step(*args) + output = model_step(*args) return output + def training_step(self, args): + return self._step(self.trainer.model.training_step, args) + def validation_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model.validation_step(*args) - else: - output = self.trainer.model.validation_step(*args) - return output + return self._step(self.trainer.model.validation_step, args) def test_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model.test_step(*args) - else: - output = self.trainer.model.test_step(*args) - return output + return self._step(self.trainer.model.test_step, args) def sync_tensor(self, tensor: Union[torch.Tensor], diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index a7f3c260e682c..a3563e6a3af73 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -116,7 +116,7 @@ def teardown(self): self.trainer.model.forward = self.model_autocast_original_forward self.barrier() - def training_step(self, args): + def _step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): output = self.trainer.model(*args) @@ -124,13 +124,14 @@ def training_step(self, args): output = self.trainer.model(*args) return output + def training_step(self, args): + return self._step(args) + def validation_step(self, args): - output = self.training_step(args) - return output + return self._step(args) def test_step(self, args): - output = self.training_step(args) - return output + return self._step(args) def training_step_end(self, output): if isinstance(output, Result): diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 93983369f17a9..6582e3b376ff6 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import ExitStack -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Callable import torch from torch.optim.lr_scheduler import _LRScheduler @@ -114,46 +114,26 @@ def train(self): hvd.join() return results - def training_step(self, args): + def _step(self, model_step: Callable, args): if self.trainer.on_gpu: - batch = args[0] - batch = self.batch_to_device(batch, hvd.local_rank()) - args[0] = batch + args[0] = self.batch_to_device(args[0], hvd.local_rank()) if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): - output = self.trainer.model.training_step(*args) + output = model_step(*args) else: - output = self.trainer.model.training_step(*args) + output = model_step(*args) return output - def validation_step(self, args): - if self.trainer.on_gpu: - batch = args[0] - batch = self.batch_to_device(batch, hvd.local_rank()) - args[0] = batch - - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model.validation_step(*args) - else: - output = self.trainer.model.validation_step(*args) + def training_step(self, args): + return self._step(self.trainer.model.training_step, args) - return output + def validation_step(self, args): + return self._step(self.trainer.model.validation_step, args) def test_step(self, args): - if self.trainer.on_gpu: - batch = args[0] - batch = self.batch_to_device(batch, hvd.local_rank()) - args[0] = batch - - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model.test_step(*args) - else: - output = self.trainer.model.test_step(*args) - return output + return self._step(self.trainer.model.test_step, args) def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs) diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index 9ff8c13825976..4cb954a8e92fc 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -54,6 +54,7 @@ def configure_slurm_ddp(self, num_gpu_nodes): if self.trainer.is_slurm_managing_tasks: rank_zero_info('Multi-processing is handled by Slurm.') + # todo: the same function as slurm_environment.py `_resolve_root_node_address` def resolve_root_node_address(self, root_node): if '[' in root_node: name, numbers = root_node.split('[', maxsplit=1) @@ -108,8 +109,8 @@ def term_handler(self, signum, frame): # save log.info("bypassing sigterm") + # todo: this is the same func as slurm_environment.py `master_port` def connect_ddp(self, global_rank: int, world_size: int) -> None: - """""" """ Sets up environment variables necessary for pytorch distributed communications based on slurm environment.