@@ -1238,7 +1238,8 @@ def check_configuration(cfg, jobname, check_data_paths=True):
12381238 "vit" ,
12391239 "mae" ,
12401240 "unext_v1" ,
1241- ], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1']"
1241+ "unext_v2" ,
1242+ ], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2']"
12421243 if (
12431244 model_arch
12441245 not in [
@@ -1253,6 +1254,7 @@ def check_configuration(cfg, jobname, check_data_paths=True):
12531254 "vit" ,
12541255 "mae" ,
12551256 "unext_v1" ,
1257+ "unext_v2" ,
12561258 ]
12571259 and cfg .PROBLEM .NDIM == "3D"
12581260 and cfg .PROBLEM .TYPE != "CLASSIFICATION"
@@ -1271,6 +1273,7 @@ def check_configuration(cfg, jobname, check_data_paths=True):
12711273 "vit" ,
12721274 "mae" ,
12731275 "unext_v1" ,
1276+ "unext_v2" ,
12741277 ]
12751278 )
12761279 )
@@ -1288,10 +1291,11 @@ def check_configuration(cfg, jobname, check_data_paths=True):
12881291 "multiresunet" ,
12891292 "unetr" ,
12901293 "unext_v1" ,
1294+ "unext_v2" ,
12911295 ]
12921296 ):
12931297 raise ValueError (
1294- "'MODEL.N_CLASSES' > 2 can only be used with 'MODEL.ARCHITECTURE' in ['unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unetr', 'unext_v1']"
1298+ "'MODEL.N_CLASSES' > 2 can only be used with 'MODEL.ARCHITECTURE' in ['unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unetr', 'unext_v1', 'unext_v2' ]"
12951299 )
12961300
12971301 assert len (cfg .MODEL .FEATURE_MAPS ) > 2 , "'MODEL.FEATURE_MAPS' needs to have at least 3 values"
@@ -1372,10 +1376,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
13721376 "resunet_se" ,
13731377 "unetr" ,
13741378 "multiresunet" ,
1375- "unext_v1" ,
1379+ "unext_v1" ,"unext_v2" ,
13761380 ]:
13771381 raise ValueError (
1378- "Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1']" .format (
1382+ "Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2' ]" .format (
13791383 cfg .PROBLEM .TYPE
13801384 )
13811385 )
@@ -1393,9 +1397,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
13931397 "attention_unet" ,
13941398 "multiresunet" ,
13951399 "unext_v1" ,
1400+ "unext_v2" ,
13961401 ]:
13971402 raise ValueError (
1398- "Architectures available for 2D 'SUPER_RESOLUTION' are: ['edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1']"
1403+ "Architectures available for 2D 'SUPER_RESOLUTION' are: ['edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2' ]"
13991404 )
14001405 elif cfg .PROBLEM .NDIM == "3D" :
14011406 if model_arch not in [
@@ -1406,9 +1411,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
14061411 "attention_unet" ,
14071412 "multiresunet" ,
14081413 "unext_v1" ,
1414+ "unext_v2" ,
14091415 ]:
14101416 raise ValueError (
1411- "Architectures available for 3D 'SUPER_RESOLUTION' are: ['unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1']"
1417+ "Architectures available for 3D 'SUPER_RESOLUTION' are: ['unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2' ]"
14121418 )
14131419 assert cfg .MODEL .UNET_SR_UPSAMPLE_POSITION in [
14141420 "pre" ,
@@ -1429,9 +1435,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
14291435 "unetr" ,
14301436 "multiresunet" ,
14311437 "unext_v1" ,
1438+ "unext_v2" ,
14321439 ]:
14331440 raise ValueError (
1434- "Architectures available for 'IMAGE_TO_IMAGE' are: ['edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'resunet_se', 'seunet', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1']"
1441+ "Architectures available for 'IMAGE_TO_IMAGE' are: ['edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'resunet_se', 'seunet', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2' ]"
14351442 )
14361443 elif cfg .PROBLEM .TYPE == "SELF_SUPERVISED" :
14371444 if model_arch not in [
@@ -1444,6 +1451,7 @@ def check_configuration(cfg, jobname, check_data_paths=True):
14441451 "resunet_se" ,
14451452 "unetr" ,
14461453 "unext_v1" ,
1454+ "unext_v2" ,
14471455 "edsr" ,
14481456 "rcan" ,
14491457 "dfcan" ,
@@ -1482,6 +1490,7 @@ def check_configuration(cfg, jobname, check_data_paths=True):
14821490 "attention_unet" ,
14831491 "multiresunet" ,
14841492 "unext_v1" ,
1493+ "unext_v2" ,
14851494 ]:
14861495 z_size = cfg .DATA .PATCH_SIZE [0 ]
14871496 sizes = cfg .DATA .PATCH_SIZE [1 :- 1 ]
@@ -1684,10 +1693,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
16841693 "'TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS' needs to be set when 'TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS' is True"
16851694 )
16861695
1687- def compare_configurations_without_model (actual_cfg , old_cfg , header_message = "" ):
1696+ def compare_configurations_without_model (actual_cfg , old_cfg , header_message = "" , old_cfg_version = None ):
16881697 """
16891698 Compares two configurations and throws an error if they differ in some critical variables that change workflow behaviour. This
1690- comparisdon does not take into account model specs.
1699+ comparisdon does not take into account model specs.
16911700 """
16921701 print ("Comparing configurations . . ." )
16931702
@@ -1699,21 +1708,33 @@ def compare_configurations_without_model(actual_cfg, old_cfg, header_message="")
16991708 "PROBLEM.SELF_SUPERVISED.PRETEXT_TASK" ,
17001709 "PROBLEM.SUPER_RESOLUTION.UPSCALING" ,
17011710 "MODEL.N_CLASSES" ,
1702- ]
1703-
1711+ ]
1712+
17041713 def get_attribute_recursive (var , attr ):
17051714 att = attr .split ("." )
17061715 if len (att ) == 1 :
17071716 return getattr (var , att [0 ])
17081717 else :
17091718 return get_attribute_recursive (getattr (var , att [0 ]), "." .join (att [1 :]))
1710-
1719+
1720+ # Old configuration translation
1721+ dim_count = 2 if old_cfg .PROBLEM .NDIM == "2D" else 3
1722+ # BiaPy version less than 3.5.5
1723+ if old_cfg_version is None :
1724+ if isinstance (old_cfg ["PROBLEM" ]["SUPER_RESOLUTION" ]["UPSCALING" ], int ):
1725+ old_cfg ["PROBLEM" ]["SUPER_RESOLUTION" ]["UPSCALING" ] = (old_cfg ["PROBLEM" ]["SUPER_RESOLUTION" ]["UPSCALING" ],) * dim_count
1726+
17111727 for var_to_compare in vars_to_compare :
17121728 if get_attribute_recursive (actual_cfg , var_to_compare ) != get_attribute_recursive (old_cfg , var_to_compare ):
1713- raise ValueError (header_message + f"The '{ var_to_compare } ' value of the compared configurations does not match" )
1714-
1729+ raise ValueError (
1730+ header_message + f"The '{ var_to_compare } ' value of the compared configurations does not match: " + \
1731+ f"{ get_attribute_recursive (actual_cfg , var_to_compare )} (current configuration) vs { get_attribute_recursive (old_cfg , var_to_compare )} (from loaded configuration)"
1732+ )
1733+
17151734 print ("Configurations seem to be compatible. Continuing . . ." )
17161735
1736+
1737+
17171738def get_checkpoint_path (cfg , jobname ):
17181739 """Get the checkpoint file path"""
17191740 checkpoint_dir = Path (cfg .PATHS .CHECKPOINT )
@@ -1942,7 +1963,10 @@ def check_torchvision_available_models(workflow, ndim):
19421963 return models , model_restrictions_description , model_restrictions
19431964
19441965def convert_old_model_cfg_to_current_version (old_cfg ):
1945- # https://github.com/BiaPyX/BiaPy/compare/6aa291baa9bc5d7fb410454bfcea3a3da0c23604...v3.5.5
1966+ """
1967+ Backward compatibility until commit 6aa291baa9bc5d7fb410454bfcea3a3da0c23604 (version 3.2.0)
1968+ Commit url: https://github.com/BiaPyX/BiaPy/commit/6aa291baa9bc5d7fb410454bfcea3a3da0c23604
1969+ """
19461970 if "TEST" in old_cfg :
19471971 if "STATS" in old_cfg ["TEST" ]:
19481972 full_image = old_cfg ["TEST" ]["STATS" ]["FULL_IMG" ]
@@ -2015,7 +2039,7 @@ def convert_old_model_cfg_to_current_version(old_cfg):
20152039 old_cfg ["DATA" ]["TRAIN" ]["FILTER_SAMPLES" ] = {}
20162040 old_cfg ["DATA" ]["TRAIN" ]["FILTER_SAMPLES" ]["PROPS" ] = [['foreground' ]]
20172041 old_cfg ["DATA" ]["TRAIN" ]["FILTER_SAMPLES" ]["VALUES" ] = [[min_fore ]]
2018- old_cfg ["DATA" ]["TRAIN" ]["FILTER_SAMPLES" ]["SIGN " ] = [['lt' ]]
2042+ old_cfg ["DATA" ]["TRAIN" ]["FILTER_SAMPLES" ]["SIGNS " ] = [['lt' ]]
20192043 if "VAL" in old_cfg ["DATA" ]:
20202044 if "BINARY_MASKS" in old_cfg ["DATA" ]["VAL" ]:
20212045 del old_cfg ["DATA" ]["VAL" ]["BINARY_MASKS" ]
@@ -2098,4 +2122,4 @@ def convert_old_model_cfg_to_current_version(old_cfg):
20982122 except :
20992123 pass
21002124
2101- return old_cfg
2125+ return old_cfg
0 commit comments