Skip to content

Commit a65eba6

Browse files
committed
Fix bug configuring DATA field in the YAML when selecting a pretrained model
1 parent a93db36 commit a65eba6

File tree

1 file changed

+104
-90
lines changed

1 file changed

+104
-90
lines changed

ui_utils.py

Lines changed: 104 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from packaging.version import Version
1919
from bs4 import BeautifulSoup
2020
from urllib.request import urlopen
21+
import collections.abc
2122

2223
from PySide6 import QtCore
2324
from PySide6.QtCore import QObject, QThread
@@ -100,7 +101,10 @@ def set_folder_checked(data_constraints, key, sample_info):
100101
if "data_constraints" not in main_window.cfg.settings["wizard_answers"]:
101102
main_window.cfg.settings["wizard_answers"]["data_constraints"] = data_constraints
102103
else:
103-
main_window.cfg.settings["wizard_answers"]["data_constraints"].update(data_constraints)
104+
main_window.cfg.settings["wizard_answers"]["data_constraints"] = update_dict(
105+
main_window.cfg.settings["wizard_answers"]["data_constraints"],
106+
data_constraints
107+
)
104108
main_window.cfg.settings["wizard_answers"][f"CHECKED {key}"] = 1
105109
set_text(main_window.ui.wizard_data_checked_label, "<span style='color:#04aa6d'>Data checked!</span>")
106110

@@ -2493,97 +2497,100 @@ def create_yaml_file(main_window):
24932497

