Skip to content

Commit 66fd0c4

Browse files
authored
Fix config structure enforcing and typechecking. Add full Tuner support. (#133)
* Fix config structure enforcing and typechecking. Add full Tuner support. * Fix config loading when in parse_config() * Improve error message when CLI command is incorrect * Reduce system's learning_rate property docs
1 parent b10b8c6 commit 66fd0c4

File tree

3 files changed

+107
-47
lines changed

3 files changed

+107
-47
lines changed

lighter/system.py

+14
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,20 @@ def setup(self, stage: str) -> None:
361361
self.predict_dataloader = partial(self._base_dataloader, mode="predict")
362362
self.predict_step = partial(self._base_step, mode="predict")
363363

364+
@property
365+
def learning_rate(self) -> float:
366+
"""Get the learning rate of the optimizer. Ensures compatibility with the Tuner's 'lr_find()' method."""
367+
if len(self.optimizer.param_groups) > 1:
368+
raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.")
369+
return self.optimizer.param_groups[0]["lr"]
370+
371+
@learning_rate.setter
372+
def learning_rate(self, value) -> None:
373+
"""Set the learning rate of the optimizer. Ensures compatibility with the Tuner's 'lr_find()' method."""
374+
if len(self.optimizer.param_groups) > 1:
375+
raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.")
376+
self.optimizer.param_groups[0]["lr"] = value
377+
364378
def _init_placeholders_for_dataloader_and_step_methods(self) -> None:
365379
"""
366380
Initializes placeholders for dataloader and step methods.

lighter/utils/misc.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,23 @@ def apply_fns(data: Any, fns: Union[Callable, List[Callable]]) -> Any:
119119

120120
def get_optimizer_stats(optimizer: Optimizer) -> Dict[str, float]:
121121
"""
122-
Extract learning rates and momentum values from each parameter group of the optimizer.
122+
Extract learning rates and momentum values from an optimizer into a dictionary.
123+
124+
This function iterates over the parameter groups of the given optimizer and collects
125+
the learning rate and momentum (or beta values) for each group. The collected values
126+
are stored in a dictionary with keys formatted to indicate the optimizer type and
127+
parameter group index (if multiple groups are present).
123128
124129
Args:
125-
optimizer (Optimizer): A PyTorch optimizer.
130+
optimizer (Optimizer): A PyTorch optimizer instance.
126131
127132
Returns:
128-
Dictionary with formatted keys and values for learning rates and momentum.
133+
Dict[str, float]: A dictionary containing the learning rates and momentum values
134+
for each parameter group in the optimizer. The keys are formatted as:
135+
- "optimizer/{optimizer_class_name}/lr" for learning rates
136+
- "optimizer/{optimizer_class_name}/momentum" for momentum values
137+
If there are multiple parameter groups, the keys will include the group index, e.g.,
138+
"optimizer/{optimizer_class_name}/lr/group1".
129139
"""
130140
stats_dict = {}
131141
for group_idx, group in enumerate(optimizer.param_groups):

lighter/utils/runner.py

+80-44
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,45 @@
11
from typing import Any
22

3+
import copy
34
from functools import partial
45

56
import fire
67
from monai.bundle.config_parser import ConfigParser
7-
from pytorch_lightning import seed_everything
8+
from pytorch_lightning import Trainer, seed_everything
9+
from pytorch_lightning.tuner import Tuner
810

911
from lighter.system import LighterSystem
1012
from lighter.utils.dynamic_imports import import_module_from_path
1113

12-
CONFIG_STRUCTURE = {"project": None, "system": {}, "trainer": {}, "args": {}, "vars": {}}
13-
TRAINER_METHOD_NAMES = ["fit", "validate", "test", "predict", "lr_find", "scale_batch_size"]
14+
CONFIG_STRUCTURE = {
15+
"project": None,
16+
"vars": {},
17+
"args": {
18+
# Keys - names of the methods; values - arguments passed to them.
19+
"fit": {},
20+
"validate": {},
21+
"test": {},
22+
"predict": {},
23+
"lr_find": {},
24+
"scale_batch_size": {},
25+
},
26+
"system": {},
27+
"trainer": {},
28+
}
1429

1530

1631
def cli() -> None:
1732
"""Defines the command line interface for running lightning trainer's methods."""
18-
commands = {method: partial(run, method) for method in TRAINER_METHOD_NAMES}
19-
fire.Fire(commands)
33+
commands = {method: partial(run, method) for method in CONFIG_STRUCTURE["args"]}
34+
try:
35+
fire.Fire(commands)
36+
except TypeError as e:
37+
if "run() takes 1 positional argument but" in str(e):
38+
raise ValueError(
39+
"Ensure that only one command is run at a time (e.g., 'lighter fit') and that "
40+
"other command line arguments start with '--' (e.g., '--config', '--system#batch_size=1')."
41+
) from e
42+
raise
2043

2144

2245
def parse_config(**kwargs) -> ConfigParser:
@@ -29,25 +52,24 @@ def parse_config(**kwargs) -> ConfigParser:
2952
Returns:
3053
An instance of ConfigParser with configuration and overrides merged and parsed.
3154
"""
32-
# Ensure a config file is specified.
3355
config = kwargs.pop("config", None)
3456
if config is None:
3557
raise ValueError("'--config' not specified. Please provide a valid configuration file.")
3658

37-
# Read the config file and update it with overrides.
38-
parser = ConfigParser(CONFIG_STRUCTURE, globals=False)
39-
parser.read_config(config)
59+
# Create a deep copy to ensure the original structure remains unaltered by ConfigParser.
60+
structure = copy.deepcopy(CONFIG_STRUCTURE)
61+
# Initialize the parser with the predefined structure.
62+
parser = ConfigParser(structure, globals=False)
63+
# Update the parser with the configuration file.
64+
parser.update(parser.load_config_files(config))
65+
# Update the parser with the provided cli arguments.
4066
parser.update(kwargs)
4167
return parser
4268

4369

4470
def validate_config(parser: ConfigParser) -> None:
4571
"""
46-
Validates the configuration parser against predefined structures and allowed method names.
47-
48-
This function checks if the keys in the top-level of the configuration parser are valid according to the
49-
CONFIG_STRUCTURE. It also verifies that the 'args' section of the configuration only contains keys that
50-
correspond to valid trainer method names as defined in TRAINER_METHOD_NAMES.
72+
Validates the configuration parser against predefined structure.
5173
5274
Args:
5375
parser (ConfigParser): The configuration parser instance to validate.
@@ -56,20 +78,28 @@ def validate_config(parser: ConfigParser) -> None:
5678
ValueError: If there are invalid keys in the top-level configuration.
5779
ValueError: If there are invalid method names specified in the 'args' section.
5880
"""
59-
# Validate parser keys against structure
60-
root_keys = parser.get().keys()
61-
invalid_root_keys = set(root_keys) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"}
81+
invalid_root_keys = set(parser.get()) - set(CONFIG_STRUCTURE)
6282
if invalid_root_keys:
63-
raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {CONFIG_STRUCTURE.keys()}")
83+
raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {list(CONFIG_STRUCTURE)}.")
6484

65-
# Validate that 'args' contains only valid trainer method names.
66-
args_keys = parser.get("args", {}).keys()
67-
invalid_args_keys = set(args_keys) - set(TRAINER_METHOD_NAMES)
85+
invalid_args_keys = set(parser.get("args")) - set(CONFIG_STRUCTURE["args"])
6886
if invalid_args_keys:
69-
raise ValueError(f"Invalid trainer method in 'args': {invalid_args_keys}. Allowed methods are: {TRAINER_METHOD_NAMES}")
70-
71-
72-
def run(method: str, **kwargs: Any):
87+
raise ValueError(f"Invalid key in 'args': {invalid_args_keys}. Allowed keys: {list(CONFIG_STRUCTURE['args'])}.")
88+
89+
typechecks = {
90+
"project": (str, type(None)),
91+
"vars": dict,
92+
"system": dict,
93+
"trainer": dict,
94+
"args": dict,
95+
**{f"args#{k}": dict for k in CONFIG_STRUCTURE["args"]},
96+
}
97+
for key, dtype in typechecks.items():
98+
if not isinstance(parser.get(key), dtype):
99+
raise ValueError(f"Invalid value for key '{key}'. Expected a {dtype}.")
100+
101+
102+
def run(method: str, **kwargs: Any) -> None:
73103
"""Run the trainer method.
74104
75105
Args:
@@ -82,30 +112,36 @@ def run(method: str, **kwargs: Any):
82112
parser = parse_config(**kwargs)
83113
validate_config(parser)
84114

85-
# Import the project folder as a module, if specified.
115+
# Project. If specified, the give path is imported as a module.
86116
project = parser.get_parsed_content("project")
87117
if project is not None:
88118
import_module_from_path("project", project)
89119

90-
# Get the main components from the parsed config.
120+
# System
91121
system = parser.get_parsed_content("system")
122+
if not isinstance(system, LighterSystem):
123+
raise ValueError("Expected 'system' to be an instance of 'LighterSystem'")
124+
125+
# Trainer
92126
trainer = parser.get_parsed_content("trainer")
93-
trainer_method_args = parser.get_parsed_content(f"args#{method}", default={})
127+
if not isinstance(trainer, Trainer):
128+
raise ValueError("Expected 'trainer' to be an instance of PyTorch Lightning 'Trainer'")
94129

95-
# Checks
96-
if not isinstance(system, LighterSystem):
97-
raise ValueError(f"Expected 'system' to be an instance of LighterSystem, got {system.__class__.__name__}.")
98-
if not hasattr(trainer, method):
99-
raise ValueError(f"{trainer.__class__.__name__} has no method named '{method}'.")
100-
if any("dataloaders" in key or "datamodule" in key for key in trainer_method_args):
101-
raise ValueError("All dataloaders should be defined as part of the LighterSystem, not passed as method arguments.")
102-
103-
# Save the config to checkpoints under "hyper_parameters" and log it if a logger is defined.
104-
config = parser.get()
105-
config.pop("_meta_") # MONAI Bundle adds this automatically, remove it.
106-
system.save_hyperparameters(config)
107-
if trainer.logger is not None:
108-
trainer.logger.log_hyperparams(config)
130+
# Trainer/Tuner method arguments.
131+
method_args = parser.get_parsed_content(f"args#{method}")
132+
if any("dataloaders" in key or "datamodule" in key for key in method_args):
133+
raise ValueError("Datasets are defined within the 'system', not passed in `args`.")
109134

110-
# Run the trainer method.
111-
getattr(trainer, method)(system, **trainer_method_args)
135+
# Save the config to checkpoints under "hyper_parameters". Log it if a logger is defined.
136+
system.save_hyperparameters(parser.get())
137+
if trainer.logger is not None:
138+
trainer.logger.log_hyperparams(parser.get())
139+
140+
# Run the trainer/tuner method.
141+
if hasattr(trainer, method):
142+
getattr(trainer, method)(system, **method_args)
143+
elif hasattr(Tuner, method):
144+
tuner = Tuner(trainer)
145+
getattr(tuner, method)(system, **method_args)
146+
else:
147+
raise ValueError(f"Method '{method}' is not a valid Trainer or Tuner method [{list(CONFIG_STRUCTURE['args'])}].")

0 commit comments

Comments
 (0)