Skip to content

Commit cd02edc

Browse files
committed
Update BMZ check with last changes in BiaPy
1 parent 3ec1624 commit cd02edc

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

biapy/biapy_aux_functions.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Optional, Dict, Tuple, List
88
from packaging.version import Version
99

10-
## Copied from BiaPy commit: ff12c305c03c48bcc07510c29887a4281cc108b7 (3.5.2)
10+
## Copied from BiaPy commit: 9e85d64ac040449182794b6b9c2b3ae0c5b652e7 (3.5.2)
1111
def check_bmz_model_compatibility(
1212
model_rdf: Dict,
1313
workflow_specs: Optional[Dict] = None,
@@ -191,7 +191,7 @@ def check_bmz_model_compatibility(
191191

192192
return preproc_info, error, reason_message
193193

194-
# Adapted from BiaPy commit: ff12c305c03c48bcc07510c29887a4281cc108b7 (3.5.2)
194+
# Adapted from BiaPy commit: 9e85d64ac040449182794b6b9c2b3ae0c5b652e7 (3.5.2)
195195
def check_model_restrictions(
196196
model_rdf,
197197
workflow_specs,
@@ -255,12 +255,11 @@ def check_model_restrictions(
255255
raise ValueError("Couldn't load input info from BMZ model's RDF: {}".format(model_rdf["inputs"][0]))
256256
opts["DATA.PATCH_SIZE"] = tuple(input_image_shape[2:]) + (input_image_shape[1],)
257257

258-
# 2) Classes in semantic segmentation. This is slightly different from BiaPy because we are using RDF dict here.
259-
# if (specific_workflow in ["INSTANCE_SEG", "SEMANTIC_SEG", "DETECTION"]):
258+
# 2) Workflow specific restrictions
259+
# Classes in semantic segmentation
260260
if specific_workflow in ["SEMANTIC_SEG"]:
261261
# Check number of classes
262262
classes = -1
263-
print(model_rdf["weights"]["pytorch_state_dict"])
264263
if "kwargs" in model_rdf["weights"]["pytorch_state_dict"]:
265264
if "n_classes" in model_rdf["weights"]["pytorch_state_dict"]["kwargs"]: # BiaPy
266265
classes = model_rdf["weights"]["pytorch_state_dict"]["kwargs"]["n_classes"]
@@ -276,8 +275,23 @@ def check_model_restrictions(
276275

277276
if specific_workflow == "SEMANTIC_SEG" and classes == -1:
278277
raise ValueError("Classes not found for semantic segmentation dir. ")
279-
print(f"classes: {classes}")
280278
opts["MODEL.N_CLASSES"] = max(2,classes)
279+
elif specific_workflow in ["INSTANCE_SEG"]:
280+
# Assumed it's BC. This needs a more elaborated process. Still deciding this:
281+
# https://github.com/bioimage-io/spec-bioimage-io/issues/621
282+
channels = 2
283+
if "out_channels" in bmz_config["original_bmz_config"].weights.pytorch_state_dict.kwargs:
284+
channels = bmz_config["original_bmz_config"].weights.pytorch_state_dict.kwargs["out_channels"]
285+
if channels == 1:
286+
channel_code = "C"
287+
elif channels == 2:
288+
channel_code = "BC"
289+
elif channels == 3:
290+
channel_code = "BCM"
291+
if channels > 3:
292+
raise ValueError(f"Not recognized number of channels for instance segmentation. Obtained {channels}")
293+
294+
opts["PROBLEM.INSTANCE_SEG.DATA_CHANNELS"] = channel_code
281295

282296
if "preprocessing" not in model_rdf["inputs"][0]:
283297
return opts
@@ -756,7 +770,7 @@ def check_value(value, value_range=(0, 1)):
756770
return False
757771

758772

759-
# Copied from BiaPy commit: ff12c305c03c48bcc07510c29887a4281cc108b7 (3.5.2)
773+
# Copied from BiaPy commit: 9e85d64ac040449182794b6b9c2b3ae0c5b652e7 (3.5.2)
760774
def crop_data_with_overlap(data, crop_shape, data_mask=None, overlap=(0, 0), padding=(0, 0), verbose=True,
761775
load_data=True):
762776
"""
@@ -876,7 +890,7 @@ def crop_data_with_overlap(data, crop_shape, data_mask=None, overlap=(0, 0), pad
876890
if p >= crop_shape[i] // 2:
877891
raise ValueError(
878892
"'Padding' can not be greater than the half of 'crop_shape'. Max value for this {} input shape is {}".format(
879-
data.shape, [(crop_shape[0] // 2) - 1, (crop_shape[1] // 2) - 1]
893+
crop_shape, [(crop_shape[0] // 2) - 1, (crop_shape[1] // 2) - 1]
880894
)
881895
)
882896
if len(crop_shape) != 3:
@@ -1006,7 +1020,7 @@ def crop_data_with_overlap(data, crop_shape, data_mask=None, overlap=(0, 0), pad
10061020
else:
10071021
return crop_coords
10081022

1009-
# Copied from BiaPy commit: ff12c305c03c48bcc07510c29887a4281cc108b7 (3.5.2)
1023+
# Copied from BiaPy commit: 9e85d64ac040449182794b6b9c2b3ae0c5b652e7 (3.5.2)
10101024
def crop_3D_data_with_overlap(
10111025
data,
10121026
vol_shape,
@@ -1149,7 +1163,7 @@ def crop_3D_data_with_overlap(
11491163
if p >= vol_shape[i] // 2:
11501164
raise ValueError(
11511165
"'Padding' can not be greater than the half of 'vol_shape'. Max value for this {} input shape is {}".format(
1152-
data.shape,
1166+
vol_shape,
11531167
[
11541168
(vol_shape[0] // 2) - 1,
11551169
(vol_shape[1] // 2) - 1,
@@ -1290,7 +1304,7 @@ def crop_3D_data_with_overlap(
12901304
else:
12911305
return crop_coords
12921306

1293-
# Copied from BiaPy commit: ff12c305c03c48bcc07510c29887a4281cc108b7 (3.5.2)
1307+
# Copied from BiaPy commit: 9e85d64ac040449182794b6b9c2b3ae0c5b652e7 (3.5.2)
12941308
def pad_and_reflect(img, crop_shape, verbose=False):
12951309
"""
12961310
Load data from a directory.

0 commit comments

Comments
 (0)