24942498
# Model definition
24952499
biapy_config['MODEL'] = {}
2496-
model_name = main_window.cfg.translate_model_names(get_text(main_window.ui.MODEL__ARCHITECTURE__INPUT), get_text(main_window.ui.PROBLEM__NDIM__INPUT))
2497-
biapy_config['MODEL']['ARCHITECTURE'] = model_name
2498-
if model_name in ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet']:
2499-
try:
2500-
biapy_config['MODEL']['FEATURE_MAPS'] = ast.literal_eval(get_text(main_window.ui.MODEL__FEATURE_MAPS__INPUT))
2501-
except:
2502-
main_window.dialog_exec("There was an error in model's feature maps field (MODEL.FEATURE_MAPS). Please check its syntax!", reason="error")
2503-
return True, False
2504-
try:
2505-
biapy_config['MODEL']['DROPOUT_VALUES'] = ast.literal_eval(get_text(main_window.ui.MODEL__DROPOUT_VALUES__INPUT))
2506-
except:
2507-
main_window.dialog_exec("There was an error in model's dropout values field (MODEL.DROPOUT_VALUES). Please check its syntax!", reason="error")
2508-
return True, False
2509-
if get_text(main_window.ui.MODEL__NORMALIZATION__INPUT) != "bn":
2510-
biapy_config['MODEL']['NORMALIZATION'] = get_text(main_window.ui.MODEL__NORMALIZATION__INPUT)
2511-
if int(get_text(main_window.ui.MODEL__KERNEL_SIZE__INPUT)) != 3:
2512-
biapy_config['MODEL']['KERNEL_SIZE'] = int(get_text(main_window.ui.MODEL__KERNEL_SIZE__INPUT))
2513-
if get_text(main_window.ui.MODEL__UPSAMPLE_LAYER__INPUT) != "convtranspose":
2514-
biapy_config['MODEL']['UPSAMPLE_LAYER'] = get_text(main_window.ui.MODEL__UPSAMPLE_LAYER__INPUT)
2515-
if get_text(main_window.ui.MODEL__ACTIVATION__INPUT) != 'elu':
2516-
biapy_config['MODEL']['ACTIVATION'] = get_text(main_window.ui.MODEL__ACTIVATION__INPUT)
2517-
if get_text(main_window.ui.MODEL__LAST_ACTIVATION__INPUT) != 'sigmoid':
2518-
biapy_config['MODEL']['LAST_ACTIVATION'] = get_text(main_window.ui.MODEL__LAST_ACTIVATION__INPUT)
2519-
if get_text(main_window.ui.PROBLEM__NDIM__INPUT) == "3D":
2520-
try:
2521-
biapy_config['MODEL']['Z_DOWN'] = ast.literal_eval(get_text(main_window.ui.MODEL__Z_DOWN__INPUT))
2522-
except:
2523-
main_window.dialog_exec("There was an error in model's z axis downsampling field (MODEL.Z_DOWN). Please check its syntax!", reason="error")
2524-
return True, False
2525-
if get_text(main_window.ui.MODEL__ISOTROPY__INPUT) != "[True, True, True, True, True]":
2500+
if not (get_text(main_window.ui.LOAD_PRETRAINED_MODEL__INPUT) == "Yes"
2501+
and get_text(main_window.ui.MODEL__SOURCE__INPUT) != "I have a model trained with BiaPy"
2502+
):
2503+
model_name = main_window.cfg.translate_model_names(get_text(main_window.ui.MODEL__ARCHITECTURE__INPUT), get_text(main_window.ui.PROBLEM__NDIM__INPUT))
2504+
biapy_config['MODEL']['ARCHITECTURE'] = model_name
2505+
if model_name in ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet']:
25262506
try:
2527-
biapy_config['MODEL']['ISOTROPY'] = ast.literal_eval(get_text(main_window.ui.MODEL__ISOTROPY__INPUT))
2507+
biapy_config['MODEL']['FEATURE_MAPS'] = ast.literal_eval(get_text(main_window.ui.MODEL__FEATURE_MAPS__INPUT))
25282508
except:
2529-
main_window.dialog_exec("There was an error in model's isotropy field (MODEL.ISOTROPY). Please check its syntax!", reason="error")
2509+
main_window.dialog_exec("There was an error in model's feature maps field (MODEL.FEATURE_MAPS). Please check its syntax!", reason="error")
25302510
return True, False
2531-
if get_text(main_window.ui.MODEL__LAGER_IO__INPUT) == "Yes":
2532-
biapy_config['MODEL']['LAGER_IO'] = True
2533-
if workflow_key_name == "SUPER_RESOLUTION" and get_text(main_window.ui.PROBLEM__NDIM__INPUT) == "3D": # SR
2534-
r = "pre" if get_text(main_window.ui.MODEL__UNET_SR_UPSAMPLE_POSITION__INPUT) == "Before model" else "post"
2535-
biapy_config['MODEL']['UNET_SR_UPSAMPLE_POSITION'] = r
2536-
elif model_name in ["unetr", "mae", "ViT"]:
2537-
biapy_config['MODEL']['VIT_TOKEN_SIZE'] = int(get_text(main_window.ui.MODEL__VIT_TOKEN_SIZE__INPUT))
2538-
biapy_config['MODEL']['VIT_EMBED_DIM'] = int(get_text(main_window.ui.MODEL__VIT_EMBED_DIM__INPUT))
2539-
biapy_config['MODEL']['VIT_NUM_LAYERS'] = int(get_text(main_window.ui.MODEL__VIT_NUM_LAYERS__INPUT))
2540-
biapy_config['MODEL']['VIT_MLP_RATIO'] = get_text(main_window.ui.MODEL__VIT_MLP_RATIO__INPUT)
2541-
biapy_config['MODEL']['VIT_NUM_HEADS'] = int(get_text(main_window.ui.MODEL__VIT_NUM_HEADS__INPUT))
2542-
biapy_config['MODEL']['VIT_NORM_EPS'] = get_text(main_window.ui.MODEL__VIT_NORM_EPS__INPUT)
2543-
2544-
# UNETR
2545-
if model_name in "unetr":
2546-
biapy_config['MODEL']['UNETR_VIT_HIDD_MULT'] = int(get_text(main_window.ui.MODEL__UNETR_VIT_HIDD_MULT__INPUT))
2547-
biapy_config['MODEL']['UNETR_VIT_NUM_FILTERS'] = int(get_text(main_window.ui.MODEL__UNETR_VIT_NUM_FILTERS__INPUT))
2548-
biapy_config['MODEL']['UNETR_DEC_ACTIVATION'] = get_text(main_window.ui.MODEL__UNETR_DEC_ACTIVATION__INPUT)
2549-
biapy_config['MODEL']['UNETR_DEC_KERNEL_SIZE'] = int(get_text(main_window.ui.MODEL__UNETR_DEC_KERNEL_SIZE__INPUT))
2550-
2551-
# MAE
2552-
if model_name in "mae":
2553-
biapy_config['MODEL']['MAE_MASK_TYPE'] = get_text(main_window.ui.MODEL__MAE_MASK_TYPE__INPUT)
2554-
if get_text(main_window.ui.MODEL__MAE_MASK_TYPE__INPUT) == "random":
2555-
biapy_config['MODEL']['MAE_MASK_RATIO'] = float(get_text(main_window.ui.MODEL__MAE_MASK_RATIO__INPUT))
2556-
biapy_config['MODEL']['MAE_DEC_HIDDEN_SIZE'] = int(get_text(main_window.ui.MODEL__MAE_DEC_HIDDEN_SIZE__INPUT))
2557-
biapy_config['MODEL']['MAE_DEC_NUM_LAYERS'] = int(get_text(main_window.ui.MODEL__MAE_DEC_NUM_LAYERS__INPUT))
2558-
biapy_config['MODEL']['MAE_DEC_NUM_HEADS'] = get_text(main_window.ui.MODEL__MAE_DEC_NUM_HEADS__INPUT)
2559-
biapy_config['MODEL']['MAE_DEC_MLP_DIMS'] = get_text(main_window.ui.MODEL__MAE_DEC_MLP_DIMS__INPUT)
2560-
2561-
# ConvNeXT
2562-
if model_name in "unext_v1":
25632511
try:
2564-
biapy_config['MODEL']['CONVNEXT_LAYERS'] = ast.literal_eval(get_text(main_window.ui.MODEL__CONVNEXT_LAYERS__INPUT))
2512+
biapy_config['MODEL']['DROPOUT_VALUES'] = ast.literal_eval(get_text(main_window.ui.MODEL__DROPOUT_VALUES__INPUT))
25652513
except:
2566-
main_window.dialog_exec("There was an error in model's convnext layers field (MODEL.CONVNEXT_LAYERS). Please check its syntax!", reason="error")
2514+
main_window.dialog_exec("There was an error in model's dropout values field (MODEL.DROPOUT_VALUES). Please check its syntax!", reason="error")
25672515
return True, False
2568-
biapy_config['MODEL']['CONVNEXT_SD_PROB'] = float(get_text(main_window.ui.MODEL__CONVNEXT_SD_PROB__INPUT))
2569-
biapy_config['MODEL']['CONVNEXT_LAYER_SCALE'] = float(get_text(main_window.ui.MODEL__CONVNEXT_LAYER_SCALE__INPUT))
2570-
biapy_config['MODEL']['CONVNEXT_STEM_K_SIZE'] = int(get_text(main_window.ui.MODEL__CONVNEXT_STEM_K_SIZE__INPUT))
2571-
2572-
if workflow_key_name in ["SEMANTIC_SEG","INSTANCE_SEG","DETECTION","CLASSIFICATION"]:
2573-
classes = 2
2574-
if workflow_key_name == "SEMANTIC_SEG" and int(get_text(main_window.ui.MODEL__N_CLASSES__INPUT)) != 2:
2575-
classes = int(get_text(main_window.ui.MODEL__N_CLASSES__INPUT))
2576-
elif workflow_key_name == "INSTANCE_SEG" and int(get_text(main_window.ui.MODEL__N_CLASSES__INST_SEG__INPUT)) != 2:
2577-
classes = int(get_text(main_window.ui.MODEL__N_CLASSES__INST_SEG__INPUT))
2578-
elif workflow_key_name == "DETECTION" and int(get_text(main_window.ui.MODEL__N_CLASSES__DET__INPUT)) != 2:
2579-
classes = int(get_text(main_window.ui.MODEL__N_CLASSES__DET__INPUT))
2580-
else: # CLASSIFICATION
2581-
classes = int(get_text(main_window.ui.MODEL__N_CLASSES__CLS__INPUT))
2582-
2583-
if classes == 1:
2516+
if get_text(main_window.ui.MODEL__NORMALIZATION__INPUT) != "bn":
2517+
biapy_config['MODEL']['NORMALIZATION'] = get_text(main_window.ui.MODEL__NORMALIZATION__INPUT)
2518+
if int(get_text(main_window.ui.MODEL__KERNEL_SIZE__INPUT)) != 3:
2519+
biapy_config['MODEL']['KERNEL_SIZE'] = int(get_text(main_window.ui.MODEL__KERNEL_SIZE__INPUT))
2520+
if get_text(main_window.ui.MODEL__UPSAMPLE_LAYER__INPUT) != "convtranspose":
2521+
biapy_config['MODEL']['UPSAMPLE_LAYER'] = get_text(main_window.ui.MODEL__UPSAMPLE_LAYER__INPUT)
2522+
if get_text(main_window.ui.MODEL__ACTIVATION__INPUT) != 'elu':
2523+
biapy_config['MODEL']['ACTIVATION'] = get_text(main_window.ui.MODEL__ACTIVATION__INPUT)
2524+
if get_text(main_window.ui.MODEL__LAST_ACTIVATION__INPUT) != 'sigmoid':
2525+
biapy_config['MODEL']['LAST_ACTIVATION'] = get_text(main_window.ui.MODEL__LAST_ACTIVATION__INPUT)
2526+
if get_text(main_window.ui.PROBLEM__NDIM__INPUT) == "3D":
2527+
try:
2528+
biapy_config['MODEL']['Z_DOWN'] = ast.literal_eval(get_text(main_window.ui.MODEL__Z_DOWN__INPUT))
2529+
except:
2530+
main_window.dialog_exec("There was an error in model's z axis downsampling field (MODEL.Z_DOWN). Please check its syntax!", reason="error")
2531+
return True, False
2532+
if get_text(main_window.ui.MODEL__ISOTROPY__INPUT) != "[True, True, True, True, True]":
2533+
try:
2534+
biapy_config['MODEL']['ISOTROPY'] = ast.literal_eval(get_text(main_window.ui.MODEL__ISOTROPY__INPUT))
2535+
except:
2536+
main_window.dialog_exec("There was an error in model's isotropy field (MODEL.ISOTROPY). Please check its syntax!", reason="error")
2537+
return True, False
2538+
if get_text(main_window.ui.MODEL__LAGER_IO__INPUT) == "Yes":
2539+
biapy_config['MODEL']['LAGER_IO'] = True
2540+
if workflow_key_name == "SUPER_RESOLUTION" and get_text(main_window.ui.PROBLEM__NDIM__INPUT) == "3D": # SR
2541+
r = "pre" if get_text(main_window.ui.MODEL__UNET_SR_UPSAMPLE_POSITION__INPUT) == "Before model" else "post"
2542+
biapy_config['MODEL']['UNET_SR_UPSAMPLE_POSITION'] = r
2543+
elif model_name in ["unetr", "mae", "ViT"]:
2544+
biapy_config['MODEL']['VIT_TOKEN_SIZE'] = int(get_text(main_window.ui.MODEL__VIT_TOKEN_SIZE__INPUT))
2545+
biapy_config['MODEL']['VIT_EMBED_DIM'] = int(get_text(main_window.ui.MODEL__VIT_EMBED_DIM__INPUT))
2546+
biapy_config['MODEL']['VIT_NUM_LAYERS'] = int(get_text(main_window.ui.MODEL__VIT_NUM_LAYERS__INPUT))
2547+
biapy_config['MODEL']['VIT_MLP_RATIO'] = get_text(main_window.ui.MODEL__VIT_MLP_RATIO__INPUT)
2548+
biapy_config['MODEL']['VIT_NUM_HEADS'] = int(get_text(main_window.ui.MODEL__VIT_NUM_HEADS__INPUT))
2549+
biapy_config['MODEL']['VIT_NORM_EPS'] = get_text(main_window.ui.MODEL__VIT_NORM_EPS__INPUT)
2550+
2551+
# UNETR
2552+
if model_name in "unetr":
2553+
biapy_config['MODEL']['UNETR_VIT_HIDD_MULT'] = int(get_text(main_window.ui.MODEL__UNETR_VIT_HIDD_MULT__INPUT))
2554+
biapy_config['MODEL']['UNETR_VIT_NUM_FILTERS'] = int(get_text(main_window.ui.MODEL__UNETR_VIT_NUM_FILTERS__INPUT))
2555+
biapy_config['MODEL']['UNETR_DEC_ACTIVATION'] = get_text(main_window.ui.MODEL__UNETR_DEC_ACTIVATION__INPUT)
2556+
biapy_config['MODEL']['UNETR_DEC_KERNEL_SIZE'] = int(get_text(main_window.ui.MODEL__UNETR_DEC_KERNEL_SIZE__INPUT))
2557+
2558+
# MAE
2559+
if model_name in "mae":
2560+
biapy_config['MODEL']['MAE_MASK_TYPE'] = get_text(main_window.ui.MODEL__MAE_MASK_TYPE__INPUT)
2561+
if get_text(main_window.ui.MODEL__MAE_MASK_TYPE__INPUT) == "random":
2562+
biapy_config['MODEL']['MAE_MASK_RATIO'] = float(get_text(main_window.ui.MODEL__MAE_MASK_RATIO__INPUT))
2563+
biapy_config['MODEL']['MAE_DEC_HIDDEN_SIZE'] = int(get_text(main_window.ui.MODEL__MAE_DEC_HIDDEN_SIZE__INPUT))
2564+
biapy_config['MODEL']['MAE_DEC_NUM_LAYERS'] = int(get_text(main_window.ui.MODEL__MAE_DEC_NUM_LAYERS__INPUT))
2565+
biapy_config['MODEL']['MAE_DEC_NUM_HEADS'] = get_text(main_window.ui.MODEL__MAE_DEC_NUM_HEADS__INPUT)
2566+
biapy_config['MODEL']['MAE_DEC_MLP_DIMS'] = get_text(main_window.ui.MODEL__MAE_DEC_MLP_DIMS__INPUT)
2567+
2568+
# ConvNeXT
2569+
if model_name in "unext_v1":
2570+
try:
2571+
biapy_config['MODEL']['CONVNEXT_LAYERS'] = ast.literal_eval(get_text(main_window.ui.MODEL__CONVNEXT_LAYERS__INPUT))
2572+
except:
2573+
main_window.dialog_exec("There was an error in model's convnext layers field (MODEL.CONVNEXT_LAYERS). Please check its syntax!", reason="error")
2574+
return True, False
2575+
biapy_config['MODEL']['CONVNEXT_SD_PROB'] = float(get_text(main_window.ui.MODEL__CONVNEXT_SD_PROB__INPUT))
2576+
biapy_config['MODEL']['CONVNEXT_LAYER_SCALE'] = float(get_text(main_window.ui.MODEL__CONVNEXT_LAYER_SCALE__INPUT))
2577+
biapy_config['MODEL']['CONVNEXT_STEM_K_SIZE'] = int(get_text(main_window.ui.MODEL__CONVNEXT_STEM_K_SIZE__INPUT))
2578+
2579+
if workflow_key_name in ["SEMANTIC_SEG","INSTANCE_SEG","DETECTION","CLASSIFICATION"]:
25842580
classes = 2
2585-
biapy_config['MODEL']['N_CLASSES'] = classes
2586-
2581+
if workflow_key_name == "SEMANTIC_SEG" and int(get_text(main_window.ui.MODEL__N_CLASSES__INPUT)) != 2:
2582+
classes = int(get_text(main_window.ui.MODEL__N_CLASSES__INPUT))
2583+
elif workflow_key_name == "INSTANCE_SEG" and int(get_text(main_window.ui.MODEL__N_CLASSES__INST_SEG__INPUT)) != 2:
2584+
classes = int(get_text(main_window.ui.MODEL__N_CLASSES__INST_SEG__INPUT))
2585+
elif workflow_key_name == "DETECTION" and int(get_text(main_window.ui.MODEL__N_CLASSES__DET__INPUT)) != 2:
2586+
classes = int(get_text(main_window.ui.MODEL__N_CLASSES__DET__INPUT))
2587+
else: # CLASSIFICATION
2588+
classes = int(get_text(main_window.ui.MODEL__N_CLASSES__CLS__INPUT))
2589+
2590+
if classes == 1:
2591+
classes = 2
2592+
biapy_config['MODEL']['N_CLASSES'] = classes
2593+
25872594
if get_text(main_window.ui.LOAD_PRETRAINED_MODEL__INPUT) == "Yes":
25882595
if get_text(main_window.ui.MODEL__SOURCE__INPUT) == "I have a model trained with BiaPy":
25892596
biapy_config['MODEL']['LOAD_CHECKPOINT'] = True
@@ -3024,8 +3031,7 @@ def create_yaml_file(main_window):
30243031
create_dict_from_key(key, value, model_restrictions)
30253032
else:
30263033
raise ValueError(f"Error found in config: {key}. Contact BiaPy team!")
3027-
3028-
biapy_config.update(model_restrictions)
3034+
biapy_config = update_dict(biapy_config, model_restrictions)
30293035

