diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index a67235baa4767..444d2aaef978b 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,4 +1,8 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 +from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 + call_training_type_register_plugins, + TrainingTypePluginsRegistry, +) from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 @@ -47,3 +51,10 @@ 'DDPShardedPlugin', 'DDPSpawnShardedPlugin', ] + +from pathlib import Path + +FILE_ROOT = Path(__file__).parent +TRAINING_TYPE_BASE_MODULE = "pytorch_lightning.plugins.training_type" + +call_training_type_register_plugins(FILE_ROOT, TRAINING_TYPE_BASE_MODULE) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py new file mode 100644 index 0000000000000..59dd7d8db6bff --- /dev/null +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -0,0 +1,146 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +from collections import UserDict +from inspect import getmembers, isclass +from pathlib import Path +from typing import Any, Callable, List, Optional + +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class _TrainingTypePluginsRegistry(UserDict): + """ + This class is a Registry that stores information about the Training Type Plugins. + + The Plugins are mapped to strings. These strings are names that idenitify + a plugin, e.g., "deepspeed". It also returns Optional description and + parameters to initialize the Plugin, which were defined durng the + registration. + + The motivation for having a TrainingTypePluginRegistry is to make it convenient + for the Users to try different Plugins by passing just strings + to the plugins flag to the Trainer. + + Example:: + + @TrainingTypePluginsRegistry.register("lightning", description="Super fast", a=1, b=True) + class LightningPlugin: + def __init__(self, a, b): + ... + + or + + TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True) + + """ + + def register( + self, + name: str, + plugin: Optional[Callable] = None, + description: Optional[str] = None, + override: bool = False, + **init_params: Any, + ) -> Callable: + """ + Registers a plugin mapped to a name and with required metadata. + + Args: + name : the name that identifies a plugin, e.g. "deepspeed_stage_3" + plugin : plugin class + description : plugin description + override : overrides the registered plugin, if True + init_params: parameters to initialize the plugin + """ + if not (name is None or isinstance(name, str)): + raise TypeError(f'`name` must be a str, found {name}') + + if name in self and not override: + raise MisconfigurationException( + f"'{name}' is already present in the registry." + " HINT: Use `override=True`." + ) + + data = {} + data["description"] = description if description is not None else "" + + data["init_params"] = init_params + + def do_register(plugin: Callable) -> Callable: + data["plugin"] = plugin + self[name] = data + return plugin + + if plugin is not None: + return do_register(plugin) + + return do_register + + def get(self, name: str) -> Any: + """ + Calls the registered plugin with the required parameters + and returns the plugin object + + Args: + name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3" + """ + if name in self: + data = self[name] + return data["plugin"](**data["init_params"]) + + err_msg = "'{}' not found in registry. Available names: {}" + available_names = ", ".join(sorted(self.keys())) or "none" + raise KeyError(err_msg.format(name, available_names)) + + def remove(self, name: str) -> None: + """Removes the registered plugin by name""" + self.pop(name) + + def available_plugins(self) -> List: + """Returns a list of registered plugins""" + return list(self.keys()) + + def __str__(self) -> str: + return "Registered Plugins: {}".format(", ".join(self.keys())) + + +TrainingTypePluginsRegistry = _TrainingTypePluginsRegistry() + + +def is_register_plugins_overridden(plugin: Callable) -> bool: + method_name = "register_plugins" + plugin_attr = getattr(plugin, method_name) + super_attr = getattr(TrainingTypePlugin, method_name) + + if hasattr(plugin_attr, 'patch_loader_code'): + is_overridden = plugin_attr.patch_loader_code != str(super_attr.__code__) + else: + is_overridden = plugin_attr.__code__ is not super_attr.__code__ + return is_overridden + + +def call_training_type_register_plugins(root: Path, base_module: str) -> None: + # Ref: https://github.com/facebookresearch/ClassyVision/blob/master/classy_vision/generic/registry_utils.py#L14 + directory = "training_type" + for file in os.listdir(root / directory): + if file.endswith(".py") and not file.startswith("_"): + module = file[:file.find(".py")] + module = importlib.import_module(".".join([base_module, module])) + for _, mod in getmembers(module, isclass): + if issubclass(mod, TrainingTypePlugin) and is_register_plugins_overridden(mod): + mod.register_plugins(TrainingTypePluginsRegistry) + break diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 3dc52b60055d8..f3af6346120f8 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -520,3 +520,23 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) -> if total_batch_idx % self._original_accumulate_grad_batches == 0: current_global_step += 1 return current_global_step + + @classmethod + def register_plugins(cls, plugin_registry): + plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin") + plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) + plugin_registry.register( + "deepspeed_stage_2_offload", + cls, + description="DeepSpeed ZeRO Stage 2 and CPU Offload", + stage=2, + cpu_offload=True + ) + plugin_registry.register("deepspeed_stage_3", cls, description="DeepSpeed ZeRO Stage 3", stage=3) + plugin_registry.register( + "deepspeed_stage_3_offload", + cls, + description="DeepSpeed ZeRO Stage 3 and CPU Offload", + stage=3, + cpu_offload=True + ) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6fd02142bf410..8ba002e1641d3 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -284,3 +284,7 @@ def call_configure_sharded_model_hook(self) -> bool: @call_configure_sharded_model_hook.setter def call_configure_sharded_model_hook(self, mode: bool) -> None: self._call_configure_sharded_model_hook = mode + + @classmethod + def register_plugins(cls, plugin_registry): + pass diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index aa52ec1c40d82..ee58bce1c8fc6 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -42,6 +42,7 @@ TPUHalfPrecisionPlugin, TPUSpawnPlugin, TrainingTypePlugin, + TrainingTypePluginsRegistry, ) from pytorch_lightning.plugins.environments import ( ClusterEnvironment, @@ -163,7 +164,16 @@ def handle_given_plugins( cluster_environment = None for plug in plugins: - if isinstance(plug, str): + if isinstance(plug, str) and plug in TrainingTypePluginsRegistry: + if training_type is None: + training_type = TrainingTypePluginsRegistry.get(plug) + else: + raise MisconfigurationException( + 'You can only specify one precision and one training type plugin.' + ' Found more than 1 training type plugin:' + f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}' + ) + elif isinstance(plug, str): # Reset the distributed type as the user has overridden training type # via the plugins argument self._distrib_type = None @@ -515,7 +525,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): rank_zero_warn( 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' ) - # todo: in some cases it yield in comarison None and int + # todo: in some cases it yield in comparison None and int if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1): self._distrib_type = DistributedType.DDP else: diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py new file mode 100644 index 0000000000000..91d9596578dfc --- /dev/null +++ b/tests/plugins/test_plugins_registry.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import TrainingTypePluginsRegistry +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin +from tests.helpers.runif import RunIf + + +def test_training_type_plugins_registry_with_new_plugin(): + + class TestPlugin: + + def __init__(self, param1, param2): + self.param1 = param1 + self.param2 = param2 + + plugin_name = "test_plugin" + plugin_description = "Test Plugin" + + TrainingTypePluginsRegistry.register( + plugin_name, TestPlugin, description=plugin_description, param1="abc", param2=123 + ) + + assert plugin_name in TrainingTypePluginsRegistry + assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description + assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123} + assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin) + + TrainingTypePluginsRegistry.remove(plugin_name) + assert plugin_name not in TrainingTypePluginsRegistry + + +@pytest.mark.parametrize( + "plugin_name, init_params", + [ + ("deepspeed", {}), + ("deepspeed_stage_2", { + "stage": 2 + }), + ("deepspeed_stage_2_offload", { + "stage": 2, + "cpu_offload": True + }), + ("deepspeed_stage_3", { + "stage": 3 + }), + ("deepspeed_stage_3_offload", { + "stage": 3, + "cpu_offload": True + }), + ], +) +def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init_params): + + assert plugin_name in TrainingTypePluginsRegistry + assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == init_params + assert TrainingTypePluginsRegistry[plugin_name]["plugin"] == DeepSpeedPlugin + + +@RunIf(deepspeed=True) +@pytest.mark.parametrize("plugin", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"]) +def test_training_type_plugins_registry_with_trainer(tmpdir, plugin): + + trainer = Trainer( + default_root_dir=tmpdir, + plugins=plugin, + precision=16, + ) + + assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin)