Skip to content

Commit d64c6d2

Browse files
authored
Keep n components (#5138)
Fixes #3809. ### Description Adds num_components argument. Defaults to 1 for backwards compatibility. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Richard Brown <[email protected]>
1 parent 8700fee commit d64c6d2

File tree

4 files changed

+85
-25
lines changed

4 files changed

+85
-25
lines changed

monai/transforms/post/array.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def __init__(
273273
is_onehot: Optional[bool] = None,
274274
independent: bool = True,
275275
connectivity: Optional[int] = None,
276+
num_components: int = 1,
276277
) -> None:
277278
"""
278279
Args:
@@ -290,13 +291,15 @@ def __init__(
290291
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
291292
connectivity of ``input.ndim`` is used. for more details:
292293
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
294+
num_components: The number of largest components to preserve.
293295
294296
"""
295297
super().__init__()
296298
self.applied_labels = ensure_tuple(applied_labels) if applied_labels is not None else None
297299
self.is_onehot = is_onehot
298300
self.independent = independent
299301
self.connectivity = connectivity
302+
self.num_components = num_components
300303

301304
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
302305
"""
@@ -316,7 +319,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
316319
if self.independent:
317320
for i in applied_labels:
318321
foreground = img_[i] > 0 if is_onehot else img_[0] == i
319-
mask = get_largest_connected_component_mask(foreground, self.connectivity)
322+
mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)
320323
if is_onehot:
321324
img_[i][foreground != mask] = 0
322325
else:
@@ -325,12 +328,12 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
325328
if not is_onehot: # not one-hot, union of labels
326329
labels, *_ = convert_to_dst_type(applied_labels, dst=img_, wrap_sequence=True)
327330
foreground = (img_[..., None] == labels).any(-1)[0]
328-
mask = get_largest_connected_component_mask(foreground, self.connectivity)
331+
mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)
329332
img_[0][foreground != mask] = 0
330333
return convert_to_dst_type(img_, dst=img)[0]
331334
# one-hot, union of labels
332335
foreground = (img_[applied_labels, ...] == 1).any(0)
333-
mask = get_largest_connected_component_mask(foreground, self.connectivity)
336+
mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)
334337
for i in applied_labels:
335338
img_[i][foreground != mask] = 0
336339
return convert_to_dst_type(img_, dst=img)[0]

monai/transforms/post/dictionary.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def __init__(
204204
is_onehot: Optional[bool] = None,
205205
independent: bool = True,
206206
connectivity: Optional[int] = None,
207+
num_components: int = 1,
207208
allow_missing_keys: bool = False,
208209
) -> None:
209210
"""
@@ -224,12 +225,17 @@ def __init__(
224225
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
225226
connectivity of ``input.ndim`` is used. for more details:
226227
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
228+
num_components: The number of largest components to preserve.
227229
allow_missing_keys: don't raise exception if key is missing.
228230
229231
"""
230232
super().__init__(keys, allow_missing_keys)
231233
self.converter = KeepLargestConnectedComponent(
232-
applied_labels=applied_labels, is_onehot=is_onehot, independent=independent, connectivity=connectivity
234+
applied_labels=applied_labels,
235+
is_onehot=is_onehot,
236+
independent=independent,
237+
connectivity=connectivity,
238+
num_components=num_components,
233239
)
234240

235241
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:

monai/transforms/utils.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@
5757
optional_import,
5858
)
5959
from monai.utils.enums import TransformBackends
60-
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor
60+
from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor
6161

62-
measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
62+
measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
6363
morphology, has_morphology = optional_import("skimage.morphology")
6464
ndimage, _ = optional_import("scipy.ndimage")
6565
cp, has_cp = optional_import("cupy")
@@ -951,7 +951,9 @@ def generate_spatial_bounding_box(
951951
return box_start, box_end
952952

953953

954-
def get_largest_connected_component_mask(img: NdarrayTensor, connectivity: Optional[int] = None) -> NdarrayTensor:
954+
def get_largest_connected_component_mask(
955+
img: NdarrayTensor, connectivity: Optional[int] = None, num_components: int = 1
956+
) -> NdarrayTensor:
955957
"""
956958
Gets the largest connected component mask of an image.
957959
@@ -961,24 +963,40 @@ def get_largest_connected_component_mask(img: NdarrayTensor, connectivity: Optio
961963
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
962964
connectivity of ``input.ndim`` is used. for more details:
963965
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
964-
"""
965-
if isinstance(img, torch.Tensor) and has_cp and has_cucim:
966-
x_cupy = monai.transforms.ToCupy()(img.short())
967-
x_label = cucim.skimage.measure.label(x_cupy, connectivity=connectivity)
968-
vals, counts = cp.unique(x_label[cp.nonzero(x_label)], return_counts=True)
969-
comp = x_label == vals[cp.ndarray.argmax(counts)]
970-
out_tensor = monai.transforms.ToTensor(device=img.device)(comp)
971-
out_tensor = out_tensor.bool()
972-
973-
return out_tensor # type: ignore
974-
975-
img_arr = convert_data_type(img, np.ndarray)[0]
976-
largest_cc: np.ndarray = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype)
977-
img_arr = measure.label(img_arr, connectivity=connectivity)
978-
if img_arr.max() != 0:
979-
largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1)
980-
981-
return convert_to_dst_type(largest_cc, dst=img, dtype=largest_cc.dtype)[0]
966+
num_components: The number of largest components to preserve.
967+
"""
968+
# use skimage/cucim.skimage and np/cp depending on whether packages are
969+
# available and input is non-cpu torch.tensor
970+
use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device("cpu")
971+
if use_cp:
972+
img_ = convert_to_cupy(img.short()) # type: ignore
973+
label = cucim.skimage.measure.label
974+
lib = cp
975+
else:
976+
if not has_measure:
977+
raise RuntimeError("Skimage.measure required.")
978+
img_, *_ = convert_data_type(img, np.ndarray)
979+
label = measure.label
980+
lib = np
981+
982+
# features will be an image -- 0 for background and then each different
983+
# feature will have its own index.
984+
features, num_features = label(img_, connectivity=connectivity, return_num=True)
985+
# if num features less than max desired, nothing to do.
986+
if num_features <= num_components:
987+
out = img_.astype(bool)
988+
else:
989+
# ignore background
990+
nonzeros = features[lib.nonzero(features)]
991+
# get number voxels per feature (bincount). argsort[::-1] to get indices
992+
# of largest components.
993+
features_to_keep = lib.argsort(lib.bincount(nonzeros))[::-1]
994+
# only keep the first n non-background indices
995+
features_to_keep = features_to_keep[:num_components]
996+
# generate labelfield. True if in list of features to keep
997+
out = lib.isin(features, features_to_keep)
998+
999+
return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]
9821000

9831001

9841002
def remove_small_objects(

tests/test_keep_largest_connected_component.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def to_onehot(x):
7878
]
7979
grid_5 = [[[0, 0, 1, 0, 0], [0, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 1, 0], [1, 1, 0, 0, 1]]]
8080

81+
grid_6 = [[[0, 0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1], [1, 1, 0, 0, 1, 0, 1], [0, 0, 0, 1, 0, 0, 1]]]
82+
8183
TESTS = []
8284
for p in TEST_NDARRAYS:
8385
TESTS.append(
@@ -343,6 +345,37 @@ def to_onehot(x):
343345
torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]),
344346
]
345347
)
348+
# no connected regions
349+
TESTS.append(["0 regions", {"num_components": 0}, p(grid_6), p(torch.zeros(1, 4, 7))])
350+
# 1 connected region
351+
TESTS.append(
352+
[
353+
"1 region",
354+
{"num_components": 1},
355+
p(grid_6),
356+
p(
357+
torch.tensor(
358+
[[[0, 0, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0, 0]]]
359+
)
360+
),
361+
]
362+
)
363+
# 2 connected regions
364+
TESTS.append(
365+
[
366+
"2 regions",
367+
{"num_components": 2},
368+
p(grid_6),
369+
p(
370+
torch.tensor(
371+
[[[0, 0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 1, 0, 1], [0, 0, 0, 1, 0, 0, 1]]]
372+
)
373+
),
374+
]
375+
)
376+
# 3+ connected regions unchanged (as input has 3)
377+
for num_connected in (3, 4):
378+
TESTS.append([f"{num_connected} regions", {"num_components": num_connected}, p(grid_6), p(grid_6)])
346379

347380

348381
class TestKeepLargestConnectedComponent(unittest.TestCase):

0 commit comments

Comments
 (0)