30303036
if not main_window.cfg.settings['yaml_config_filename'].endswith(".yaml") and not main_window.cfg.settings['yaml_config_filename'].endswith(".yml"):
30313037
main_window.cfg.settings['yaml_config_filename'] = main_window.cfg.settings['yaml_config_filename']+".yaml"
@@ -3366,7 +3372,7 @@ def analyze_dict(self, conf, sep=""):
33663372
err, _vars = self.analyze_dict(v,sep+k if sep == "" else sep+"__"+k)
33673373
if err is not None:
33683374
errors +=err
3369-
variables_set.update(_vars)
3375+
variables_set = update_dict(variables_set,_vars)
33703376
else:
33713377
widget_name = sep+"__"+k+"__INPUT"
33723378
other_widgets_to_set = []
@@ -3792,7 +3798,7 @@ def save_biapy_config(main_window, data, biapy_version=""):
37923798
old_data = json.load(file)
37933799
except:
37943800
old_data = {}
3795-
old_data.update(data)
3801+
old_data = update_dict(old_data, data)
37963802

37973803
if biapy_version != "":
37983804
if "GUI_VERSION" not in old_data:
@@ -3806,4 +3812,12 @@ def save_biapy_config(main_window, data, biapy_version=""):
38063812
old_data["GUI_VERSION"] = biapy_version
38073813

38083814
with open(main_window.log_info["config_file"], "w") as outfile:
3809-
json.dump(old_data, outfile, indent=4)
3815+
json.dump(old_data, outfile, indent=4)
3816+
3817+
def update_dict(d, u):
3818+
for k, v in u.items():
3819+
if isinstance(v, collections.abc.Mapping):
3820+
d[k] = update_dict(d.get(k, {}), v)
3821+
else:
3822+
d[k] = v
3823+
return d

0 commit comments

Comments
 (0)