Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Training Type Plugins Registry #6982

Merged
merged 27 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401
call_training_type_register_plugins,
TrainingTypePluginsRegistry,
)
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
Expand Down Expand Up @@ -47,3 +51,10 @@
'DDPShardedPlugin',
'DDPSpawnShardedPlugin',
]

from pathlib import Path

FILE_ROOT = Path(__file__).parent
TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type"

call_training_type_register_plugins(FILE_ROOT, TRAINING_TYPE_BASE_MODULE)
146 changes: 146 additions & 0 deletions pytorch_lightning/plugins/plugins_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from collections import UserDict
from inspect import getmembers, isclass
from pathlib import Path
from typing import Any, Callable, List, Optional

from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class _TrainingTypePluginsRegistry(UserDict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we maybe make this a general registry and reuse this with flash? We also have a registry in flash and duplicating this does not make sense...

I think you only have to change the naming of the fields...

cc @tchaton

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did think of a generic LightningRegistry that could be used for different types. But for now, there has only been a need for TrainingTypePlugins, we could change it down the road as well, as _TrainingTypePluginsRegistry is internal.

"""
This class is a Registry that stores information about the Training Type Plugins.

The Plugins are mapped to strings. These strings are names that idenitify
a plugin, eg., "deepspeed". It also returns Optional description and
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
parameters to initialize the Plugin, which were defined durng the
registeration.
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

The motivation for having a TrainingTypePluginRegistry is to make it convenient
for the Users to try different Plugins by passing just strings
to the plugins flag to the Trainer.

Example::

@TrainingTypePluginsRegistry.register("lightning", description="Super fast", a=1, b=True)
class LightningPlugin:
def __init__(self, a, b):
...

or

TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True)

"""

def register(
self,
name: str,
plugin: Optional[Callable] = None,
description: Optional[str] = None,
override: bool = False,
**init_params: Any,
) -> Callable:
"""
Registers a plugin mapped to a name and with required metadata.

Args:
name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3"
plugin (callable): plugin class
description (str): plugin description
override (bool): overrides the registered plugin, if True
init_params: parameters to initialize the plugin
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
"""
if not (name is None or isinstance(name, str)):
raise TypeError(f'`name` must be a str, found {name}')

if name in self and not override:
raise MisconfigurationException(
f"'{name}' is already present in the registry."
" HINT: Use `override=True`."
)

data = {}
data["description"] = description if description is not None else ""

data["init_params"] = init_params

def do_register(plugin: Callable) -> Callable:
data["plugin"] = plugin
self[name] = data
return plugin

if plugin is not None:
return do_register(plugin)

return do_register

def get(self, name: str) -> Any:
"""
Calls the registered plugin with the required parameters
and returns the plugin object

Args:
name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3"
"""
if name in self:
data = self[name]
return data["plugin"](**data["init_params"])

err_msg = "'{}' not found in registry. Available names: {}"
available_names = ", ".join(sorted(self.keys())) or "none"
raise KeyError(err_msg.format(name, available_names))

def remove(self, name: str) -> None:
"""Removes the registered plugin by name"""
self.pop(name)

def available_plugins(self) -> List:
"""Returns a list of registered plugins"""
return list(self.keys())

def __str__(self) -> str:
return "Registered Plugins: {}".format(", ".join(self.keys()))


TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry()


def is_register_plugins_overriden(plugin: Callable) -> bool:
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
method_name = "register_plugins"
plugin_attr = getattr(plugin, method_name)
super_attr = getattr(TrainingTypePlugin, method_name)

if hasattr(plugin_attr, 'patch_loader_code'):
is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overridden = plugin_attr.__code__ is not super_attr.__code__
return is_overridden


def call_training_type_register_plugins(root: Path, base_module: str) -> None:
# Ref: https://github.com/facebookresearch/ClassyVision/blob/master/classy_vision/generic/registry_utils.py#L14
directory = "training_type"
for file in os.listdir(root / directory):
if file.endswith(".py") and not file.startswith("_"):
module = file[:file.find(".py")]
module = importlib.import_module(".".join([base_module, module]))
for _, mod in getmembers(module, isclass):
if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overriden(mod):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
mod.register_plugins(TrainingTypePluginsRegistry)
break
20 changes: 20 additions & 0 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,23 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) ->
if total_batch_idx % self._original_accumulate_grad_batches == 0:
current_global_step += 1
return current_global_step

