Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize trainer with None in DDPAccelerator #4915

Merged
merged 19 commits into from
Dec 10, 2020
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import torch.distributed as torch_distrib
from torch.optim import Optimizer

import pytorch_lightning as pl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls do not use this call inside PL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work with the doc when I do.

from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.parsing import AttributeDict

Expand All @@ -33,7 +37,10 @@ class ReduceOp:

class Accelerator(object):

def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: Optional['pl.Trainer'] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trainer: Optional['pl.Trainer'] = None,
trainer: Optional['Trainer'] = None,

as it is string you do not need the pl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need it for the doc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same it breaks the doc.

cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
self.trainer = trainer
self.nickname = None
self.cluster_environment = cluster_environment
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class CPUAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer: 'pl.Trainer', cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training on CPU

Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand All @@ -34,7 +37,10 @@

class DDP2Accelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: 'pl.Trainer',
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP2 strategy on a cluster

Expand Down
16 changes: 13 additions & 3 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,21 @@
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

Expand All @@ -41,7 +48,10 @@

class DDPAccelerator(Accelerator):

def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: Optional['pl.Trainer'] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP strategy on a single machine (manually, not via cluster start)

Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from typing import Optional

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE

if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDPCPUHPCAccelerator(DDPHPCAccelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: 'pl.Trainer',
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP (with CPUs) strategy on a cluster

Expand Down
16 changes: 11 additions & 5 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)

if HYDRA_AVAILABLE:
Expand All @@ -41,7 +43,11 @@

class DDPCPUSpawnAccelerator(Accelerator):

def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: 'pl.Trainer',
nprocs: int,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn

Expand Down Expand Up @@ -197,8 +203,8 @@ def broadcast(self, obj, src=0):

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand All @@ -34,7 +37,10 @@

class DDPHPCAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: 'pl.Trainer',
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP on an HPC cluster

Expand Down
16 changes: 11 additions & 5 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,26 @@

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)
from pytorch_lightning.utilities.seed import seed_everything

Expand All @@ -45,7 +47,11 @@

class DDPSpawnAccelerator(Accelerator):

def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: 'pl.Trainer',
nprocs: int,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP using mp.spawn via manual launch (not cluster launch)

Expand Down Expand Up @@ -226,8 +232,8 @@ def barrier(self, name: Optional[str] = None):

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Optional, Union

import torch
from torch import optim

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed import LightningDistributed
Expand All @@ -27,7 +30,7 @@

class DataParallelAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer: 'pl.Trainer', cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using DP via manual start (not HPC cluster)

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@

import torch

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType


class GPUAccelerator(Accelerator):
amp_backend: AMPType

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer: 'pl.Trainer', cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using a single GPU

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand All @@ -28,7 +31,7 @@
class HorovodAccelerator(Accelerator):
amp_backend: AMPType

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer: 'pl.Trainer', cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using horovod

Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import torch.multiprocessing as mp
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
Expand All @@ -43,7 +45,7 @@

class TPUAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer: 'pl.Trainer', cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using TPUs (colab, single machine or pod)

Expand Down