Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerator docs #4583

Merged
merged 2 commits into from
Nov 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions docs/source/accelerators.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
############
Accelerators
############
Accelerators connect a Lightning Trainer to arbitrary accelerators (CPUs, GPUs, TPUs, etc). Accelerators
also manage distributed accelerators (like DP, DDP, HPC cluster).

Accelerators can also be configured to run on arbitrary clusters using Plugins or to link up to arbitrary
computational strategies like 16-bit precision via AMP and Apex.

----------

******************************
Implement a custom accelerator
******************************
To link up arbitrary hardware, implement your own Accelerator subclass

.. code-block:: python

from pytorch_lightning.accelerators.accelerator import Accelerator

class MyAccelerator(Accelerator):
def __init__(self, trainer, cluster_environment=None):
super().__init__(trainer, cluster_environment)
self.nickname = 'my_accelator'

def setup(self):
# find local rank, etc, custom things to implement

def train(self):
# implement what happens during training

def training_step(self):
# implement how to do a training_step on this accelerator

def validation_step(self):
# implement how to do a validation_step on this accelerator

def test_step(self):
# implement how to do a test_step on this accelerator

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# implement how to do a backward pass with this accelerator

def barrier(self, name: Optional[str] = None):
# implement this accelerator's barrier

def broadcast(self, obj, src=0):
# implement this accelerator's broadcast function

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
# implement how to sync tensors when reducing metrics across accelerators

********
Examples
********
The following examples illustrate customizing accelerators.

Example 1: Arbitrary HPC cluster
================================
To link any accelerator with an arbitrary cluster (SLURM, Condor, etc), pass in a Cluster Plugin which will be passed
into any accelerator.

First, implement your own ClusterEnvironment. Here is the torch elastic implementation.

.. code-block:: python

import os
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment

class TorchElasticEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()

def master_address(self):
if "MASTER_ADDR" not in os.environ:
rank_zero_warn(
"MASTER_ADDR environment variable is not defined. Set as localhost"
)
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
master_address = os.environ.get('MASTER_ADDR')
return master_address

def master_port(self):
if "MASTER_PORT" not in os.environ:
rank_zero_warn(
"MASTER_PORT environment variable is not defined. Set as 12910"
)
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

port = os.environ.get('MASTER_PORT')
return port

def world_size(self):
return os.environ.get('WORLD_SIZE')

def local_rank(self):
return int(os.environ['LOCAL_RANK'])

Now, pass it into the trainer which will use Torch Elastic across your accelerator of choice.

.. code-block:: python

cluster = TorchElasticEnvironment()
accelerator = MyAccelerator()
trainer = Trainer(plugins=[cluster], accelerator=MyAccelerator())

In this example, MyAccelerator can define arbitrary hardware (like IPUs or TPUs) and links it to an arbitrary
compute cluster.

------------

**********************
Available Accelerators
**********************

CPU Accelerator
===============

.. autoclass:: pytorch_lightning.accelerators.cpu_accelerator.CPUAccelerator
:noindex:

DDP Accelerator
===============

.. autoclass:: pytorch_lightning.accelerators.ddp_accelerator.DDPAccelerator
:noindex:

DDP2 Accelerator
================

.. autoclass:: pytorch_lightning.accelerators.ddp2_accelerator.DDP2Accelerator
:noindex:

DDP CPU HPC Accelerator
=======================

.. autoclass:: pytorch_lightning.accelerators.ddp_cpu_hpc_accelerator.DDPCPUHPCAccelerator
:noindex:

DDP CPU Spawn Accelerator
=========================

.. autoclass:: pytorch_lightning.accelerators.ddp_cpu_spawn_accelerator.DDPCPUSpawnAccelerator
:noindex:

DDP HPC Accelerator
===================

.. autoclass:: pytorch_lightning.accelerators.ddp_hpc_accelerator.DDPHPCAccelerator
:noindex:

DDP Spawn Accelerator
=====================

.. autoclass:: pytorch_lightning.accelerators.ddp_spawn_accelerator.DDPSpawnAccelerator
:noindex:

