|
1 |
| -from typing import Any, Dict |
| 1 | +from typing import Any |
2 | 2 |
|
| 3 | +from functools import partial |
| 4 | + |
| 5 | +import fire |
3 | 6 | from monai.bundle.config_parser import ConfigParser
|
4 | 7 | from pytorch_lightning import seed_everything
|
5 | 8 |
|
| 9 | +from lighter.system import LighterSystem |
6 | 10 | from lighter.utils.dynamic_imports import import_module_from_path
|
7 | 11 |
|
| 12 | +CONFIG_STRUCTURE = {"project": None, "system": {}, "trainer": {}, "args": {}, "vars": {}} |
| 13 | +TRAINER_METHOD_NAMES = ["fit", "validate", "test", "predict", "lr_find", "scale_batch_size"] |
| 14 | + |
| 15 | + |
| 16 | +def cli() -> None: |
| 17 | + """Defines the command line interface for running lightning trainer's methods.""" |
| 18 | + commands = {method: partial(run, method) for method in TRAINER_METHOD_NAMES} |
| 19 | + fire.Fire(commands) |
| 20 | + |
8 | 21 |
|
9 | 22 | def parse_config(**kwargs) -> ConfigParser:
|
10 | 23 | """
|
11 | 24 | Parses configuration files and updates the provided parser
|
12 | 25 | with given keyword arguments. Returns an updated parser object.
|
13 | 26 |
|
14 | 27 | Args:
|
15 |
| - **kwargs (dict): Keyword arguments containing configuration data. |
16 |
| - config_file (str): Path to the main configuration file. |
17 |
| - args_file (str, optional): Path to secondary configuration file for additional arguments. |
18 |
| - Additional key-value pairs can also be provided to be added or updated in the parser. |
19 |
| -
|
| 28 | + **kwargs (dict): Keyword arguments containing 'config' and, optionally, config overrides. |
20 | 29 | Returns:
|
21 |
| - An instance of ConfigParser with parsed and merged configuration data. |
| 30 | + An instance of ConfigParser with configuration and overrides merged and parsed. |
22 | 31 | """
|
| 32 | + # Ensure a config file is specified. |
| 33 | + config = kwargs.pop("config", None) |
| 34 | + if config is None: |
| 35 | + raise ValueError("'--config' not specified. Please provide a valid configuration file.") |
| 36 | + |
| 37 | + # Read the config file and update it with overrides. |
| 38 | + parser = ConfigParser(CONFIG_STRUCTURE, globals=False) |
| 39 | + parser.read_config(config) |
| 40 | + parser.update(kwargs) |
| 41 | + return parser |
23 | 42 |
|
24 |
| - # Check that a config file is specified. |
25 |
| - if "config_file" not in kwargs: |
26 |
| - raise ValueError("--config_file not specified. Exiting.") |
27 | 43 |
|
28 |
| - # Parse the config file(s). |
29 |
| - parser = ConfigParser() |
30 |
| - parser.read_config(kwargs.pop("config_file")) |
31 |
| - parser.update(pairs=kwargs) |
| 44 | +def validate_config(parser: ConfigParser) -> None: |
| 45 | + """ |
| 46 | + Validates the configuration parser against predefined structures and allowed method names. |
32 | 47 |
|
33 |
| - # Import the project folder as a module, if specified. |
34 |
| - project = parser.get("project", None) |
35 |
| - if project is not None: |
36 |
| - import_module_from_path("project", project) |
| 48 | + This function checks if the keys in the top-level of the configuration parser are valid according to the |
| 49 | + CONFIG_STRUCTURE. It also verifies that the 'args' section of the configuration only contains keys that |
| 50 | + correspond to valid trainer method names as defined in TRAINER_METHOD_NAMES. |
37 | 51 |
|
38 |
| - return parser |
| 52 | + Args: |
| 53 | + parser (ConfigParser): The configuration parser instance to validate. |
| 54 | +
|
| 55 | + Raises: |
| 56 | + ValueError: If there are invalid keys in the top-level configuration. |
| 57 | + ValueError: If there are invalid method names specified in the 'args' section. |
| 58 | + """ |
| 59 | + # Validate parser keys against structure |
| 60 | + root_keys = parser.get().keys() |
| 61 | + invalid_root_keys = set(root_keys) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"} |
| 62 | + if invalid_root_keys: |
| 63 | + raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {CONFIG_STRUCTURE.keys()}") |
| 64 | + |
| 65 | + # Validate that 'args' contains only valid trainer method names. |
| 66 | + args_keys = parser.get("args", {}).keys() |
| 67 | + invalid_args_keys = set(args_keys) - set(TRAINER_METHOD_NAMES) |
| 68 | + if invalid_args_keys: |
| 69 | + raise ValueError(f"Invalid trainer method in 'args': {invalid_args_keys}. Allowed methods are: {TRAINER_METHOD_NAMES}") |
39 | 70 |
|
40 | 71 |
|
41 |
| -def run_trainer_method(method: Dict, **kwargs: Any): |
42 |
| - """Call monai.bundle.run() on a Trainer method. If a project path |
43 |
| - is defined in the config file(s), import it. |
| 72 | +def run(method: str, **kwargs: Any): |
| 73 | + """Run the trainer method. |
44 | 74 |
|
45 | 75 | Args:
|
46 |
| - method (str): name of the Trainer method to run. ["fit", "validate", "test", "predict", "tune"]. |
47 |
| - **kwargs (Any): keyword arguments passed to the `monai.bundle.run` function. |
| 76 | + method (str): name of the trainer method to run. |
| 77 | + **kwargs (Any): keyword arguments that include 'config' and specific config overrides passed to `parse_config()`. |
48 | 78 | """
|
49 |
| - # Sets the random seed to `PL_GLOBAL_SEED` env variable. If not specified, it picks a random seed. |
50 | 79 | seed_everything()
|
51 | 80 |
|
52 |
| - # Parse the config file(s). |
| 81 | + # Parse and validate the config. |
53 | 82 | parser = parse_config(**kwargs)
|
| 83 | + validate_config(parser) |
54 | 84 |
|
55 |
| - # Get trainer and system |
56 |
| - trainer = parser.get_parsed_content("trainer") |
| 85 | + # Import the project folder as a module, if specified. |
| 86 | + project = parser.get_parsed_content("project") |
| 87 | + if project is not None: |
| 88 | + import_module_from_path("project", project) |
| 89 | + |
| 90 | + # Get the main components from the parsed config. |
57 | 91 | system = parser.get_parsed_content("system")
|
| 92 | + trainer = parser.get_parsed_content("trainer") |
| 93 | + trainer_method_args = parser.get_parsed_content(f"args#{method}", default={}) |
58 | 94 |
|
| 95 | + # Checks |
| 96 | + if not isinstance(system, LighterSystem): |
| 97 | + raise ValueError(f"Expected 'system' to be an instance of LighterSystem, got {system.__class__.__name__}.") |
| 98 | + if not hasattr(trainer, method): |
| 99 | + raise ValueError(f"{trainer.__class__.__name__} has no method named '{method}'.") |
| 100 | + if any("dataloaders" in key or "datamodule" in key for key in trainer_method_args): |
| 101 | + raise ValueError("All dataloaders should be defined as part of the LighterSystem, not passed as method arguments.") |
| 102 | + |
| 103 | + # Save the config to checkpoints under "hyper_parameters" and log it if a logger is defined. |
59 | 104 | config = parser.get()
|
60 |
| - config.pop("_meta_") |
61 |
| - # Save the config to model checkpoints under the "hyper_parameters" key. |
| 105 | + config.pop("_meta_") # MONAI Bundle adds this automatically, remove it. |
62 | 106 | system.save_hyperparameters(config)
|
63 |
| - # Log the config. |
64 | 107 | if trainer.logger is not None:
|
65 | 108 | trainer.logger.log_hyperparams(config)
|
66 | 109 |
|
67 |
| - # Run the Trainer method. |
68 |
| - if not hasattr(trainer, method): |
69 |
| - raise ValueError(f"Trainer has no method named {method}.") |
70 |
| - getattr(trainer, method)(system) |
| 110 | + # Run the trainer method. |
| 111 | + getattr(trainer, method)(system, **trainer_method_args) |
0 commit comments