Skip to content

Commit 0e1298b

Browse files
ibro45surajpaib
andauthored
Refactor cli and runner. Implement reserved config keys. Add feature to pass args to Trainer's methods. (#124)
* Support subloss logging through loss dicts * Replace implicit subloss addition with user's "total" loss key * Implement config reserved keys and allow methods args to be specified * Update the cifar test with the new run function * Small fix * Remove old cli module from suppressed modules * Allow _meta_ and _requires_ at top config level * Fix test and change config_file to config * Fix running without args * Update error messages * Remove the unnecessary default for args and update comments * Add back default value that is necessary for some reason --------- Co-authored-by: Suraj Pai <[email protected]>
1 parent 9eefdb5 commit 0e1298b

File tree

8 files changed

+87
-70
lines changed

8 files changed

+87
-70
lines changed

docs/basics/projects.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ In the above example, the path of the dataset is `/home/user/project/my_xray_dat
136136

137137
=== "Terminal"
138138
```
139-
lighter fit --config_file xray.yaml
139+
lighter fit --config xray.yaml
140140
```
141141

142142
</div>
143143

144-
1. Make sure to put an `__init__.py` file in this directory. Remember this is needed for an importable python module
144+
1. Make sure to put an `__init__.py` file in this directory. Remember this is needed for an importable python module

docs/basics/quickstart.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ We just combine the Trainer and LighterSystem into a single YAML and run the com
129129
```
130130
=== "Terminal"
131131
```
132-
lighter fit --config_file cifar10.yaml
132+
lighter fit --config cifar10.yaml
133133
```
134134

135135

136-
Congratulations!! You have run your first training example with Lighter.
136+
Congratulations!! You have run your first training example with Lighter.

