Skip to content

Commit

Permalink
simplify accelerator steps (#5015)
Browse files Browse the repository at this point in the history
* simplify accelerator steps

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <[email protected]>

Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
Borda and rohitgr7 authored Dec 10, 2020
1 parent 820d5c7 commit d5fa02e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 52 deletions.
25 changes: 9 additions & 16 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Callable

import torch

Expand Down Expand Up @@ -61,29 +61,22 @@ def train(self):
results = self.train_or_test()
return results

def training_step(self, args):
def _step(self, model_step: Callable, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.training_step(*args)
output = model_step(*args)
else:
output = self.trainer.model.training_step(*args)
output = model_step(*args)
return output

def training_step(self, args):
return self._step(self.trainer.model.training_step, args)

def validation_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.validation_step(*args)
else:
output = self.trainer.model.validation_step(*args)
return output
return self._step(self.trainer.model.validation_step, args)

def test_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.test_step(*args)
else:
output = self.trainer.model.test_step(*args)
return output
return self._step(self.trainer.model.test_step, args)

def sync_tensor(self,
tensor: Union[torch.Tensor],
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,22 @@ def teardown(self):
self.trainer.model.forward = self.model_autocast_original_forward
self.barrier()

def training_step(self, args):
def _step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def training_step(self, args):
return self._step(args)

def validation_step(self, args):
output = self.training_step(args)
return output
return self._step(args)

def test_step(self, args):
output = self.training_step(args)
return output
return self._step(args)

def training_step_end(self, output):
if isinstance(output, Result):
Expand Down
40 changes: 10 additions & 30 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Callable

import torch
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -114,46 +114,26 @@ def train(self):
hvd.join()
return results

def training_step(self, args):
def _step(self, model_step: Callable, args):
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
args[0] = self.batch_to_device(args[0], hvd.local_rank())

if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.training_step(*args)
output = model_step(*args)
else:
output = self.trainer.model.training_step(*args)
output = model_step(*args)

return output

def validation_step(self, args):
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch

if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.validation_step(*args)
else:
output = self.trainer.model.validation_step(*args)
def training_step(self, args):
return self._step(self.trainer.model.training_step, args)

return output
def validation_step(self, args):
return self._step(self.trainer.model.validation_step, args)

def test_step(self, args):
if self.trainer.on_gpu:
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch

if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model.test_step(*args)
else:
output = self.trainer.model.test_step(*args)
return output
return self._step(self.trainer.model.test_step, args)

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/connectors/slurm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def configure_slurm_ddp(self, num_gpu_nodes):
if self.trainer.is_slurm_managing_tasks:
rank_zero_info('Multi-processing is handled by Slurm.')

# todo: the same function as slurm_environment.py `_resolve_root_node_address`
def resolve_root_node_address(self, root_node):
if '[' in root_node:
name, numbers = root_node.split('[', maxsplit=1)
Expand Down Expand Up @@ -108,8 +109,8 @@ def term_handler(self, signum, frame):
# save
log.info("bypassing sigterm")

# todo: this is the same func as slurm_environment.py `master_port`
def connect_ddp(self, global_rank: int, world_size: int) -> None:
""""""
"""
Sets up environment variables necessary for pytorch distributed communications
based on slurm environment.
Expand Down

0 comments on commit d5fa02e

Please sign in to comment.