Skip to content

Commit

Permalink
ref: merge backends x/n (#3477)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Sep 12, 2020
1 parent 59d8472 commit 00d155a
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 166 deletions.
57 changes: 2 additions & 55 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License

import os
import re

import torch

Expand All @@ -22,11 +21,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -42,7 +37,7 @@
amp = None


class DDP2Backend(Accelerator):
class DDP2Backend(DDPBase):

def __init__(self, trainer):
super().__init__(trainer)
Expand Down Expand Up @@ -170,22 +165,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# clean up memory
torch.cuda.empty_cache()

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def training_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
Expand All @@ -200,35 +179,3 @@ def test_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
57 changes: 2 additions & 55 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License

import os
import re
import subprocess
import sys
from os.path import abspath
Expand All @@ -25,12 +24,8 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -46,7 +41,7 @@
amp = None


class DDPBackend(Accelerator):
class DDPBackend(DDPBase):

def __init__(self, trainer, mode: str = 'ddp'):
super().__init__(trainer)
Expand Down Expand Up @@ -257,57 +252,9 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
return results

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
"You tried to run `.fit` or `.test` multiple times in the same script."
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
)

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
90 changes: 90 additions & 0 deletions pytorch_lightning/accelerators/ddp_base_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import re
import torch

from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn

try:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
except ImportError:
HYDRA_AVAILABLE = False
else:
HYDRA_AVAILABLE = True

try:
from apex import amp
except ImportError:
amp = None


class DDPBase(Accelerator):

def __init__(self, trainer):
super().__init__(trainer)

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
57 changes: 2 additions & 55 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
import re

import torch
import torch.multiprocessing as mp

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
from apex import amp
except ImportError:
amp = None


class DDPSpawnBackend(Accelerator):
class DDPSpawnBackend(DDPBase):

def __init__(self, trainer, nprocs):
super().__init__(trainer)
Expand Down Expand Up @@ -180,51 +175,3 @@ def ddp_train(self, process_idx, mp_queue, model):

# clean up memory
torch.cuda.empty_cache()

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
1 change: 0 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managi
global_rank: The global process idx.
world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus).
is_slurm_managing_tasks: is cluster managed by SLURM.
"""
if is_slurm_managing_tasks:
self._init_slurm_connection()
Expand Down

0 comments on commit 00d155a

Please sign in to comment.