From bcbba3b7028fe1a33e0d2dba99704c27728533ff Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 10 Dec 2020 00:42:44 +0530 Subject: [PATCH] Simplify GPU and TPU accelerator (#5024) --- .../accelerators/gpu_accelerator.py | 48 ++++--------------- .../accelerators/tpu_accelerator.py | 24 ++++------ 2 files changed, 18 insertions(+), 54 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index abc065cd39ed4..f4d31213c7e5f 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch @@ -66,53 +66,25 @@ def train(self): results = self.train_or_test() return results - def training_step(self, args): + def _step(self, model_step: Callable, args): + args[0] = self.to_device(args[0]) + if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): - output = self.__training_step(args) + output = model_step(*args) else: - output = self.__training_step(args) + output = model_step(*args) return output - def __training_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.training_step(*args) - return output + def training_step(self, args): + return self._step(self.trainer.model.training_step, args) def validation_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.__validation_step(args) - else: - output = self.__validation_step(args) - - return output - - def __validation_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.validation_step(*args) - return output + return self._step(self.trainer.model.validation_step, args) def test_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.__test_step(args) - else: - output = self.__test_step(args) - - return output - - def __test_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.test_step(*args) - return output + return self._step(self.trainer.model.test_step, args) def to_device(self, batch): gpu_id = 0 diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index a7752e42a96cf..74fd201df8a66 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -14,7 +14,7 @@ import io import os import re -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.multiprocessing as mp @@ -145,26 +145,18 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # persist info in spawn self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) + def _step(self, model_step: Callable, args): + args[0] = self.to_device(args[0]) + return model_step(*args) + def training_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.training_step(*args) - return output + return self._step(self.trainer.model.training_step, args) def validation_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.validation_step(*args) - return output + return self._step(self.trainer.model.validation_step, args) def test_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.test_step(*args) - return output + return self._step(self.trainer.model.test_step, args) def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id)