Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive Item Mapping for Nested Lists in Compose #8187

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
16 changes: 8 additions & 8 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
def execute_compose(
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
transforms: Sequence[Any],
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
start: int = 0,
end: int | None = None,
Expand All @@ -66,7 +66,7 @@ def execute_compose(
data: a tensor-like object to be transformed
transforms: a sequence of transforms to be carried out
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
defaults to `True`. If set to an integer, recursively map the items in `data` `map_items` times.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
start: the index of the first transform to be executed. If not set, this defaults to 0
Expand Down Expand Up @@ -206,7 +206,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
Args:
transforms: sequence of callables.
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
defaults to `True`. If set to an integer, recursively map the items in `data` `map_items` times.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
Expand All @@ -227,7 +227,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
def __init__(
self,
transforms: Sequence[Callable] | Callable | None = None,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = False,
Expand All @@ -238,9 +238,9 @@ def __init__(
if transforms is None:
transforms = []

if not isinstance(map_items, bool):
if not isinstance(map_items, (bool, int)):
raise ValueError(
f"Argument 'map_items' should be boolean. Got {type(map_items)}."
f"Argument 'map_items' should be boolean or int. Got {type(map_items)}."
"Check brackets when passing a sequence of callables."
)

Expand Down Expand Up @@ -392,7 +392,7 @@ class OneOf(Compose):
weights: probabilities corresponding to each callable in transforms.
Probabilities are normalized to sum to one.
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
defaults to `True`. If set to an integer, recursively map the items in `data` `map_items` times.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
Expand All @@ -414,7 +414,7 @@ def __init__(
self,
transforms: Sequence[Callable] | Callable | None = None,
weights: Sequence[float] | float | None = None,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = False,
Expand Down
16 changes: 12 additions & 4 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def _apply_transform(
def apply_transform(
transform: Callable[..., ReturnType],
data: Any,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = None,
overrides: dict | None = None,
) -> list[ReturnType] | ReturnType:
) -> list[Any] | ReturnType:
"""
Transform `data` with `transform`.

Expand All @@ -119,6 +119,7 @@ def apply_transform(
data: an object to be transformed.
map_items: whether to apply transform to each item in `data`,
if `data` is a list or tuple. Defaults to True.
it can also be an int, if so, recursively map the items in `data` `map_items` times.
unpack_items: whether to unpack parameters using `*`. Defaults to False.
log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which
disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the
Expand All @@ -136,8 +137,15 @@ def apply_transform(
Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof.
"""
try:
if isinstance(data, (list, tuple)) and map_items:
return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
if isinstance(data, (list, tuple)) and (map_items or type(map_items) is int):
# if map_items is an int, apply the transform to each item in the list `map_items` times
if type(map_items) is int and map_items > 0:
return [
apply_transform(transform, item, map_items - 2, unpack_items, log_stats, lazy, overrides)
for item in data
]
else:
return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
Expand Down
14 changes: 14 additions & 0 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def b(i, i2):
self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected)
self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected)

def test_list_non_dict_compose_with_unpack_map_2(self):

def a(i, i2):
return i + "a", i2 + "a2"

def b(i, i2):
return i + "b", i2 + "b2"

transforms = [a, b, a, b]
data = [[("", ""), ("", "")], [("t", "t"), ("t", "t")]]
expected = [[("abab", "a2b2a2b2"), ("abab", "a2b2a2b2")], [("tabab", "ta2b2a2b2"), ("tabab", "ta2b2a2b2")]]
self.assertEqual(mt.Compose(transforms, map_items=2, unpack_items=True)(data), expected)
self.assertEqual(execute_compose(data, transforms, map_items=2, unpack_items=True), expected)

def test_list_dict_compose_no_map(self):

def a(d): # transform to handle dict data
Expand Down
Loading