Skip to content

Commit 8f679cc

Browse files
authored
Merge pull request #138 from Project-MONAI/master
1005 Support different margin for dims in CropForground (#1011)
2 parents 3b9280b + d958ab5 commit 8f679cc

File tree

6 files changed

+37
-12
lines changed

6 files changed

+37
-12
lines changed

monai/transforms/croppad/array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,14 +384,17 @@ class CropForeground(Transform):
384384
"""
385385

386386
def __init__(
387-
self, select_fn: Callable = lambda x: x > 0, channel_indices: Optional[IndexSelection] = None, margin: int = 0
387+
self,
388+
select_fn: Callable = lambda x: x > 0,
389+
channel_indices: Optional[IndexSelection] = None,
390+
margin: Union[Sequence[int], int] = 0,
388391
) -> None:
389392
"""
390393
Args:
391394
select_fn: function to select expected foreground, default is to select values > 0.
392395
channel_indices: if defined, select foreground only on the specified channels
393396
of image. if None, select foreground on the whole image.
394-
margin: add margin to all dims of the bounding box.
397+
margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
395398
"""
396399
self.select_fn = select_fn
397400
self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None

monai/transforms/croppad/dictionary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def __init__(
336336
select_fn: function to select expected foreground, default is to select values > 0.
337337
channel_indices: if defined, select foreground only on the specified channels
338338
of image. if None, select foreground on the whole image.
339-
margin: add margin to all dims of the bounding box.
339+
margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
340340
"""
341341
super().__init__(keys)
342342
self.source_key = source_key

monai/transforms/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818

1919
from monai.config import IndexSelection
20-
from monai.utils import ensure_tuple, ensure_tuple_size, fall_back_tuple, min_version, optional_import
20+
from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import
2121

2222
measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
2323

@@ -455,7 +455,7 @@ def generate_spatial_bounding_box(
455455
img: np.ndarray,
456456
select_fn: Callable = lambda x: x > 0,
457457
channel_indices: Optional[IndexSelection] = None,
458-
margin: int = 0,
458+
margin: Union[Sequence[int], int] = 0,
459459
) -> Tuple[List[int], List[int]]:
460460
"""
461461
generate the spatial bounding box of foreground in the image with start-end positions.
@@ -467,19 +467,19 @@ def generate_spatial_bounding_box(
467467
select_fn: function to select expected foreground, default is to select values > 0.
468468
channel_indices: if defined, select foreground only on the specified channels
469469
of image. if None, select foreground on the whole image.
470-
margin: add margin to all dims of the bounding box.
470+
margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
471471
"""
472-
assert isinstance(margin, int), "margin must be int type."
473472
data = img[[*(ensure_tuple(channel_indices))]] if channel_indices is not None else img
474473
data = np.any(select_fn(data), axis=0)
475474
nonzero_idx = np.nonzero(data)
475+
margin = ensure_tuple_rep(margin, data.ndim)
476476

477477
box_start = list()
478478
box_end = list()
479479
for i in range(data.ndim):
480480
assert len(nonzero_idx[i]) > 0, f"did not find nonzero index at spatial dim {i}"
481-
box_start.append(max(0, np.min(nonzero_idx[i]) - margin))
482-
box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin + 1))
481+
box_start.append(max(0, np.min(nonzero_idx[i]) - margin[i]))
482+
box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin[i] + 1))
483483
return box_start, box_end
484484

485485

tests/test_crop_foreground.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,15 @@
4040
np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]),
4141
]
4242

43+
TEST_CASE_5 = [
44+
{"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]},
45+
np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),
46+
np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),
47+
]
48+
4349

4450
class TestCropForeground(unittest.TestCase):
45-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
51+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
4652
def test_value(self, argments, image, expected_data):
4753
result = CropForeground(**argments)(image)
4854
np.testing.assert_allclose(result, expected_data)

tests/test_crop_foregroundd.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,15 @@
4949
np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]),
5050
]
5151

52+
TEST_CASE_5 = [
53+
{"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]},
54+
{"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])},
55+
np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),
56+
]
57+
5258

5359
class TestCropForegroundd(unittest.TestCase):
54-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
60+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
5561
def test_value(self, argments, image, expected_data):
5662
result = CropForegroundd(**argments)(image)
5763
np.testing.assert_allclose(result["img"], expected_data)

tests/test_generate_spatial_bounding_box.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,19 @@
5656
([0, 0], [4, 5]),
5757
]
5858

59+
TEST_CASE_5 = [
60+
{
61+
"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),
62+
"select_fn": lambda x: x > 0,
63+
"channel_indices": None,
64+
"margin": [2, 1],
65+
},
66+
([0, 0], [5, 5]),
67+
]
68+
5969

6070
class TestGenerateSpatialBoundingBox(unittest.TestCase):
61-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
71+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
6272
def test_value(self, input_data, expected_box):
6373
result = generate_spatial_bounding_box(**input_data)
6474
self.assertTupleEqual(result, expected_box)

0 commit comments

Comments
 (0)