Skip to content

Commit

Permalink
Simplify GPU and TPU accelerator (Lightning-AI#5024)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored Dec 9, 2020
1 parent 90d1d9f commit bcbba3b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 54 deletions.
48 changes: 10 additions & 38 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import torch

Expand Down Expand Up @@ -66,53 +66,25 @@ def train(self):
results = self.train_or_test()
return results

def training_step(self, args):
def _step(self, model_step: Callable, args):
args[0] = self.to_device(args[0])

if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.__training_step(args)
output = model_step(*args)
else:
output = self.__training_step(args)
output = model_step(*args)

return output

def __training_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
def training_step(self, args):
return self._step(self.trainer.model.training_step, args)

def validation_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.__validation_step(args)
else:
output = self.__validation_step(args)

return output

def __validation_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
return self._step(self.trainer.model.validation_step, args)

def test_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.__test_step(args)
else:
output = self.__test_step(args)

return output

def __test_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.test_step(*args)
return output
return self._step(self.trainer.model.test_step, args)

def to_device(self, batch):
gpu_id = 0
Expand Down
24 changes: 8 additions & 16 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import io
import os
import re
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -145,26 +145,18 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
# persist info in spawn
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

def _step(self, model_step: Callable, args):
args[0] = self.to_device(args[0])
return model_step(*args)

def training_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
return self._step(self.trainer.model.training_step, args)

def validation_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
return self._step(self.trainer.model.validation_step, args)

def test_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.test_step(*args)
return output
return self._step(self.trainer.model.test_step, args)

def process_dataloader(self, dataloader):
device = xm.xla_device(self.trainer.tpu_id)
Expand Down

0 comments on commit bcbba3b

Please sign in to comment.