Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom apex and amp plugins #4355

Merged
merged 4 commits into from
Oct 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,29 @@ Plugins

Plugins allow custom integrations to the internals of the Trainer such as a custom amp or ddp implementation.

For example, to customize your own DistributedDataParallel you could do something like this:

.. code-block:: python

class MyDDP(DDPPlugin):
...

# use your own ddp algorithm
my_ddp = MyDDP()
trainer = Trainer(plugins=[my_ddp])

**********
ApexPlugin
**********

.. autoclass:: pytorch_lightning.plugins.apex.ApexPlugin

***************
NativeAMPPlugin
***************

.. autoclass:: pytorch_lightning.plugins.native_amp.NativeAMPPlugin

*********
DDPPlugin
*********
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class ApexPlugin:

def __init__(self, trainer):
def __init__(self, trainer=None):
self.trainer = trainer

def connect(self, model, optimizers):
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import torch


class NativeAMP:
class NativeAMPPlugin:

def __init__(self, trainer):
def __init__(self, trainer=None):
"""
Integrates native amp into Lightning's internals.
"""
self.trainer = trainer

def connect(self, model, optimizers):
Expand Down
35 changes: 35 additions & 0 deletions pytorch_lightning/plugins/plugin_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.apex import ApexPlugin
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.utilities import AMPType


class PluginConnector:
Expand All @@ -23,6 +26,8 @@ def __init__(self, trainer):
self.plugins = []
self.ddp_plugin = DDPPlugin()
self.cloud_environment = None
self.amp_plugin = NativeAMPPlugin(trainer)
self.apex_plugin = ApexPlugin(trainer)

def on_trainer_init(self, plugins):
self.plugins = plugins
Expand All @@ -31,6 +36,36 @@ def on_trainer_init(self, plugins):

self.__attach_ddp()
self.__attach_cluster()
self.__attach_amp()
self.__attach_apex()

def __attach_amp(self):
amp_plugin = self.__attach_plugin(NativeAMPPlugin)
if amp_plugin:
self.trainer.amp_backend = AMPType.NATIVE
self.trainer.precision_connector.backend = amp_plugin

def __attach_apex(self):
apex_plugin = self.__attach_plugin(ApexPlugin)
if apex_plugin:
self.trainer.amp_backend = AMPType.NATIVE
self.trainer.precision_connector.backend = apex_plugin

def __attach_plugin(self, plugin_type, limit=1):
count = 0
plugin_result = None
for plugin in self.plugins:
if isinstance(plugin, plugin_type):

# count the clusters
count += 1
if count > limit:
m = f'you can only use one {plugin_type.__class__} in plugins. You passed in: {count}'
raise MisconfigurationException(m)

plugin_result = plugin

return plugin_result

def __attach_ddp(self, limit=1):
count = 0
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/connectors/precision_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.apex import ApexPlugin
from pytorch_lightning.plugins.native_amp import NativeAMP
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn


Expand Down Expand Up @@ -56,7 +56,7 @@ def _setup_amp_backend(self, amp_type: str):
else:
log.info('Using native 16bit precision.')
self.trainer.amp_backend = AMPType.NATIVE
self.backend = NativeAMP(self.trainer)
self.backend = NativeAMPPlugin(self.trainer)

if amp_type == 'apex':
if not APEX_AVAILABLE:
Expand Down
86 changes: 86 additions & 0 deletions tests/plugins/test_amp_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from pytorch_lightning.callbacks import Callback
from tests.base.boring_model import BoringModel
from pytorch_lightning import Trainer
import pytest
import os
from unittest import mock
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from distutils.version import LooseVersion
import torch


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6",
)
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):

class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_connector.backend, NativeAMPPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='native',
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
callbacks=[CB()]
)

with pytest.raises(SystemExit):
trainer.fit(model)


@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6",
)
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
class MyNativeAMP(NativeAMPPlugin):
pass

class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_connector.backend, MyNativeAMP)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='native',
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
plugins=[MyNativeAMP()],
callbacks=[CB()]
)

with pytest.raises(SystemExit):
trainer.fit(model)
79 changes: 79 additions & 0 deletions tests/plugins/test_apex_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import APEX_AVAILABLE
from tests.base.boring_model import BoringModel
from pytorch_lightning import Trainer
import pytest
import os
from unittest import mock
from pytorch_lightning.plugins.apex import ApexPlugin


@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):

class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_connector.backend, ApexPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='apex',
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
callbacks=[CB()]
)

with pytest.raises(SystemExit):
trainer.fit(model)


@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0"
})
@mock.patch('torch.cuda.device_count', return_value=2)
@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'],
[('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)])
def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
class MyApexPlugin(ApexPlugin):
pass

class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.precision_connector.backend, MyApexPlugin)
raise SystemExit()

model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend='apex',
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
plugins=[MyApexPlugin()],
callbacks=[CB()]
)

with pytest.raises(SystemExit):
trainer.fit(model)