docs/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Say goodbye to messy scripts and notebooks. Lighter is here to help you organize
9999
<div style="width: 49%;">
100100
<h3 style="text-align: center">Lighter</h3>
101101
```bash title="Terminal"
102-
lighter fit --config_file cifar10.yaml
102+
lighter fit --config cifar10.yaml
103103
```
104104
```yaml title="cifar10.yaml"
105105
trainer:

lighter/utils/cli.py

-23
This file was deleted.

lighter/utils/logging.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
SUPPRESSED_MODULES = [
33
"fire",
44
"monai.bundle",
5-
"lighter.utils.cli",
65
"lighter.utils.runner",
76
"pytorch_lightning.trainer",
87
"lightning_utilities",

lighter/utils/runner.py

+76-35
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,111 @@
1-
from typing import Any, Dict
1+
from typing import Any
22

3+
from functools import partial
4+
5+
import fire
36
from monai.bundle.config_parser import ConfigParser
47
from pytorch_lightning import seed_everything
58

9+
from lighter.system import LighterSystem
610
from lighter.utils.dynamic_imports import import_module_from_path
711

12+
CONFIG_STRUCTURE = {"project": None, "system": {}, "trainer": {}, "args": {}, "vars": {}}
13+
TRAINER_METHOD_NAMES = ["fit", "validate", "test", "predict", "lr_find", "scale_batch_size"]
14+
15+
16+
def cli() -> None:
17+
"""Defines the command line interface for running lightning trainer's methods."""
18+
commands = {method: partial(run, method) for method in TRAINER_METHOD_NAMES}
19+
fire.Fire(commands)
20+
821

922
def parse_config(**kwargs) -> ConfigParser:
1023
"""
1124
Parses configuration files and updates the provided parser
1225
with given keyword arguments. Returns an updated parser object.
1326
1427
Args:
15-
**kwargs (dict): Keyword arguments containing configuration data.
16-
config_file (str): Path to the main configuration file.
17-
args_file (str, optional): Path to secondary configuration file for additional arguments.
18-
Additional key-value pairs can also be provided to be added or updated in the parser.
19-
28+
**kwargs (dict): Keyword arguments containing 'config' and, optionally, config overrides.
2029
Returns:
21-
An instance of ConfigParser with parsed and merged configuration data.
30+
An instance of ConfigParser with configuration and overrides merged and parsed.
2231
"""
32+
# Ensure a config file is specified.
33+
config = kwargs.pop("config", None)
34+
if config is None:
35+
raise ValueError("'--config' not specified. Please provide a valid configuration file.")
36+
37+
# Read the config file and update it with overrides.
38+
parser = ConfigParser(CONFIG_STRUCTURE, globals=False)
39+
parser.read_config(config)
40+
parser.update(kwargs)
41+
return parser
2342

24-
# Check that a config file is specified.
25-
if "config_file" not in kwargs:
26-
raise ValueError("--config_file not specified. Exiting.")
2743

28-
# Parse the config file(s).
29-
parser = ConfigParser()
30-
parser.read_config(kwargs.pop("config_file"))
31-
parser.update(pairs=kwargs)
44+
def validate_config(parser: ConfigParser) -> None:
45+
"""
46+
Validates the configuration parser against predefined structures and allowed method names.
3247
33-
# Import the project folder as a module, if specified.
34-
project = parser.get("project", None)
35-
if project is not None:
36-
import_module_from_path("project", project)
48+
This function checks if the keys in the top-level of the configuration parser are valid according to the
49+
CONFIG_STRUCTURE. It also verifies that the 'args' section of the configuration only contains keys that
50+
correspond to valid trainer method names as defined in TRAINER_METHOD_NAMES.
3751
38-
return parser
52+
Args:
53+
parser (ConfigParser): The configuration parser instance to validate.
54+
55+
Raises:
56+
ValueError: If there are invalid keys in the top-level configuration.
57+
ValueError: If there are invalid method names specified in the 'args' section.
58+
"""
59+
# Validate parser keys against structure
60+
root_keys = parser.get().keys()
61+
invalid_root_keys = set(root_keys) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"}
62+
if invalid_root_keys:
63+
raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {CONFIG_STRUCTURE.keys()}")
64+
65+
# Validate that 'args' contains only valid trainer method names.
66+
args_keys = parser.get("args", {}).keys()
67+
invalid_args_keys = set(args_keys) - set(TRAINER_METHOD_NAMES)
68+
if invalid_args_keys:
69+
raise ValueError(f"Invalid trainer method in 'args': {invalid_args_keys}. Allowed methods are: {TRAINER_METHOD_NAMES}")
3970

4071

41-
def run_trainer_method(method: Dict, **kwargs: Any):
42-
"""Call monai.bundle.run() on a Trainer method. If a project path
43-
is defined in the config file(s), import it.
72+
def run(method: str, **kwargs: Any):
73+
"""Run the trainer method.
4474
4575
Args:
46-
method (str): name of the Trainer method to run. ["fit", "validate", "test", "predict", "tune"].
47-
**kwargs (Any): keyword arguments passed to the `monai.bundle.run` function.
76+
method (str): name of the trainer method to run.
77+
**kwargs (Any): keyword arguments that include 'config' and specific config overrides passed to `parse_config()`.
4878
"""
49-
# Sets the random seed to `PL_GLOBAL_SEED` env variable. If not specified, it picks a random seed.
5079
seed_everything()
5180

52-
# Parse the config file(s).
81+
# Parse and validate the config.
5382
parser = parse_config(**kwargs)
83+
validate_config(parser)
5484

55-
# Get trainer and system
56-
trainer = parser.get_parsed_content("trainer")
85+
# Import the project folder as a module, if specified.
86+
project = parser.get_parsed_content("project")
87+
if project is not None:
88+
import_module_from_path("project", project)
89+
90+
# Get the main components from the parsed config.
5791
system = parser.get_parsed_content("system")
92+
trainer = parser.get_parsed_content("trainer")
93+
trainer_method_args = parser.get_parsed_content(f"args#{method}", default={})
5894

95+
# Checks
96+
if not isinstance(system, LighterSystem):
97+
raise ValueError(f"Expected 'system' to be an instance of LighterSystem, got {system.__class__.__name__}.")
98+
if not hasattr(trainer, method):
99+
raise ValueError(f"{trainer.__class__.__name__} has no method named '{method}'.")
100+
if any("dataloaders" in key or "datamodule" in key for key in trainer_method_args):
101+
raise ValueError("All dataloaders should be defined as part of the LighterSystem, not passed as method arguments.")
102+
103+
# Save the config to checkpoints under "hyper_parameters" and log it if a logger is defined.
59104
config = parser.get()
60-
config.pop("_meta_")
61-
# Save the config to model checkpoints under the "hyper_parameters" key.
105+
config.pop("_meta_") # MONAI Bundle adds this automatically, remove it.
62106
system.save_hyperparameters(config)
63-
# Log the config.
64107
if trainer.logger is not None:
65108
trainer.logger.log_hyperparams(config)
66109

67-
# Run the Trainer method.
68-
if not hasattr(trainer, method):
69-
raise ValueError(f"Trainer has no method named {method}.")
70-
getattr(trainer, method)(system)
110+
# Run the trainer method.
111+
getattr(trainer, method)(system, **trainer_method_args)

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ requires = ["poetry_core>=1.0.0"]
44
build-backend = "poetry.core.masonry.api"
55

66
[tool.poetry.scripts]
7-
lighter = "lighter.utils.cli:interface"
7+
lighter = "lighter.utils.runner:cli"
88

99
[tool.poetry]
1010
name = "project-lighter"

tests/integration/test_cifar.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import pytest
44

5-
from lighter.utils.cli import run_trainer_method
5+
from lighter.utils.runner import run
66

77
test_overrides = "./tests/integration/test_overrides.yaml"
88

99

1010
@pytest.mark.parametrize(
11-
("method_name", "config_file"),
11+
("method_name", "config"),
1212
[
1313
( # Method name
1414
"fit",
@@ -18,9 +18,9 @@
1818
],
1919
)
2020
@pytest.mark.slow
21-
def test_trainer_method(method_name: str, config_file: str):
21+
def test_trainer_method(method_name: str, config: str):
2222
""" """
23-
kwargs = {"config_file": config_file, "args_file": test_overrides}
23+
kwargs = {"config": [config, test_overrides]}
2424

25-
func_return = run_trainer_method(method_name, **kwargs)
25+
func_return = run(method_name, **kwargs)
2626
assert func_return is None

0 commit comments

Comments
 (0)