Skip to content

Commit

Permalink
Minor CLI improvements [1/3] (#9553)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Sep 16, 2021
1 parent b845414 commit d2ca81b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 32 deletions.
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/plugins_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import importlib
import inspect
from collections import UserDict
from inspect import getmembers, isclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
Expand All @@ -22,7 +21,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class _TrainingTypePluginsRegistry(UserDict):
class _TrainingTypePluginsRegistry(dict):
"""This class is a Registry that stores information about the Training Type Plugins.
The Plugins are mapped to strings. These strings are names that idenitify
Expand Down
43 changes: 25 additions & 18 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non
"--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
)
self.callback_keys: List[str] = []
self.optimizers_and_lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
# separate optimizers and lr schedulers to know which were added
self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}

def add_lightning_class_args(
self,
Expand Down Expand Up @@ -115,10 +117,10 @@ def add_optimizer_args(
assert issubclass(optimizer_class, Optimizer)
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
if isinstance(optimizer_class, tuple):
self.add_subclass_arguments(optimizer_class, nested_key, required=True, **kwargs)
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
else:
self.add_class_arguments(optimizer_class, nested_key, **kwargs)
self.optimizers_and_lr_schedulers[nested_key] = (optimizer_class, link_to)
self._optimizers[nested_key] = (optimizer_class, link_to)

def add_lr_scheduler_args(
self,
Expand All @@ -139,10 +141,10 @@ def add_lr_scheduler_args(
assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple)
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
if isinstance(lr_scheduler_class, tuple):
self.add_subclass_arguments(lr_scheduler_class, nested_key, required=True, **kwargs)
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
else:
self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs)
self.optimizers_and_lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)


class SaveConfigCallback(Callback):
Expand Down Expand Up @@ -374,7 +376,8 @@ def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any
@staticmethod
def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None:
"""Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``."""
for key, (class_type, link_to) in parser.optimizers_and_lr_schedulers.items():
optimizers_and_lr_schedulers = {**parser._optimizers, **parser._lr_schedulers}
for key, (class_type, link_to) in optimizers_and_lr_schedulers.items():
if link_to == "AUTOMATIC":
continue
if isinstance(class_type, tuple):
Expand Down Expand Up @@ -423,7 +426,7 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]
config["callbacks"].append(config_callback)
return self.trainer_class(**config)

def _parser(self, subcommand: Optional[str]) -> ArgumentParser:
def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser:
if subcommand is None:
return self.parser
# return the subcommand parser for the subcommand passed
Expand All @@ -438,19 +441,20 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -
`configure_optimizers` method is automatically implemented in the model class.
"""
parser = self._parser(subcommand)
optimizers_and_lr_schedulers = parser.optimizers_and_lr_schedulers

def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
def get_automatic(
class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]]
) -> List[str]:
automatic = []
for key, (base_class, link_to) in optimizers_and_lr_schedulers.items():
for key, (base_class, link_to) in register.items():
if not isinstance(base_class, tuple):
base_class = (base_class,)
if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class):
automatic.append(key)
return automatic

optimizers = get_automatic(Optimizer)
lr_schedulers = get_automatic(LRSchedulerTypeTuple)
optimizers = get_automatic(Optimizer, parser._optimizers)
lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers)

if len(optimizers) == 0:
return
Expand All @@ -470,14 +474,17 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`."
)

optimizer_class = optimizers_and_lr_schedulers[optimizers[0]][0]
optimizer_init = self._get(self.config_init, optimizers[0], default={})
optimizer_class = parser._optimizers[optimizers[0]][0]
optimizer_init = self._get(self.config_init, optimizers[0])
if not isinstance(optimizer_class, tuple):
optimizer_init = _global_add_class_path(optimizer_class, optimizer_init)
if not optimizer_init:
# optimizers were registered automatically but not passed by the user
return
lr_scheduler_init = None
if lr_schedulers:
lr_scheduler_class = optimizers_and_lr_schedulers[lr_schedulers[0]][0]
lr_scheduler_init = self._get(self.config_init, lr_schedulers[0], default={})
lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0]
lr_scheduler_init = self._get(self.config_init, lr_schedulers[0])
if not isinstance(lr_scheduler_class, tuple):
lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)

Expand Down Expand Up @@ -524,8 +531,8 @@ def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]:
return fn_kwargs


def _global_add_class_path(class_type: Type, init_args: Dict[str, Any]) -> Dict[str, Any]:
return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args}
def _global_add_class_path(class_type: Type, init_args: Dict[str, Any] = None) -> Dict[str, Any]:
return {"class_path": class_type.__module__ + "." + class_type.__name__, "init_args": init_args or {}}


def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
Expand Down
19 changes: 7 additions & 12 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def _raise():
def test_parse_args_parsing(cli_args, expected):
"""Test parsing simple types and None optionals not modified."""
cli_args = cli_args.split(" ") if cli_args else []
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
with mock.patch("sys.argv", ["any.py"] + cli_args):
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
args = parser.parse_args()

for k, v in expected.items():
Expand All @@ -155,9 +155,9 @@ def test_parse_args_parsing(cli_args, expected):
)
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
"""Test parsing complex types."""
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
with mock.patch("sys.argv", ["any.py"] + cli_args):
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
args = parser.parse_args()

for k, v in expected.items():
Expand All @@ -171,9 +171,9 @@ def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu):
"""Test parsing of gpus and instantiation of Trainer."""
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
cli_args = cli_args.split(" ") if cli_args else []
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
with mock.patch("sys.argv", ["any.py"] + cli_args):
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
args = parser.parse_args()

trainer = Trainer.from_argparse_args(args)
Expand Down Expand Up @@ -639,12 +639,7 @@ def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)

cli_args = [
"fit",
f"--trainer.default_root_dir={tmpdir}",
"--trainer.fast_dev_run=1",
"--lr_scheduler.gamma=0.8",
]
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", "--lr_scheduler.gamma=0.8"]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModel)
Expand Down

0 comments on commit d2ca81b

Please sign in to comment.