@classmethod
def register_plugins(cls, plugin_registry):
plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin")
plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
plugin_registry.register(
"deepspeed_stage_2_offload",
cls,
description="DeepSpeed ZeRO Stage 2 and CPU Offload",
stage=2,
cpu_offload=True
)
plugin_registry.register("deepspeed_stage_3", cls, description="DeepSpeed ZeRO Stage 3", stage=3)
plugin_registry.register(
"deepspeed_stage_3_offload",
cls,
description="DeepSpeed ZeRO Stage 3 and CPU Offload",
stage=3,
cpu_offload=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,7 @@ def call_configure_sharded_model_hook(self) -> bool:
@call_configure_sharded_model_hook.setter
def call_configure_sharded_model_hook(self, mode: bool) -> None:
self._call_configure_sharded_model_hook = mode

@classmethod
def register_plugins(cls, plugin_registry):
pass
14 changes: 12 additions & 2 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
TPUHalfPrecisionPlugin,
TPUSpawnPlugin,
TrainingTypePlugin,
TrainingTypePluginsRegistry,
)
from pytorch_lightning.plugins.environments import (
ClusterEnvironment,
Expand Down Expand Up @@ -163,7 +164,16 @@ def handle_given_plugins(
cluster_environment = None

for plug in plugins:
if isinstance(plug, str):
if isinstance(plug, str) and plug in TrainingTypePluginsRegistry:
if training_type is None:
training_type = TrainingTypePluginsRegistry.get(plug)
else:
raise MisconfigurationException(
'You can only specify one precision and one training type plugin.'
' Found more than 1 training type plugin:'
f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}'
)
elif isinstance(plug, str):
# Reset the distributed type as the user has overridden training type
# via the plugins argument
self._distrib_type = None
Expand Down Expand Up @@ -515,7 +525,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
# todo: in some cases it yield in comarison None and int
# todo: in some cases it yield in comparison None and int
if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1):
self._distrib_type = DistributedType.DDP
else:
Expand Down
83 changes: 83 additions & 0 deletions tests/plugins/test_plugins_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import TrainingTypePluginsRegistry
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
from tests.helpers.runif import RunIf


def test_training_type_plugins_registry_with_new_plugin():

class TestPlugin:

def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2

plugin_name = "test_plugin"
plugin_description = "Test Plugin"

TrainingTypePluginsRegistry.register(
plugin_name, TestPlugin, description=plugin_description, param1="abc", param2=123
)

assert plugin_name in TrainingTypePluginsRegistry
assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description
assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123}
assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin)

TrainingTypePluginsRegistry.remove(plugin_name)
assert plugin_name not in TrainingTypePluginsRegistry


@pytest.mark.parametrize(
"plugin_name, init_params",
[
("deepspeed", {}),
("deepspeed_stage_2", {
"stage": 2
}),
("deepspeed_stage_2_offload", {
"stage": 2,
"cpu_offload": True
}),
("deepspeed_stage_3", {
"stage": 3
}),
("deepspeed_stage_3_offload", {
"stage": 3,
"cpu_offload": True
}),
],
)
def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init_params):

assert plugin_name in TrainingTypePluginsRegistry
assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == init_params
assert TrainingTypePluginsRegistry[plugin_name]["plugin"] == DeepSpeedPlugin


@RunIf(deepspeed=True)
@pytest.mark.parametrize("plugin", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"])
def test_training_type_plugins_registry_with_trainer(tmpdir, plugin):

trainer = Trainer(
default_root_dir=tmpdir,
plugins=plugin,
precision=16,
)

assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin)