Skip to content

Commit c119088

Browse files
committed
Change batch size calculation to adjust better into all scenarios
1 parent 047fa6a commit c119088

File tree

2 files changed

+41
-26
lines changed

2 files changed

+41
-26
lines changed

run_functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def run(self):
423423
volumes[p] = {"bind": p, "mode": mode}
424424
p = path_to_linux(self.config['DATA']['TRAIN']['PATH'], self.main_gui.cfg.settings['os_host'])
425425
temp_cfg['DATA']['TRAIN']['PATH'] = p
426-
paths_message += "<tr><td>Train raw image path</td><td>{}</td></tr>".format(
426+
paths_message += "<tr><td>Train raw image path </td><td>{}</td></tr>".format(
427427
self.config['DATA']['TRAIN']['PATH']
428428
)
429429

@@ -436,7 +436,7 @@ def run(self):
436436
volumes[p] = {"bind": p, "mode": mode}
437437
p = path_to_linux(self.config['DATA']['TRAIN']['GT_PATH'], self.main_gui.cfg.settings['os_host'])
438438
temp_cfg['DATA']['TRAIN']['GT_PATH'] = p
439-
paths_message += "<tr><td>Train target path</td><td>{}</td></tr>".format(
439+
paths_message += "<tr><td>Train target path </td><td>{}</td></tr>".format(
440440
self.config['DATA']['TRAIN']['GT_PATH']
441441
)
442442

@@ -449,7 +449,7 @@ def run(self):
449449
volumes[p] = {"bind": p, "mode": mode}
450450
p = path_to_linux(self.config['DATA']['VAL']['PATH'], self.main_gui.cfg.settings['os_host'])
451451
temp_cfg['DATA']['VAL']['PATH'] = p
452-
paths_message += "<tr><td>Validation raw image path</td><td>{}</td></tr>".format(
452+
paths_message += "<tr><td>Validation raw image path </td><td>{}</td></tr>".format(
453453
self.config['DATA']['VAL']['PATH']
454454
)
455455

@@ -462,7 +462,7 @@ def run(self):
462462
volumes[p] = {"bind": p, "mode": mode}
463463
p = path_to_linux(self.config['DATA']['VAL']['GT_PATH'], self.main_gui.cfg.settings['os_host'])
464464
temp_cfg['DATA']['VAL']['GT_PATH'] = p
465-
paths_message += "<tr><td>Validation target path</td><td>{}</td></tr>".format(
465+
paths_message += "<tr><td>Validation target path </td><td>{}</td></tr>".format(
466466
self.config['DATA']['VAL']['GT_PATH']
467467
)
468468
else:
@@ -478,7 +478,7 @@ def run(self):
478478
volumes[p] = {"bind": p, "mode": mode}
479479
p = path_to_linux(self.config['DATA']['TEST']['PATH'], self.main_gui.cfg.settings['os_host'])
480480
temp_cfg['DATA']['TEST']['PATH'] = p
481-
paths_message += "<tr><td>Test raw image path</td><td>{}</td></tr>".format(
481+
paths_message += "<tr><td>Test raw image path </td><td>{}</td></tr>".format(
482482
self.config['DATA']['TEST']['PATH']
483483
)
484484

@@ -491,7 +491,7 @@ def run(self):
491491
volumes[p] = {"bind": p, "mode": mode}
492492
p = path_to_linux(self.config['DATA']['TEST']['GT_PATH'], self.main_gui.cfg.settings['os_host'])
493493
temp_cfg['DATA']['TEST']['GT_PATH'] = p
494-
paths_message += "<tr><td>Test target path</td><td>{}</td></tr>".format(
494+
paths_message += "<tr><td>Test target path </td><td>{}</td></tr>".format(
495495
self.config['DATA']['TEST']['GT_PATH']
496496
)
497497

@@ -562,7 +562,7 @@ def run(self):
562562
<tr><td>Jobname</td><td>{}</td></tr>\
563563
<tr><td>YAML</td><td><a href={}>{}</a></td></tr>\
564564
<tr><td>Device</td><td>{}</td></tr>\
565-
<tr><td>Output folder </td><td><a href={}>{}</a></td></tr>\
565+
<tr><td>Output folder</td><td><a href={}>{}</a></td></tr>\
566566
<tr><td>Output log</td><td><a href={}>{}</a></td></tr>\
567567
<tr><td>Error log</td><td><a href={}>{}</a></td></tr>\
568568
{}\

ui_utils.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,16 +1216,16 @@ def set_default_config(cfg, gpu_info, sample_info):
12161216
# TRAINING PARAMETERS #
12171217
#######################
12181218
if cfg['TRAIN']['ENABLE']:
1219-
cfg['TRAIN']['EPOCHS'] = 1000
1220-
cfg['TRAIN']['PATIENCE'] = 50
1219+
cfg['TRAIN']['EPOCHS'] = 300
1220+
cfg['TRAIN']['PATIENCE'] = 30
12211221
cfg['TRAIN']['OPTIMIZER'] = "ADAMW"
12221222
cfg['TRAIN']['LR'] = 1.E-4
12231223
# Learning rate scheduler
12241224
if "LR_SCHEDULER" not in cfg['TRAIN']:
12251225
cfg['TRAIN']['LR_SCHEDULER'] = {}
12261226
cfg['TRAIN']['LR_SCHEDULER']['NAME'] = 'warmupcosine'
12271227
cfg['TRAIN']['LR_SCHEDULER']['MIN_LR'] = 5.E-6
1228-
cfg['TRAIN']['LR_SCHEDULER']['WARMUP_COSINE_DECAY_EPOCHS'] = 5
1228+
cfg['TRAIN']['LR_SCHEDULER']['WARMUP_COSINE_DECAY_EPOCHS'] = 10
12291229
if "LOSS" not in cfg:
12301230
cfg['LOSS'] = {}
12311231
cfg['LOSS']['CLASS_REBALANCE'] = True
@@ -1235,7 +1235,7 @@ def set_default_config(cfg, gpu_info, sample_info):
12351235
#########
12361236
# TODO: calculate BMZ/torchvision model size. Now a high number to ensure batch_size is set to 1 so it
12371237
# works always
1238-
network_memory = 9999999
1238+
network_base_memory = 9999999
12391239
if cfg['MODEL']['SOURCE'] == "biapy":
12401240
if cfg["PROBLEM"]["TYPE"] in [
12411241
"SEMANTIC_SEG",
@@ -1247,18 +1247,23 @@ def set_default_config(cfg, gpu_info, sample_info):
12471247
cfg["MODEL"]["ARCHITECTURE"] = "resunet"
12481248
if cfg['PROBLEM']['NDIM'] == "3D":
12491249
cfg["MODEL"]["Z_DOWN"] = [1, 1, 1, 1]
1250-
network_memory = 220 if cfg['PROBLEM']['NDIM'] == "2D" else 615
1250+
network_base_memory = 400
12511251
elif cfg["PROBLEM"]["TYPE"] == "SUPER_RESOLUTION":
1252-
cfg["MODEL"]["ARCHITECTURE"] = "rcan" if cfg['PROBLEM']['NDIM'] == "2D" else "resunet"
1253-
network_memory = 876.92 if cfg['PROBLEM']['NDIM'] == "2D" else 615
1252+
if cfg['PROBLEM']['NDIM'] == "3D":
1253+
cfg["MODEL"]["ARCHITECTURE"] = "resunet"
1254+
cfg["MODEL"]["Z_DOWN"] = [1, 1, 1, 1]
1255+
network_base_memory = 400
1256+
else:
1257+
cfg["MODEL"]["ARCHITECTURE"] = "rcan"
1258+
network_base_memory = 3500
12541259
elif cfg["PROBLEM"]["TYPE"] == "SELF_SUPERVISED":
1255-
cfg["MODEL"]["ARCHITECTURE"] = "unet"
1256-
network_memory = 400
1260+
cfg["MODEL"]["ARCHITECTURE"] = "resunet"
1261+
network_base_memory = 400
12571262
if cfg['PROBLEM']['NDIM'] == "3D":
12581263
cfg["MODEL"]["Z_DOWN"] = [1, 1, 1, 1]
12591264
elif cfg["PROBLEM"]["TYPE"] == "CLASSIFICATION":
12601265
cfg["MODEL"]["ARCHITECTURE"] = "vit"
1261-
network_memory = 342.67 if cfg['PROBLEM']['NDIM'] == "2D" else 354.55
1266+
network_base_memory = 2200
12621267

12631268
#####################
12641269
# DATA AUGMENTATION #
@@ -1353,6 +1358,12 @@ def set_default_config(cfg, gpu_info, sample_info):
13531358
elif cfg["PROBLEM"]["TYPE"] == "CLASSIFICATION":
13541359
pass
13551360

1361+
# Removing detection key added from the wizard it we're not in that workflow
1362+
if not cfg['TEST']['ENABLE'] or cfg['PROBLEM']['TYPE'] != 'DETECTION':
1363+
del cfg['TEST']['POST_PROCESSING']["REMOVE_CLOSE_POINTS_RADIUS"]
1364+
if len(cfg['TEST']['POST_PROCESSING']) == 0:
1365+
del cfg['TEST']['POST_PROCESSING']
1366+
13561367
# Calculate data channels
13571368
# cfg.PROBLEM.TYPE == "CLASSIFICATION" or "SELF_SUPERVISED" and PRETEXT_TASK == "masking". But as in the wizard there is no SSL
13581369
# we reduce the if fo samples_each_time to this
@@ -1386,7 +1397,7 @@ def set_default_config(cfg, gpu_info, sample_info):
13861397
sample_info=sample_info,
13871398
patch_size=cfg['DATA']['PATCH_SIZE'],
13881399
channels_per_sample=channels_per_sample,
1389-
network_memory=network_memory,
1400+
network_base_memory=network_base_memory,
13901401
y_upsampling=y_upsampling,
13911402
max_batch_size_allowed=max_batch_size_allowed,
13921403
)
@@ -1404,18 +1415,18 @@ def batch_size_calculator(
14041415
sample_info,
14051416
patch_size,
14061417
channels_per_sample,
1407-
network_memory=400,
1418+
network_base_memory=400,
14081419
sample_size_reference=256,
1409-
sample_memory_reference=45,
1420+
sample_memory_reference=100,
14101421
max_batch_size_allowed=32,
14111422
y_upsampling=(1,1),
14121423
):
14131424
"""
14141425
Calculate a reasonable value for the batch size measuring how much memory can consume a item returned by BiaPy generator
14151426
(i.e. __getitem__ function). It takes into account the GPUs available (taking the minimum memory among them), number
1416-
of samples, memory of the network (``network_memory``) and a sample reference memory (``sample_size_reference`` and
1427+
of samples, memory of the network (``network_base_memory``) and a sample reference memory (``sample_size_reference`` and
14171428
``sample_memory_reference``). The default settings of this variables takes as reference samples of (256,256,1), where each
1418-
sample is calculated to consume 45MB of memory (more or less).
1429+
sample is calculated to consume 100MB of memory (more or less).
14191430
14201431
Parameters
14211432
----------
@@ -1438,7 +1449,7 @@ def batch_size_calculator(
14381449
* 'x': int. Channels for X data.
14391450
* 'y': int. Channels for Y data.
14401451
1441-
network_memory : int, optional
1452+
network_base_memory : int, optional
14421453
Memory consumed by the deep learning model in MB. This needs to be previously measured. For instace, the default
14431454
U-Net of BiaPy takes 400MB in the GPU.
14441455
@@ -1471,7 +1482,7 @@ def batch_size_calculator(
14711482
# the data an calculating the total number of samples, with the sample reference (composed by sample_size_reference and
14721483
# sample_memory_reference).
14731484
x_analisis_ref_ratio = x_data_to_analize['crop_shape'][1] / sample_size_reference
1474-
x_selected_patch_ref_ratio = patch_size[1] / sample_size_reference
1485+
x_selected_patch_ref_ratio = (patch_size[1] / sample_size_reference)**2
14751486
x_sample_memory_ratio = x_analisis_ref_ratio * x_selected_patch_ref_ratio * sample_memory_reference
14761487

14771488
y_data_to_analize = None
@@ -1484,7 +1495,7 @@ def batch_size_calculator(
14841495
y_crop = y_data_to_analize['crop_shape'][1] * y_upsampling[1]
14851496
# Patch size always square in wizard, that's why we use always y_data_to_analize['crop_shape'][1]
14861497
y_analisis_ref_ratio = y_crop / sample_size_reference
1487-
y_selected_patch_ref_ratio = patch_size[1] / sample_size_reference
1498+
y_selected_patch_ref_ratio = (patch_size[1] / sample_size_reference)**2
14881499
y_sample_memory_ratio = y_analisis_ref_ratio * y_selected_patch_ref_ratio * sample_memory_reference
14891500
else:
14901501
y_sample_memory_ratio = 0
@@ -1500,8 +1511,12 @@ def batch_size_calculator(
15001511
# This value will be close to the one being used during a __getitem__ of BiaPy generator
15011512
item_memory_consumption = (x_sample_memory_ratio*x_data) + (y_sample_memory_ratio*y_data)
15021513

1514+
# The network will hold more memory depending on the input_size. It's calculated to be like 210MB per each (256,256,1) sample
1515+
# so we use here x_sample_memory_ratio as it was measure with sample_memory_reference=100MB by default
1516+
approx_network_memory = network_base_memory + (2*x_sample_memory_ratio*x_data)
1517+
15031518
if item_memory_consumption == 0: item_memory_consumption = 1
1504-
number_of_samples = (min_mem_gpu - network_memory)//item_memory_consumption if (min_mem_gpu - network_memory) > 0 else 1
1519+
number_of_samples = (min_mem_gpu - approx_network_memory)//item_memory_consumption if (min_mem_gpu - approx_network_memory) > 0 else 1
15051520
else:
15061521
number_of_samples = max_batch_size_allowed//2
15071522

0 commit comments

Comments
 (0)