1
1
from typing import Any
2
2
3
+ import copy
3
4
from functools import partial
4
5
5
6
import fire
6
7
from monai .bundle .config_parser import ConfigParser
7
- from pytorch_lightning import seed_everything
8
+ from pytorch_lightning import Trainer , seed_everything
9
+ from pytorch_lightning .tuner import Tuner
8
10
9
11
from lighter .system import LighterSystem
10
12
from lighter .utils .dynamic_imports import import_module_from_path
11
13
12
- CONFIG_STRUCTURE = {"project" : None , "system" : {}, "trainer" : {}, "args" : {}, "vars" : {}}
13
- TRAINER_METHOD_NAMES = ["fit" , "validate" , "test" , "predict" , "lr_find" , "scale_batch_size" ]
14
+ CONFIG_STRUCTURE = {
15
+ "project" : None ,
16
+ "vars" : {},
17
+ "args" : {
18
+ # Keys - names of the methods; values - arguments passed to them.
19
+ "fit" : {},
20
+ "validate" : {},
21
+ "test" : {},
22
+ "predict" : {},
23
+ "lr_find" : {},
24
+ "scale_batch_size" : {},
25
+ },
26
+ "system" : {},
27
+ "trainer" : {},
28
+ }
14
29
15
30
16
31
def cli () -> None :
17
32
"""Defines the command line interface for running lightning trainer's methods."""
18
- commands = {method : partial (run , method ) for method in TRAINER_METHOD_NAMES }
19
- fire .Fire (commands )
33
+ commands = {method : partial (run , method ) for method in CONFIG_STRUCTURE ["args" ]}
34
+ try :
35
+ fire .Fire (commands )
36
+ except TypeError as e :
37
+ if "run() takes 1 positional argument but" in str (e ):
38
+ raise ValueError (
39
+ "Ensure that only one command is run at a time (e.g., 'lighter fit') and that "
40
+ "other command line arguments start with '--' (e.g., '--config', '--system#batch_size=1')."
41
+ ) from e
42
+ raise
20
43
21
44
22
45
def parse_config (** kwargs ) -> ConfigParser :
@@ -29,25 +52,24 @@ def parse_config(**kwargs) -> ConfigParser:
29
52
Returns:
30
53
An instance of ConfigParser with configuration and overrides merged and parsed.
31
54
"""
32
- # Ensure a config file is specified.
33
55
config = kwargs .pop ("config" , None )
34
56
if config is None :
35
57
raise ValueError ("'--config' not specified. Please provide a valid configuration file." )
36
58
37
- # Read the config file and update it with overrides.
38
- parser = ConfigParser (CONFIG_STRUCTURE , globals = False )
39
- parser .read_config (config )
59
+ # Create a deep copy to ensure the original structure remains unaltered by ConfigParser.
60
+ structure = copy .deepcopy (CONFIG_STRUCTURE )
61
+ # Initialize the parser with the predefined structure.
62
+ parser = ConfigParser (structure , globals = False )
63
+ # Update the parser with the configuration file.
64
+ parser .update (parser .load_config_files (config ))
65
+ # Update the parser with the provided cli arguments.
40
66
parser .update (kwargs )
41
67
return parser
42
68
43
69
44
70
def validate_config (parser : ConfigParser ) -> None :
45
71
"""
46
- Validates the configuration parser against predefined structures and allowed method names.
47
-
48
- This function checks if the keys in the top-level of the configuration parser are valid according to the
49
- CONFIG_STRUCTURE. It also verifies that the 'args' section of the configuration only contains keys that
50
- correspond to valid trainer method names as defined in TRAINER_METHOD_NAMES.
72
+ Validates the configuration parser against predefined structure.
51
73
52
74
Args:
53
75
parser (ConfigParser): The configuration parser instance to validate.
@@ -56,20 +78,28 @@ def validate_config(parser: ConfigParser) -> None:
56
78
ValueError: If there are invalid keys in the top-level configuration.
57
79
ValueError: If there are invalid method names specified in the 'args' section.
58
80
"""
59
- # Validate parser keys against structure
60
- root_keys = parser .get ().keys ()
61
- invalid_root_keys = set (root_keys ) - set (CONFIG_STRUCTURE .keys ()) - {"_meta_" , "_requires_" }
81
+ invalid_root_keys = set (parser .get ()) - set (CONFIG_STRUCTURE )
62
82
if invalid_root_keys :
63
- raise ValueError (f"Invalid top-level config keys: { invalid_root_keys } . Allowed keys: { CONFIG_STRUCTURE . keys () } " )
83
+ raise ValueError (f"Invalid top-level config keys: { invalid_root_keys } . Allowed keys: { list ( CONFIG_STRUCTURE ) } . " )
64
84
65
- # Validate that 'args' contains only valid trainer method names.
66
- args_keys = parser .get ("args" , {}).keys ()
67
- invalid_args_keys = set (args_keys ) - set (TRAINER_METHOD_NAMES )
85
+ invalid_args_keys = set (parser .get ("args" )) - set (CONFIG_STRUCTURE ["args" ])
68
86
if invalid_args_keys :
69
- raise ValueError (f"Invalid trainer method in 'args': { invalid_args_keys } . Allowed methods are: { TRAINER_METHOD_NAMES } " )
70
-
71
-
72
- def run (method : str , ** kwargs : Any ):
87
+ raise ValueError (f"Invalid key in 'args': { invalid_args_keys } . Allowed keys: { list (CONFIG_STRUCTURE ['args' ])} ." )
88
+
89
+ typechecks = {
90
+ "project" : (str , type (None )),
91
+ "vars" : dict ,
92
+ "system" : dict ,
93
+ "trainer" : dict ,
94
+ "args" : dict ,
95
+ ** {f"args#{ k } " : dict for k in CONFIG_STRUCTURE ["args" ]},
96
+ }
97
+ for key , dtype in typechecks .items ():
98
+ if not isinstance (parser .get (key ), dtype ):
99
+ raise ValueError (f"Invalid value for key '{ key } '. Expected a { dtype } ." )
100
+
101
+
102
+ def run (method : str , ** kwargs : Any ) -> None :
73
103
"""Run the trainer method.
74
104
75
105
Args:
@@ -82,30 +112,36 @@ def run(method: str, **kwargs: Any):
82
112
parser = parse_config (** kwargs )
83
113
validate_config (parser )
84
114
85
- # Import the project folder as a module, if specified .
115
+ # Project. If specified, the give path is imported as a module.
86
116
project = parser .get_parsed_content ("project" )
87
117
if project is not None :
88
118
import_module_from_path ("project" , project )
89
119
90
- # Get the main components from the parsed config.
120
+ # System
91
121
system = parser .get_parsed_content ("system" )
122
+ if not isinstance (system , LighterSystem ):
123
+ raise ValueError ("Expected 'system' to be an instance of 'LighterSystem'" )
124
+
125
+ # Trainer
92
126
trainer = parser .get_parsed_content ("trainer" )
93
- trainer_method_args = parser .get_parsed_content (f"args#{ method } " , default = {})
127
+ if not isinstance (trainer , Trainer ):
128
+ raise ValueError ("Expected 'trainer' to be an instance of PyTorch Lightning 'Trainer'" )
94
129
95
- # Checks
96
- if not isinstance (system , LighterSystem ):
97
- raise ValueError (f"Expected 'system' to be an instance of LighterSystem, got { system .__class__ .__name__ } ." )
98
- if not hasattr (trainer , method ):
99
- raise ValueError (f"{ trainer .__class__ .__name__ } has no method named '{ method } '." )
100
- if any ("dataloaders" in key or "datamodule" in key for key in trainer_method_args ):
101
- raise ValueError ("All dataloaders should be defined as part of the LighterSystem, not passed as method arguments." )
102
-
103
- # Save the config to checkpoints under "hyper_parameters" and log it if a logger is defined.
104
- config = parser .get ()
105
- config .pop ("_meta_" ) # MONAI Bundle adds this automatically, remove it.
106
- system .save_hyperparameters (config )
107
- if trainer .logger is not None :
108
- trainer .logger .log_hyperparams (config )
130
+ # Trainer/Tuner method arguments.
131
+ method_args = parser .get_parsed_content (f"args#{ method } " )
132
+ if any ("dataloaders" in key or "datamodule" in key for key in method_args ):
133
+ raise ValueError ("Datasets are defined within the 'system', not passed in `args`." )
109
134
110
- # Run the trainer method.
111
- getattr (trainer , method )(system , ** trainer_method_args )
135
+ # Save the config to checkpoints under "hyper_parameters". Log it if a logger is defined.
136
+ system .save_hyperparameters (parser .get ())
137
+ if trainer .logger is not None :
138
+ trainer .logger .log_hyperparams (parser .get ())
139
+
140
+ # Run the trainer/tuner method.
141
+ if hasattr (trainer , method ):
142
+ getattr (trainer , method )(system , ** method_args )
143
+ elif hasattr (Tuner , method ):
144
+ tuner = Tuner (trainer )
145
+ getattr (tuner , method )(system , ** method_args )
146
+ else :
147
+ raise ValueError (f"Method '{ method } ' is not a valid Trainer or Tuner method [{ list (CONFIG_STRUCTURE ['args' ])} ]." )
0 commit comments