1
- import os
2
- import sys
3
- import yaml
4
-
5
- default_config = "default_config.yaml"
6
-
7
- # ============================== read yaml ==============================================
8
-
9
-
10
- def read_config_file (config_filepath : str ):
11
- """
12
- reads the configuration from the YAML file specified
13
- returns the config as dictionary object
14
-
15
- Args:
16
- config_filepath: path to the YAML file containing the configuration
17
-
18
- """
19
- if not (config_filepath .lower ().endswith ((".yaml" , ".yml" ))):
20
- print ("Please provide a path to a YAML file." )
21
- quit ()
22
- with open (config_filepath , "r" ) as config_file :
23
- config = yaml .safe_load (config_file )
24
- return config
25
-
26
-
27
- def get_available_loss (config ):
28
- available_loss = []
29
- chosen_task = config ["task" ]["chosen" ]
30
- if chosen_task == "regression" :
31
- available_loss = config ["loss_functions" ]["regression" ]["available" ]
32
- else :
33
- available_loss = config ["loss_functions" ]["classification" ]["available" ]
34
- return available_loss
35
-
36
-
37
- def get_chosen_loss (config ):
38
- chosen_task = config ["task" ]["chosen" ]
39
- if chosen_task == "regression" :
40
- return config ["loss_functions" ]["regression" ]["chosen" ]
41
- else :
42
- return config ["loss_functions" ]["classification" ]["chosen" ]
43
-
44
-
45
- def get_available_datasets (config ):
46
- return config ["datasets" ]["available" ]
47
-
48
-
49
- def get_chosen_datasets (config ):
50
- return config ["datasets" ]["chosen" ]
51
-
52
-
53
- def get_available_tasks (config ):
54
- return config ["task" ]["available" ]
55
-
56
-
57
- def get_chosen_task (config ):
58
- return config ["task" ]["chosen" ]
59
-
60
-
61
- def get_wandb_active (config ):
62
- return config ["wandb" ]["active_tracking" ]
63
-
64
-
65
- def get_wandb_key (config ):
66
- return config ["wandb" ]["login_key" ]
1
+ from context import config
2
+
3
+
4
+ def check_chosen_in_available (conf_file ):
5
+ """check that chosen options are actually available (for that task)"""
6
+
7
+ # regresssion
8
+ assert config .Config .get_chosen_datasets (
9
+ conf_file , "regression"
10
+ ) in config .Config .get_available_datasets (
11
+ conf_file , "regression"
12
+ ), "chosen dataset for regression is not available"
13
+
14
+ assert config .Config .get_chosen_optimizers (
15
+ conf_file , "regression"
16
+ ) in config .Config .get_available_optimizers (
17
+ conf_file , "regression"
18
+ ), "chosen optimizer for regression is not available"
19
+
20
+ assert config .Config .get_chosen_loss (
21
+ conf_file , "regression"
22
+ ) in config .Config .get_available_loss (
23
+ conf_file , "regression"
24
+ ), "chosen loss for regression is not available"
25
+
26
+ # classification
27
+ assert config .Config .get_chosen_datasets (
28
+ conf_file , "classification"
29
+ ) in config .Config .get_available_datasets (
30
+ conf_file , "classification"
31
+ ), "chosen dataset for classification is not available"
32
+
33
+ assert config .Config .get_chosen_optimizers (
34
+ conf_file , "classification"
35
+ ) in config .Config .get_available_optimizers (
36
+ conf_file , "classification"
37
+ ), "chosen optimizer for classification is not available"
38
+
39
+ assert config .Config .get_chosen_loss (
40
+ conf_file , "classification"
41
+ ) in config .Config .get_available_loss (
42
+ conf_file , "classification"
43
+ ), "chosen loss for classification is not available"
44
+
45
+
46
+ def check_hyperparams (conf_file ):
47
+ """Check that necessary hyperparams are defined"""
48
+ hyperparams = config .Config .get_hyperparams (conf_file )
67
49
50
+ assert (
51
+ hyperparams ["learning_rate" ] is not None
52
+ and hyperparams ["learning_rate" ] <= 1
53
+ and hyperparams ["learning_rate" ] >= 1e-10
54
+ ), "you sure about that?"
68
55
69
- # ============================== asserts ==============================================
56
+ assert hyperparams [ "batch_size" ] is not None
70
57
58
+ assert (
59
+ hyperparams ["num_workers" ] is not None
60
+ and hyperparams ["num_workers" ] >= 1
61
+ and hyperparams ["num_workers" ] <= 32
62
+ ), "you sure about that?"
71
63
72
- def check_chosen_in_available (config ):
73
- assert get_chosen_task (config ) in get_available_tasks (
74
- config
75
- ), "chosen task is not available"
76
- assert get_chosen_datasets (config ) in get_available_datasets (
77
- config
78
- ), "chosen dataset is not available"
79
- assert get_chosen_loss (config ) in get_available_loss (
80
- config
81
- ), "chosen loss is not available"
64
+ assert (
65
+ hyperparams ["num_epochs" ] is not None
66
+ and hyperparams ["num_epochs" ] >= 1
67
+ and hyperparams ["num_epochs" ] <= 250
68
+ ), "you sure about that?"
82
69
83
70
84
- def wandb_config (config ):
85
- wandb_key = get_wandb_key (config )
71
+ def wandb_config (conf_file ):
72
+ """check that the wandb key is set up right"""
73
+ wandb_key = config .Config .get_wandb_key (conf_file )
86
74
assert wandb_key is not None , "put wandb API key in yaml file if you want tracking"
87
75
assert isinstance (
88
76
wandb_key , str
@@ -93,15 +81,15 @@ def wandb_config(config):
93
81
94
82
95
83
def main ():
96
- config_file_location = os .path .join (
97
- os .path .dirname (__file__ ), ".." , "central" , default_config
98
- )
99
- config = read_config_file (config_file_location )
84
+ """run asserts for config file"""
85
+
86
+ actual_config = config .Config (config .Config .DEFAULT_CONFIG )
100
87
101
- if get_wandb_active ( config ):
102
- wandb_config (config )
88
+ if config . Config . get_wand_active ( actual_config ):
89
+ wandb_config (actual_config )
103
90
104
- check_chosen_in_available (config )
91
+ check_chosen_in_available (actual_config )
92
+ check_hyperparams (actual_config )
105
93
106
94
107
95
if __name__ == "__main__" :
0 commit comments