GPU Accelerator
===============

.. autoclass:: pytorch_lightning.accelerators.gpu_accelerator.GPUAccelerator
:noindex:

Horovod Accelerator
===================

.. autoclass:: pytorch_lightning.accelerators.horovod_accelerator.HorovodAccelerator
:noindex:

TPU Accelerator
===============

.. autoclass:: pytorch_lightning.accelerators.tpu_accelerator.TPUAccelerator
:noindex:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ PyTorch Lightning Documentation
:name: docs
:caption: Optional extensions

accelerators
callbacks
datamodules
logging
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,13 @@ def sync_tensor(self,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
"""
Function to reduce a tensor from several distributed processes to one aggregated tensor.

Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value
"""
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
class CPUAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None):
"""
Runs training on CPU

Example::

# default
trainer = Trainer(accelerator=CPUAccelerator())

"""
super().__init__(trainer, cluster_environment)
self.nickname = None

Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@
class DDP2Accelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
"""
Runs training using DDP2 strategy on a cluster

Example::

# default
trainer = Trainer(accelerator=DDP2Accelerator())

"""
super().__init__(trainer, cluster_environment, ddp_plugin)
self.task_idx = None
self.dist = LightningDistributed()
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@
class DDPAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
"""
Runs training using DDP strategy on a single machine (manually, not via cluster start)

Example::

# default
trainer = Trainer(accelerator=DDPAccelerator())

"""
super().__init__(trainer, cluster_environment, ddp_plugin)
self.task_idx = None
self._has_spawned_children = False
Expand Down Expand Up @@ -304,4 +313,7 @@ def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
"""

"""
return sync_ddp_if_available(tensor, group, reduce_op)
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
class DDPCPUHPCAccelerator(DDPHPCAccelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
"""
Runs training using DDP (with CPUs) strategy on a cluster

Example::

# default
trainer = Trainer(accelerator=DDPCPUHPCAccelerator())

"""
super().__init__(trainer, cluster_environment, ddp_plugin)
self.nickname = 'ddp_cpu'

Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@
class DDPCPUSpawnAccelerator(Accelerator):

def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
"""
Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn

Example::

# default
trainer = Trainer(accelerator=DDPCPUSpawnAccelerator())

"""
super().__init__(trainer, cluster_environment, ddp_plugin)
self.mp_queue = None
self.nprocs = nprocs
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
class DDPHPCAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
"""
Runs training using DDP on an HPC cluster

Example::

# default
trainer = Trainer(accelerator=DDPHPCAccelerator())

"""
super().__init__(trainer, cluster_environment, ddp_plugin)
self.task_idx = None
self._has_spawned_children = False
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@
class DDPSpawnAccelerator(Accelerator):

def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
"""
Runs training using DDP using mp.spawn via manual launch (not cluster launch)

Example::

# default
trainer = Trainer(accelerator=DDPSpawnAccelerator())

"""
super().__init__(trainer, cluster_environment, ddp_plugin)
self.mp_queue = None
self.nprocs = nprocs
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
class DataParallelAccelerator(Accelerator):

def __init__(self, trainer, cluster_environment=None):
"""
Runs training using DP via manual start (not HPC cluster)

Example::

# default
trainer = Trainer(accelerator=DataParallelAccelerator())

"""
super().__init__(trainer, cluster_environment)
self.model_autocast_original_forward = None
self.dist = LightningDistributed()
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ class GPUAccelerator(Accelerator):
amp_backend: AMPType

def __init__(self, trainer, cluster_environment=None):
"""
Runs training using a single GPU

Example::

# default
trainer = Trainer(accelerator=GPUAccelerator())

"""
super().__init__(trainer, cluster_environment)
self.dist = LightningDistributed()
self.nickname = None
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ class HorovodAccelerator(Accelerator):
amp_backend: AMPType

def __init__(self, trainer, cluster_environment=None):
"""
Runs training using horovod

Example::

# default
trainer = Trainer(accelerator=HorovodAccelerator())

"""
super().__init__(trainer, cluster_environment)
self.nickname = 'horovod'

Expand Down
Loading