Skip to content

Commit

Permalink
Initialize trainer with None in DDPAccelerator (Lightning-AI#4915)
Browse files Browse the repository at this point in the history
* Initialize trainer with None

* add typing to all accelerators

* resolve imports

* update

* add typing

* removed typo

* update

* Fix formatting and imports in accelerator

Co-authored-by: maxjeblick <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: SeanNaren <[email protected]>
Co-authored-by: Roger Shieh <[email protected]>
  • Loading branch information
5 people authored Dec 10, 2020
1 parent d5fa02e commit 2c3d43d
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 27 deletions.
8 changes: 6 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union

import torch
import torch.distributed as torch_distrib
from torch.optim import Optimizer

from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.parsing import AttributeDict
Expand All @@ -33,7 +34,10 @@ class ReduceOp:

class Accelerator(object):

def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: Optional = None,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
self.trainer = trainer
self.nickname = None
self.cluster_environment = cluster_environment
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
import torch

from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class CPUAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training on CPU
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand All @@ -34,7 +36,10 @@

class DDP2Accelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP2 strategy on a cluster
Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

Expand All @@ -41,7 +47,10 @@

class DDPAccelerator(Accelerator):

def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: Optional = None,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP strategy on a single machine (manually, not via cluster start)
Expand Down
12 changes: 10 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from typing import Optional

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE

if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDPCPUHPCAccelerator(DDPHPCAccelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP (with CPUs) strategy on a cluster
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)

if HYDRA_AVAILABLE:
Expand All @@ -41,7 +42,11 @@

class DDPCPUSpawnAccelerator(Accelerator):

def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
nprocs: int,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn
Expand Down Expand Up @@ -197,8 +202,8 @@ def broadcast(self, obj, src=0):

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand All @@ -34,7 +36,10 @@

class DDPHPCAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP on an HPC cluster
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,25 @@

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)
from pytorch_lightning.utilities.seed import seed_everything

Expand All @@ -45,7 +46,11 @@

class DDPSpawnAccelerator(Accelerator):

def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
nprocs: int,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP using mp.spawn via manual launch (not cluster launch)
Expand Down Expand Up @@ -226,8 +231,8 @@ def barrier(self, name: Optional[str] = None):

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Optional, Union

import torch
from torch import optim

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed import LightningDistributed
Expand All @@ -27,7 +29,7 @@

class DataParallelAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using DP via manual start (not HPC cluster)
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType


class GPUAccelerator(Accelerator):
amp_backend: AMPType

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using a single GPU
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand All @@ -28,7 +30,7 @@
class HorovodAccelerator(Accelerator):
amp_backend: AMPType

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using horovod
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
Expand All @@ -43,7 +44,7 @@

class TPUAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using TPUs (colab, single machine or pod)
Expand Down

0 comments on commit 2c3d43d

Please sign in to comment.