Skip to content

Commit

Permalink
ref: part 4 of #3733
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Oct 2, 2020
1 parent 62eabdd commit 434a328
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 35 deletions.
141 changes: 124 additions & 17 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
# limitations under the License

import os
import torch
import torch.distributed as torch_distrib
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Optional

import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import AMPType


try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -34,13 +39,14 @@
HYDRA_AVAILABLE = True


class DDPBackend(DDPBase):
class DDPBackend(Accelerator):

def __init__(self, trainer, mode: str = 'ddp'):
super().__init__(trainer)
self.task_idx = None
self._has_spawned_children = False
self.mode = mode
self.interactive_ddp_procs = []

def setup(self, model):
if self.mode == 'ddp':
Expand All @@ -59,6 +65,10 @@ def __torchelastic_setup(self):
self.task_idx = int(os.environ['LOCAL_RANK'])

def __ddp_script_mode_setup(self):
# do nothing when already in a ddp subprocess
if os.environ.get('PL_IN_DDP_SUBPROCESS', '0') == '1':
return

assert self.trainer.global_rank == 0
self._check_can_spawn_children()
self._has_spawned_children = True
Expand Down Expand Up @@ -105,7 +115,7 @@ def __ddp_script_mode_setup(self):

os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}'

self.trainer.interactive_ddp_procs = []
self.interactive_ddp_procs = []
for local_rank in range(1, self.trainer.num_processes):
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'
Expand All @@ -118,7 +128,7 @@ def __ddp_script_mode_setup(self):
if HydraConfig.initialized():
cwd = get_original_cwd()
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
self.trainer.interactive_ddp_procs.append(proc)
self.interactive_ddp_procs.append(proc)

# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
Expand All @@ -127,14 +137,116 @@ def __ddp_script_mode_setup(self):

self.task_idx = 0

# wait for all the procs to start
sleep(2)

def train(self):
model = self.trainer.model
if self.mode == 'ddp':
results = self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True)
del os.environ['WORLD_SIZE']
results = self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True)
if 'WORLD_SIZE' in os.environ:
del os.environ['WORLD_SIZE']
return results
else:
self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model)
return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
"""
Entry point for ddp
Args:
process_idx:
mp_queue: multiprocessing queue
model:
is_master:
proc_offset:
Returns:
"""
# offset the process id if requested
process_idx = process_idx + proc_offset

# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
self.set_world_ranks(process_idx)

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
self.model_to_device(model, process_idx, is_master)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# AMP - run through amp wrapper before going to distributed DP
# DDP uses all GPUs on the machine
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.barrier('ddp_setup')
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# clean up memory
torch.cuda.empty_cache()

return results

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()

def _check_can_spawn_children(self):
if self._has_spawned_children:
Expand All @@ -149,15 +261,7 @@ def set_world_ranks(self, process_idx):
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx, is_master):
gpu_idx = process_idx

# when using ddp, the master process (proc 0) continues running as the main one
# this means that the local rank will always be 0
# (even if cuda visible devices has other visible gpus)
# this means that the master process needs to pull the 0th visible index as the device number
if is_master:
available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
gpu_idx = int(available_gpus[self.trainer.local_rank])
gpu_idx = int(os.environ.get('PL_DDP_PID', process_idx))

gpu_idx = int(os.environ.get('PL_DDP_PID', gpu_idx))

Expand All @@ -168,3 +272,6 @@ def model_to_device(self, model, process_idx, is_master):
def get_device_ids(self):
device_ids = [self.trainer.root_gpu]
return device_ids

def on_train_end(self):
pass
36 changes: 18 additions & 18 deletions tests/backends/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,21 @@ def test_multi_gpu_model_ddp_test_only(tmpdir, cli_args):
assert result['status'] == 'complete'


# @pytest.mark.parametrize('cli_args', [
# pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
# ])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args):
# # call the script
# call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20)
#
# # load the results of the script
# result_path = os.path.join(tmpdir, 'ddp.result')
# result = torch.load(result_path)
#
# # verify the file wrote the expected outputs
# assert result['status'] == 'complete'
#
# model_outs = result['result']
# for out in model_outs:
# assert out['test_acc'] > 0.90
@pytest.mark.parametrize('cli_args', [
pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args):
# call the script
call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20)

# load the results of the script
result_path = os.path.join(tmpdir, 'ddp.result')
result = torch.load(result_path)

# verify the file wrote the expected outputs
assert result['status'] == 'complete'

model_outs = result['result']
for out in model_outs:
assert out['test_acc'] > 0.90

0 comments on commit 434a328

Please sign in to comment.