Skip to content

Commit af081df

Browse files
committed
adapted asserts to new config structure
1 parent c71981d commit af081df

File tree

3 files changed

+82
-87
lines changed

3 files changed

+82
-87
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ repos:
77
rev: 6.1.0
88
hooks:
99
- id: flake8
10-
args: ["--ignore=E501,E203,E402,F401"]
10+
args: ["--ignore=E501,E203,E402,F401,W503"]
1111
- repo: https://github.com/pre-commit/pre-commit-hooks
1212
rev: v4.4.0
1313
hooks:

bloom/tests/context.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5+
import models
6+
import load_data
7+
import config

bloom/tests/test_configs.py

+74-86
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,76 @@
1-
import os
2-
import sys
3-
import yaml
4-
5-
default_config = "default_config.yaml"
6-
7-
# ============================== read yaml ==============================================
8-
9-
10-
def read_config_file(config_filepath: str):
11-
"""
12-
reads the configuration from the YAML file specified
13-
returns the config as dictionary object
14-
15-
Args:
16-
config_filepath: path to the YAML file containing the configuration
17-
18-
"""
19-
if not (config_filepath.lower().endswith((".yaml", ".yml"))):
20-
print("Please provide a path to a YAML file.")
21-
quit()
22-
with open(config_filepath, "r") as config_file:
23-
config = yaml.safe_load(config_file)
24-
return config
25-
26-
27-
def get_available_loss(config):
28-
available_loss = []
29-
chosen_task = config["task"]["chosen"]
30-
if chosen_task == "regression":
31-
available_loss = config["loss_functions"]["regression"]["available"]
32-
else:
33-
available_loss = config["loss_functions"]["classification"]["available"]
34-
return available_loss
35-
36-
37-
def get_chosen_loss(config):
38-
chosen_task = config["task"]["chosen"]
39-
if chosen_task == "regression":
40-
return config["loss_functions"]["regression"]["chosen"]
41-
else:
42-
return config["loss_functions"]["classification"]["chosen"]
43-
44-
45-
def get_available_datasets(config):
46-
return config["datasets"]["available"]
47-
48-
49-
def get_chosen_datasets(config):
50-
return config["datasets"]["chosen"]
51-
52-
53-
def get_available_tasks(config):
54-
return config["task"]["available"]
55-
56-
57-
def get_chosen_task(config):
58-
return config["task"]["chosen"]
59-
60-
61-
def get_wandb_active(config):
62-
return config["wandb"]["active_tracking"]
63-
64-
65-
def get_wandb_key(config):
66-
return config["wandb"]["login_key"]
1+
from context import config
2+
3+
4+
def check_chosen_in_available(conf_file):
5+
"""check that chosen options are actually available (for that task)"""
6+
7+
# regresssion
8+
assert config.Config.get_chosen_datasets(
9+
conf_file, "regression"
10+
) in config.Config.get_available_datasets(
11+
conf_file, "regression"
12+
), "chosen dataset for regression is not available"
13+
14+
assert config.Config.get_chosen_optimizers(
15+
conf_file, "regression"
16+
) in config.Config.get_available_optimizers(
17+
conf_file, "regression"
18+
), "chosen optimizer for regression is not available"
19+
20+
assert config.Config.get_chosen_loss(
21+
conf_file, "regression"
22+
) in config.Config.get_available_loss(
23+
conf_file, "regression"
24+
), "chosen loss for regression is not available"
25+
26+
# classification
27+
assert config.Config.get_chosen_datasets(
28+
conf_file, "classification"
29+
) in config.Config.get_available_datasets(
30+
conf_file, "classification"
31+
), "chosen dataset for classification is not available"
32+
33+
assert config.Config.get_chosen_optimizers(
34+
conf_file, "classification"
35+
) in config.Config.get_available_optimizers(
36+
conf_file, "classification"
37+
), "chosen optimizer for classification is not available"
38+
39+
assert config.Config.get_chosen_loss(
40+
conf_file, "classification"
41+
) in config.Config.get_available_loss(
42+
conf_file, "classification"
43+
), "chosen loss for classification is not available"
44+
45+
46+
def check_hyperparams(conf_file):
47+
"""Check that necessary hyperparams are defined"""
48+
hyperparams = config.Config.get_hyperparams(conf_file)
6749

50+
assert (
51+
hyperparams["learning_rate"] is not None
52+
and hyperparams["learning_rate"] <= 1
53+
and hyperparams["learning_rate"] >= 1e-10
54+
), "you sure about that?"
6855

69-
# ============================== asserts ==============================================
56+
assert hyperparams["batch_size"] is not None
7057

58+
assert (
59+
hyperparams["num_workers"] is not None
60+
and hyperparams["num_workers"] >= 1
61+
and hyperparams["num_workers"] <= 32
62+
), "you sure about that?"
7163

72-
def check_chosen_in_available(config):
73-
assert get_chosen_task(config) in get_available_tasks(
74-
config
75-
), "chosen task is not available"
76-
assert get_chosen_datasets(config) in get_available_datasets(
77-
config
78-
), "chosen dataset is not available"
79-
assert get_chosen_loss(config) in get_available_loss(
80-
config
81-
), "chosen loss is not available"
64+
assert (
65+
hyperparams["num_epochs"] is not None
66+
and hyperparams["num_epochs"] >= 1
67+
and hyperparams["num_epochs"] <= 250
68+
), "you sure about that?"
8269

8370

84-
def wandb_config(config):
85-
wandb_key = get_wandb_key(config)
71+
def wandb_config(conf_file):
72+
"""check that the wandb key is set up right"""
73+
wandb_key = config.Config.get_wandb_key(conf_file)
8674
assert wandb_key is not None, "put wandb API key in yaml file if you want tracking"
8775
assert isinstance(
8876
wandb_key, str
@@ -93,15 +81,15 @@ def wandb_config(config):
9381

9482

9583
def main():
96-
config_file_location = os.path.join(
97-
os.path.dirname(__file__), "..", "central", default_config
98-
)
99-
config = read_config_file(config_file_location)
84+
"""run asserts for config file"""
85+
86+
actual_config = config.Config(config.Config.DEFAULT_CONFIG)
10087

101-
if get_wandb_active(config):
102-
wandb_config(config)
88+
if config.Config.get_wand_active(actual_config):
89+
wandb_config(actual_config)
10390

104-
check_chosen_in_available(config)
91+
check_chosen_in_available(actual_config)
92+
check_hyperparams(actual_config)
10593

10694

10795
if __name__ == "__main__":

0 commit comments

Comments
 (0)