Skip to content

Commit

Permalink
ref: merge backends x/n (#3478)
Browse files Browse the repository at this point in the history
* ref: merge backends x/n

* ref: merge backends x/n

* ref: merge backends x/n

* ref: merge backends x/n
  • Loading branch information
williamFalcon authored Sep 12, 2020
1 parent 00d155a commit 0045119
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 110 deletions.
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pytorch_lightning.accelerators.ddp2_backend import DDP2Backend
from pytorch_lightning.accelerators.ddp_backend import DDPBackend
from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend
from pytorch_lightning.accelerators.ddp_cpu_spawn_backend import DDPCPUSpawnBackend
from pytorch_lightning.accelerators.dp_backend import DataParallelBackend
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
from pytorch_lightning.accelerators.tpu_backend import TPUBackend
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def select_accelerator(self):
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed

use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend in ['ddp_cpu', 'ddp_spawn']
use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == 'ddp_spawn'
use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == 'ddp_cpu'

# choose the appropriate accelerator backend
if self.trainer.use_ddp2:
Expand All @@ -144,6 +145,9 @@ def select_accelerator(self):
elif use_ddp_spawn:
accelerator_backend = accelerators.DDPSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)

elif use_ddp_cpu_spawn:
accelerator_backend = accelerators.DDPCPUSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)

elif self.trainer.distributed_backend == 'ddp':
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp')

Expand Down
102 changes: 101 additions & 1 deletion pytorch_lightning/accelerators/ddp_base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
from pytorch_lightning import _logger as log

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -88,3 +89,102 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)

def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
"""
Entry point for ddp
Args:
process_idx:
mp_queue: multiprocessing queue
model:
Returns:
"""
# offset the process id if requested
process_idx = process_idx + proc_offset

# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
self.set_world_ranks(process_idx)

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# AMP -
# run through amp wrapper before going to distributed DP
if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# get original model
model = self.trainer.get_model()

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()

def set_world_ranks(self, process_idx):
raise NotImplementedError('to create a ddp backend, please implement set_world_ranks')

def model_to_device(self, model, process_idx):
raise NotImplementedError('to create a ddp backend, please implement model_to_device')

def get_device_ids(self):
raise NotImplementedError('to create a ddp backend, please implement get_device_ids')
82 changes: 82 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os

import torch
import torch.multiprocessing as mp

from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
from apex import amp
except ImportError:
amp = None


class DDPCPUSpawnBackend(DDPBase):

def __init__(self, trainer, nprocs):
super().__init__(trainer)
self.mp_queue = None
self.nprocs = nprocs

def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))

# pass in a state q
smp = mp.get_context('spawn')
self.mp_queue = smp.SimpleQueue()

self.trainer.model = model

def train(self):
model = self.trainer.model

# train in children process
mp.spawn(self.ddp_train_tmp, nprocs=self.nprocs, args=(self.mp_queue, model,))

# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()

# recover the weights of the processes trained in the children
self.__recover_child_process_weights(model, best_path, last_path)
return results

def __recover_child_process_weights(self, model, best_path, last_path):
# transfer back the best path to the trainer
if self.trainer.checkpoint_callback:
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score

# load last weights
if last_path is not None and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self.trainer.model = model

def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
pass

def get_device_ids(self):
device_ids = None
return device_ids
124 changes: 16 additions & 108 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import torch
import torch.multiprocessing as mp

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
Expand Down Expand Up @@ -47,7 +45,7 @@ def train(self):
model = self.trainer.model

# train in children process
mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,))
mp.spawn(self.ddp_train_tmp, nprocs=self.nprocs, args=(self.mp_queue, model,))

# restore main state with best weights
best_path = self.mp_queue.get()
Expand All @@ -71,107 +69,17 @@ def __recover_child_process_weights(self, model, best_path, last_path):

self.trainer.model = model

def ddp_train(self, process_idx, mp_queue, model):
"""
Entry point for ddp
Args:
process_idx:
mp_queue: multiprocessing queue
model:
Returns:
"""
# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
if self.trainer.use_ddp:
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

elif self.trainer.use_ddp2:
self.trainer.local_rank = self.trainer.node_rank
self.trainer.global_rank = self.trainer.node_rank
self.trainer.world_size = self.trainer.num_nodes

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
if self.trainer.on_gpu:
gpu_idx = process_idx
self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# AMP -
# run through amp wrapper before going to distributed DP
if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# DDP2 uses all GPUs on the machine
if self.trainer.distributed_backend == 'ddp' or self.trainer.distributed_backend == 'ddp_spawn':
device_ids = [self.trainer.root_gpu]
elif self.trainer.use_ddp2:
device_ids = self.trainer.data_parallel_device_ids
else: # includes ddp_cpu
device_ids = None

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# get original model
model = self.trainer.get_model()

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
gpu_idx = process_idx
self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
device_ids = [self.trainer.root_gpu]
return device_ids

0 comments on commit 0045119

Please sign in to comment.