57
57
optional_import ,
58
58
)
59
59
from monai .utils .enums import TransformBackends
60
- from monai .utils .type_conversion import convert_data_type , convert_to_dst_type , convert_to_tensor
60
+ from monai .utils .type_conversion import convert_data_type , convert_to_cupy , convert_to_dst_type , convert_to_tensor
61
61
62
- measure , _ = optional_import ("skimage.measure" , "0.14.2" , min_version )
62
+ measure , has_measure = optional_import ("skimage.measure" , "0.14.2" , min_version )
63
63
morphology , has_morphology = optional_import ("skimage.morphology" )
64
64
ndimage , _ = optional_import ("scipy.ndimage" )
65
65
cp , has_cp = optional_import ("cupy" )
@@ -951,7 +951,9 @@ def generate_spatial_bounding_box(
951
951
return box_start , box_end
952
952
953
953
954
- def get_largest_connected_component_mask (img : NdarrayTensor , connectivity : Optional [int ] = None ) -> NdarrayTensor :
954
+ def get_largest_connected_component_mask (
955
+ img : NdarrayTensor , connectivity : Optional [int ] = None , num_components : int = 1
956
+ ) -> NdarrayTensor :
955
957
"""
956
958
Gets the largest connected component mask of an image.
957
959
@@ -961,24 +963,40 @@ def get_largest_connected_component_mask(img: NdarrayTensor, connectivity: Optio
961
963
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
962
964
connectivity of ``input.ndim`` is used. for more details:
963
965
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
964
- """
965
- if isinstance (img , torch .Tensor ) and has_cp and has_cucim :
966
- x_cupy = monai .transforms .ToCupy ()(img .short ())
967
- x_label = cucim .skimage .measure .label (x_cupy , connectivity = connectivity )
968
- vals , counts = cp .unique (x_label [cp .nonzero (x_label )], return_counts = True )
969
- comp = x_label == vals [cp .ndarray .argmax (counts )]
970
- out_tensor = monai .transforms .ToTensor (device = img .device )(comp )
971
- out_tensor = out_tensor .bool ()
972
-
973
- return out_tensor # type: ignore
974
-
975
- img_arr = convert_data_type (img , np .ndarray )[0 ]
976
- largest_cc : np .ndarray = np .zeros (shape = img_arr .shape , dtype = img_arr .dtype )
977
- img_arr = measure .label (img_arr , connectivity = connectivity )
978
- if img_arr .max () != 0 :
979
- largest_cc [...] = img_arr == (np .argmax (np .bincount (img_arr .flat )[1 :]) + 1 )
980
-
981
- return convert_to_dst_type (largest_cc , dst = img , dtype = largest_cc .dtype )[0 ]
966
+ num_components: The number of largest components to preserve.
967
+ """
968
+ # use skimage/cucim.skimage and np/cp depending on whether packages are
969
+ # available and input is non-cpu torch.tensor
970
+ use_cp = has_cp and has_cucim and isinstance (img , torch .Tensor ) and img .device != torch .device ("cpu" )
971
+ if use_cp :
972
+ img_ = convert_to_cupy (img .short ()) # type: ignore
973
+ label = cucim .skimage .measure .label
974
+ lib = cp
975
+ else :
976
+ if not has_measure :
977
+ raise RuntimeError ("Skimage.measure required." )
978
+ img_ , * _ = convert_data_type (img , np .ndarray )
979
+ label = measure .label
980
+ lib = np
981
+
982
+ # features will be an image -- 0 for background and then each different
983
+ # feature will have its own index.
984
+ features , num_features = label (img_ , connectivity = connectivity , return_num = True )
985
+ # if num features less than max desired, nothing to do.
986
+ if num_features <= num_components :
987
+ out = img_ .astype (bool )
988
+ else :
989
+ # ignore background
990
+ nonzeros = features [lib .nonzero (features )]
991
+ # get number voxels per feature (bincount). argsort[::-1] to get indices
992
+ # of largest components.
993
+ features_to_keep = lib .argsort (lib .bincount (nonzeros ))[::- 1 ]
994
+ # only keep the first n non-background indices
995
+ features_to_keep = features_to_keep [:num_components ]
996
+ # generate labelfield. True if in list of features to keep
997
+ out = lib .isin (features , features_to_keep )
998
+
999
+ return convert_to_dst_type (out , dst = img , dtype = out .dtype )[0 ]
982
1000
983
1001
984
1002
def remove_small_objects (
0 commit comments