Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive Item Mapping for Nested Lists in Compose #8187

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
37 changes: 26 additions & 11 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
def execute_compose(
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
transforms: Sequence[Any],
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
start: int = 0,
end: int | None = None,
Expand All @@ -65,8 +65,13 @@ def execute_compose(
Args:
data: a tensor-like object to be transformed
transforms: a sequence of transforms to be carried out
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
it can behave as follows:
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
to the first level of items in `data`.
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
should be recursively applied. This allows treating multi-sample transforms applied after another
multi-sample transform while controlling how deep the mapping goes.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
start: the index of the first transform to be executed. If not set, this defaults to 0
Expand Down Expand Up @@ -205,8 +210,13 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):

Args:
transforms: sequence of callables.
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
it can behave as follows:
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
to the first level of items in `data`.
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
should be recursively applied. This allows treating multi-sample transforms applied after another
multi-sample transform while controlling how deep the mapping goes.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
Expand All @@ -227,7 +237,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
def __init__(
self,
transforms: Sequence[Callable] | Callable | None = None,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = False,
Expand All @@ -238,9 +248,9 @@ def __init__(
if transforms is None:
transforms = []

if not isinstance(map_items, bool):
if not isinstance(map_items, (bool, int)):
raise ValueError(
f"Argument 'map_items' should be boolean. Got {type(map_items)}."
f"Argument 'map_items' should be boolean or int. Got {type(map_items)}."
"Check brackets when passing a sequence of callables."
)

Expand Down Expand Up @@ -391,8 +401,13 @@ class OneOf(Compose):
transforms: sequence of callables.
weights: probabilities corresponding to each callable in transforms.
Probabilities are normalized to sum to one.
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
it can behave as follows:
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
to the first level of items in `data`.
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
should be recursively applied. This allows treating multi-sample transforms applied after another
multi-sample transform while controlling how deep the mapping goes.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
Expand All @@ -414,7 +429,7 @@ def __init__(
self,
transforms: Sequence[Callable] | Callable | None = None,
weights: Sequence[float] | float | None = None,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = False,
Expand Down
21 changes: 15 additions & 6 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def _apply_transform(
def apply_transform(
transform: Callable[..., ReturnType],
data: Any,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = None,
overrides: dict | None = None,
) -> list[ReturnType] | ReturnType:
) -> list[Any] | ReturnType:
"""
Transform `data` with `transform`.

Expand All @@ -117,8 +117,13 @@ def apply_transform(
Args:
transform: a callable to be used to transform `data`.
data: an object to be transformed.
map_items: whether to apply transform to each item in `data`,
if `data` is a list or tuple. Defaults to True.
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
it can behave as follows:
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
to the first level of items in `data`.
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
should be recursively applied. This allows treating multi-sample transforms applied after another
multi-sample transform while controlling how deep the mapping goes.
unpack_items: whether to unpack parameters using `*`. Defaults to False.
log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which
disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the
Expand All @@ -136,8 +141,12 @@ def apply_transform(
Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof.
"""
try:
if isinstance(data, (list, tuple)) and map_items:
return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0:
return [
apply_transform(transform, item, map_items - 1, unpack_items, log_stats, lazy, overrides)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
for item in data
]
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
Expand Down
14 changes: 14 additions & 0 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def b(i, i2):
self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected)
self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected)

def test_list_non_dict_compose_with_unpack_map_2(self):

def a(i, i2):
return i + "a", i2 + "a2"

def b(i, i2):
return i + "b", i2 + "b2"

transforms = [a, b, a, b]
data = [[("", ""), ("", "")], [("t", "t"), ("t", "t")]]
expected = [[("abab", "a2b2a2b2"), ("abab", "a2b2a2b2")], [("tabab", "ta2b2a2b2"), ("tabab", "ta2b2a2b2")]]
self.assertEqual(mt.Compose(transforms, map_items=2, unpack_items=True)(data), expected)
self.assertEqual(execute_compose(data, transforms, map_items=2, unpack_items=True), expected)

def test_list_dict_compose_no_map(self):

def a(d): # transform to handle dict data
Expand Down
Loading