@@ -1216,16 +1216,16 @@ def set_default_config(cfg, gpu_info, sample_info):
12161216 # TRAINING PARAMETERS #
12171217 #######################
12181218 if cfg ['TRAIN' ]['ENABLE' ]:
1219- cfg ['TRAIN' ]['EPOCHS' ] = 1000
1220- cfg ['TRAIN' ]['PATIENCE' ] = 50
1219+ cfg ['TRAIN' ]['EPOCHS' ] = 300
1220+ cfg ['TRAIN' ]['PATIENCE' ] = 30
12211221 cfg ['TRAIN' ]['OPTIMIZER' ] = "ADAMW"
12221222 cfg ['TRAIN' ]['LR' ] = 1.E-4
12231223 # Learning rate scheduler
12241224 if "LR_SCHEDULER" not in cfg ['TRAIN' ]:
12251225 cfg ['TRAIN' ]['LR_SCHEDULER' ] = {}
12261226 cfg ['TRAIN' ]['LR_SCHEDULER' ]['NAME' ] = 'warmupcosine'
12271227 cfg ['TRAIN' ]['LR_SCHEDULER' ]['MIN_LR' ] = 5.E-6
1228- cfg ['TRAIN' ]['LR_SCHEDULER' ]['WARMUP_COSINE_DECAY_EPOCHS' ] = 5
1228+ cfg ['TRAIN' ]['LR_SCHEDULER' ]['WARMUP_COSINE_DECAY_EPOCHS' ] = 10
12291229 if "LOSS" not in cfg :
12301230 cfg ['LOSS' ] = {}
12311231 cfg ['LOSS' ]['CLASS_REBALANCE' ] = True
@@ -1235,7 +1235,7 @@ def set_default_config(cfg, gpu_info, sample_info):
12351235 #########
12361236 # TODO: calculate BMZ/torchvision model size. Now a high number to ensure batch_size is set to 1 so it
12371237 # works always
1238- network_memory = 9999999
1238+ network_base_memory = 9999999
12391239 if cfg ['MODEL' ]['SOURCE' ] == "biapy" :
12401240 if cfg ["PROBLEM" ]["TYPE" ] in [
12411241 "SEMANTIC_SEG" ,
@@ -1247,18 +1247,23 @@ def set_default_config(cfg, gpu_info, sample_info):
12471247 cfg ["MODEL" ]["ARCHITECTURE" ] = "resunet"
12481248 if cfg ['PROBLEM' ]['NDIM' ] == "3D" :
12491249 cfg ["MODEL" ]["Z_DOWN" ] = [1 , 1 , 1 , 1 ]
1250- network_memory = 220 if cfg [ 'PROBLEM' ][ 'NDIM' ] == "2D" else 615
1250+ network_base_memory = 400
12511251 elif cfg ["PROBLEM" ]["TYPE" ] == "SUPER_RESOLUTION" :
1252- cfg ["MODEL" ]["ARCHITECTURE" ] = "rcan" if cfg ['PROBLEM' ]['NDIM' ] == "2D" else "resunet"
1253- network_memory = 876.92 if cfg ['PROBLEM' ]['NDIM' ] == "2D" else 615
1252+ if cfg ['PROBLEM' ]['NDIM' ] == "3D" :
1253+ cfg ["MODEL" ]["ARCHITECTURE" ] = "resunet"
1254+ cfg ["MODEL" ]["Z_DOWN" ] = [1 , 1 , 1 , 1 ]
1255+ network_base_memory = 400
1256+ else :
1257+ cfg ["MODEL" ]["ARCHITECTURE" ] = "rcan"
1258+ network_base_memory = 3500
12541259 elif cfg ["PROBLEM" ]["TYPE" ] == "SELF_SUPERVISED" :
1255- cfg ["MODEL" ]["ARCHITECTURE" ] = "unet "
1256- network_memory = 400
1260+ cfg ["MODEL" ]["ARCHITECTURE" ] = "resunet "
1261+ network_base_memory = 400
12571262 if cfg ['PROBLEM' ]['NDIM' ] == "3D" :
12581263 cfg ["MODEL" ]["Z_DOWN" ] = [1 , 1 , 1 , 1 ]
12591264 elif cfg ["PROBLEM" ]["TYPE" ] == "CLASSIFICATION" :
12601265 cfg ["MODEL" ]["ARCHITECTURE" ] = "vit"
1261- network_memory = 342.67 if cfg [ 'PROBLEM' ][ 'NDIM' ] == "2D" else 354.55
1266+ network_base_memory = 2200
12621267
12631268 #####################
12641269 # DATA AUGMENTATION #
@@ -1353,6 +1358,12 @@ def set_default_config(cfg, gpu_info, sample_info):
13531358 elif cfg ["PROBLEM" ]["TYPE" ] == "CLASSIFICATION" :
13541359 pass
13551360
1361+ # Removing detection key added from the wizard it we're not in that workflow
1362+ if not cfg ['TEST' ]['ENABLE' ] or cfg ['PROBLEM' ]['TYPE' ] != 'DETECTION' :
1363+ del cfg ['TEST' ]['POST_PROCESSING' ]["REMOVE_CLOSE_POINTS_RADIUS" ]
1364+ if len (cfg ['TEST' ]['POST_PROCESSING' ]) == 0 :
1365+ del cfg ['TEST' ]['POST_PROCESSING' ]
1366+
13561367 # Calculate data channels
13571368 # cfg.PROBLEM.TYPE == "CLASSIFICATION" or "SELF_SUPERVISED" and PRETEXT_TASK == "masking". But as in the wizard there is no SSL
13581369 # we reduce the if fo samples_each_time to this
@@ -1386,7 +1397,7 @@ def set_default_config(cfg, gpu_info, sample_info):
13861397 sample_info = sample_info ,
13871398 patch_size = cfg ['DATA' ]['PATCH_SIZE' ],
13881399 channels_per_sample = channels_per_sample ,
1389- network_memory = network_memory ,
1400+ network_base_memory = network_base_memory ,
13901401 y_upsampling = y_upsampling ,
13911402 max_batch_size_allowed = max_batch_size_allowed ,
13921403 )
@@ -1404,18 +1415,18 @@ def batch_size_calculator(
14041415 sample_info ,
14051416 patch_size ,
14061417 channels_per_sample ,
1407- network_memory = 400 ,
1418+ network_base_memory = 400 ,
14081419 sample_size_reference = 256 ,
1409- sample_memory_reference = 45 ,
1420+ sample_memory_reference = 100 ,
14101421 max_batch_size_allowed = 32 ,
14111422 y_upsampling = (1 ,1 ),
14121423 ):
14131424 """
14141425 Calculate a reasonable value for the batch size measuring how much memory can consume a item returned by BiaPy generator
14151426 (i.e. __getitem__ function). It takes into account the GPUs available (taking the minimum memory among them), number
1416- of samples, memory of the network (``network_memory ``) and a sample reference memory (``sample_size_reference`` and
1427+ of samples, memory of the network (``network_base_memory ``) and a sample reference memory (``sample_size_reference`` and
14171428 ``sample_memory_reference``). The default settings of this variables takes as reference samples of (256,256,1), where each
1418- sample is calculated to consume 45MB of memory (more or less).
1429+ sample is calculated to consume 100MB of memory (more or less).
14191430
14201431 Parameters
14211432 ----------
@@ -1438,7 +1449,7 @@ def batch_size_calculator(
14381449 * 'x': int. Channels for X data.
14391450 * 'y': int. Channels for Y data.
14401451
1441- network_memory : int, optional
1452+ network_base_memory : int, optional
14421453 Memory consumed by the deep learning model in MB. This needs to be previously measured. For instace, the default
14431454 U-Net of BiaPy takes 400MB in the GPU.
14441455
@@ -1471,7 +1482,7 @@ def batch_size_calculator(
14711482 # the data an calculating the total number of samples, with the sample reference (composed by sample_size_reference and
14721483 # sample_memory_reference).
14731484 x_analisis_ref_ratio = x_data_to_analize ['crop_shape' ][1 ] / sample_size_reference
1474- x_selected_patch_ref_ratio = patch_size [1 ] / sample_size_reference
1485+ x_selected_patch_ref_ratio = ( patch_size [1 ] / sample_size_reference ) ** 2
14751486 x_sample_memory_ratio = x_analisis_ref_ratio * x_selected_patch_ref_ratio * sample_memory_reference
14761487
14771488 y_data_to_analize = None
@@ -1484,7 +1495,7 @@ def batch_size_calculator(
14841495 y_crop = y_data_to_analize ['crop_shape' ][1 ] * y_upsampling [1 ]
14851496 # Patch size always square in wizard, that's why we use always y_data_to_analize['crop_shape'][1]
14861497 y_analisis_ref_ratio = y_crop / sample_size_reference
1487- y_selected_patch_ref_ratio = patch_size [1 ] / sample_size_reference
1498+ y_selected_patch_ref_ratio = ( patch_size [1 ] / sample_size_reference ) ** 2
14881499 y_sample_memory_ratio = y_analisis_ref_ratio * y_selected_patch_ref_ratio * sample_memory_reference
14891500 else :
14901501 y_sample_memory_ratio = 0
@@ -1500,8 +1511,12 @@ def batch_size_calculator(
15001511 # This value will be close to the one being used during a __getitem__ of BiaPy generator
15011512 item_memory_consumption = (x_sample_memory_ratio * x_data ) + (y_sample_memory_ratio * y_data )
15021513
1514+ # The network will hold more memory depending on the input_size. It's calculated to be like 210MB per each (256,256,1) sample
1515+ # so we use here x_sample_memory_ratio as it was measure with sample_memory_reference=100MB by default
1516+ approx_network_memory = network_base_memory + (2 * x_sample_memory_ratio * x_data )
1517+
15031518 if item_memory_consumption == 0 : item_memory_consumption = 1
1504- number_of_samples = (min_mem_gpu - network_memory )// item_memory_consumption if (min_mem_gpu - network_memory ) > 0 else 1
1519+ number_of_samples = (min_mem_gpu - approx_network_memory )// item_memory_consumption if (min_mem_gpu - approx_network_memory ) > 0 else 1
15051520 else :
15061521 number_of_samples = max_batch_size_allowed // 2
15071522
0 commit comments