Skip to content

Commit c9a1ecc

Browse files
authored
Merge branch 'dev' into update-cucim-dep
2 parents 4badb80 + d3d1743 commit c9a1ecc

File tree

14 files changed

+66
-55
lines changed

14 files changed

+66
-55
lines changed

monai/data/dataset.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ def __init__(
575575
cache_rate: float = 1.0,
576576
num_workers: Optional[int] = None,
577577
progress: bool = True,
578+
copy_cache: bool = True,
578579
) -> None:
579580
"""
580581
Args:
@@ -587,11 +588,16 @@ def __init__(
587588
num_workers: the number of worker processes to use.
588589
If num_workers is None then the number returned by os.cpu_count() is used.
589590
progress: whether to display a progress bar.
591+
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
592+
default to `True`. if the random transforms don't modify the cache content
593+
or every cache item is only used once in a `multi-processing` environment,
594+
may set `copy=False` for better performance.
590595
"""
591596
if not isinstance(transform, Compose):
592597
transform = Compose(transform)
593598
super().__init__(data=data, transform=transform)
594599
self.progress = progress
600+
self.copy_cache = copy_cache
595601
self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data))
596602
self.num_workers = num_workers
597603
if self.num_workers is not None:
@@ -656,7 +662,8 @@ def _transform(self, index: int):
656662
# only need to deep copy data on first non-deterministic transform
657663
if not start_run:
658664
start_run = True
659-
data = deepcopy(data)
665+
if self.copy_cache:
666+
data = deepcopy(data)
660667
data = apply_transform(_transform, data)
661668
return data
662669

@@ -722,6 +729,10 @@ class SmartCacheDataset(Randomizable, CacheDataset):
722729
shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch.
723730
it will not modify the original input data sequence in-place.
724731
seed: random seed if shuffle is `True`, default to `0`.
732+
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
733+
default to `True`. if the random transforms don't modify the cache content
734+
or every cache item is only used once in a `multi-processing` environment,
735+
may set `copy=False` for better performance.
725736
"""
726737

727738
def __init__(
@@ -736,14 +747,15 @@ def __init__(
736747
progress: bool = True,
737748
shuffle: bool = True,
738749
seed: int = 0,
750+
copy_cache: bool = True,
739751
) -> None:
740752
if shuffle:
741753
self.set_random_state(seed=seed)
742754
data = copy(data)
743755
self.randomize(data)
744756
self.shuffle = shuffle
745757

746-
super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress)
758+
super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache)
747759
if self._cache is None:
748760
self._cache = self._fill_cache()
749761
if self.cache_num >= len(data):

monai/handlers/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,12 @@ class mean median max 5percentile 95percentile notnans
130130
if summary_ops is not None:
131131
supported_ops = OrderedDict(
132132
{
133-
"mean": lambda x: np.nanmean(x),
134-
"median": lambda x: np.nanmedian(x),
135-
"max": lambda x: np.nanmax(x),
136-
"min": lambda x: np.nanmin(x),
133+
"mean": np.nanmean,
134+
"median": np.nanmedian,
135+
"max": np.nanmax,
136+
"min": np.nanmin,
137137
"90percentile": lambda x: np.nanpercentile(x[0], x[1]),
138-
"std": lambda x: np.nanstd(x),
138+
"std": np.nanstd,
139139
"notnans": lambda x: (~np.isnan(x)).sum(),
140140
}
141141
)
@@ -149,7 +149,7 @@ def _compute_op(op: str, d: np.ndarray):
149149
return c_op(d)
150150

151151
threshold = int(op.split("percentile")[0])
152-
return supported_ops["90percentile"]((d, threshold))
152+
return supported_ops["90percentile"]((d, threshold)) # type: ignore
153153

154154
with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f:
155155
f.write(f"class{deli}{deli.join(ops)}\n")

monai/transforms/compose.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,13 @@ def __init__(
204204
def _normalize_probabilities(self, weights):
205205
if len(weights) == 0:
206206
return weights
207-
else:
208-
weights = np.array(weights)
209-
if np.any(weights < 0):
210-
raise AssertionError("Probabilities must be greater than or equal to zero.")
211-
if np.all(weights == 0):
212-
raise AssertionError("At least one probability must be greater than zero.")
213-
weights = weights / weights.sum()
214-
return list(weights)
207+
weights = np.array(weights)
208+
if np.any(weights < 0):
209+
raise AssertionError("Probabilities must be greater than or equal to zero.")
210+
if np.all(weights == 0):
211+
raise AssertionError("At least one probability must be greater than zero.")
212+
weights = weights / weights.sum()
213+
return list(weights)
215214

216215
def flatten(self):
217216
transforms = []
@@ -232,16 +231,15 @@ def flatten(self):
232231
def __call__(self, data):
233232
if len(self.transforms) == 0:
234233
return data
235-
else:
236-
index = self.R.multinomial(1, self.weights).argmax()
237-
_transform = self.transforms[index]
238-
data = apply_transform(_transform, data, self.map_items, self.unpack_items)
239-
# if the data is a mapping (dictionary), append the OneOf transform to the end
240-
if isinstance(data, Mapping):
241-
for key in data.keys():
242-
if key + InverseKeys.KEY_SUFFIX in data:
243-
self.push_transform(data, key, extra_info={"index": index})
244-
return data
234+
index = self.R.multinomial(1, self.weights).argmax()
235+
_transform = self.transforms[index]
236+
data = apply_transform(_transform, data, self.map_items, self.unpack_items)
237+
# if the data is a mapping (dictionary), append the OneOf transform to the end
238+
if isinstance(data, Mapping):
239+
for key in data.keys():
240+
if key + InverseKeys.KEY_SUFFIX in data:
241+
self.push_transform(data, key, extra_info={"index": index})
242+
return data
245243

246244
def inverse(self, data):
247245
if len(self.transforms) == 0:

monai/transforms/croppad/dictionary.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,6 @@ def __init__(
663663
random_size=random_size,
664664
allow_missing_keys=allow_missing_keys,
665665
)
666-
MapTransform.__init__(self, keys, allow_missing_keys)
667666
self.roi_scale = roi_scale
668667
self.max_roi_scale = max_roi_scale
669668

monai/transforms/intensity/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1457,7 +1457,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
14571457
raise RuntimeError("Image needs a channel direction.")
14581458
if isinstance(self.loc[0], int) and len(img.shape) == 4 and len(self.loc) == 2:
14591459
raise RuntimeError("Input images of dimension 4 need location tuple to be length 3 or 4")
1460-
if isinstance(self.loc[0], Sequence) and len(img.shape) == 4 and min(map(lambda x: len(x), self.loc)) == 2:
1460+
if isinstance(self.loc[0], Sequence) and len(img.shape) == 4 and min(map(len, self.loc)) == 2:
14611461
raise RuntimeError("Input images of dimension 4 need location tuple to be length 3 or 4")
14621462

14631463
n_dims = len(img.shape[1:])

monai/transforms/post/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __call__(
205205

206206
rounding = self.rounding if rounding is None else rounding
207207
if rounding is not None:
208-
rounding = look_up_option(rounding, ["torchrounding"])
208+
look_up_option(rounding, ["torchrounding"])
209209
img = torch.round(img)
210210

211211
return img.float()

monai/transforms/spatial/array.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
335335
"""
336336
if isinstance(img, np.ndarray):
337337
return np.ascontiguousarray(np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)))
338-
else:
339-
return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))
338+
return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))
340339

341340

342341
class Resize(Transform):

monai/transforms/utility/array.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,8 @@ def __call__(self, data: NdarrayOrTensor):
391391
if self.data_type == "tensor":
392392
dtype_ = get_equivalent_dtype(self.dtype, torch.Tensor)
393393
return convert_to_tensor(data, dtype=dtype_, device=self.device)
394-
else:
395-
dtype_ = get_equivalent_dtype(self.dtype, np.ndarray)
396-
return convert_to_numpy(data, dtype=dtype_)
394+
dtype_ = get_equivalent_dtype(self.dtype, np.ndarray)
395+
return convert_to_numpy(data, dtype=dtype_)
397396

398397

399398
class ToNumpy(Transform):
@@ -1091,11 +1090,11 @@ def __call__(
10911090
img_ = img[mask]
10921091

10931092
supported_ops = {
1094-
"mean": lambda x: np.nanmean(x),
1095-
"median": lambda x: np.nanmedian(x),
1096-
"max": lambda x: np.nanmax(x),
1097-
"min": lambda x: np.nanmin(x),
1098-
"std": lambda x: np.nanstd(x),
1093+
"mean": np.nanmean,
1094+
"median": np.nanmedian,
1095+
"max": np.nanmax,
1096+
"min": np.nanmin,
1097+
"std": np.nanstd,
10991098
}
11001099

11011100
def _compute(op: Callable, data: np.ndarray):
@@ -1107,7 +1106,7 @@ def _compute(op: Callable, data: np.ndarray):
11071106
for o in self.ops:
11081107
if isinstance(o, str):
11091108
o = look_up_option(o, supported_ops.keys())
1110-
meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_)
1109+
meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_) # type: ignore
11111110
elif callable(o):
11121111
meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_)
11131112
custom_index += 1

monai/transforms/utility/dictionary.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
Class names are ended with 'd' to denote dictionary-based transforms.
1616
"""
1717

18-
import copy
1918
import logging
2019
import re
2120
from copy import deepcopy
@@ -886,7 +885,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
886885
if isinstance(val, torch.Tensor):
887886
d[new_key] = val.detach().clone()
888887
else:
889-
d[new_key] = copy.deepcopy(val)
888+
d[new_key] = deepcopy(val)
890889
return d
891890

892891

monai/transforms/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
import torch
2121

2222
import monai
23-
import monai.transforms.transform
2423
from monai.config import DtypeLike, IndexSelection
2524
from monai.config.type_definitions import NdarrayOrTensor
2625
from monai.networks.layers import GaussianFilter
2726
from monai.transforms.compose import Compose, OneOf
28-
from monai.transforms.transform import MapTransform, Transform
27+
from monai.transforms.transform import MapTransform, Transform, apply_transform
2928
from monai.transforms.utils_pytorch_numpy_unification import any_np_pt, nonzero, ravel, unravel_index
3029
from monai.utils import (
3130
GridSampleMode,
@@ -1330,9 +1329,7 @@ def _get_data(obj, key):
13301329
prev_data = _get_data(test_data, key)
13311330
prev_type = type(prev_data)
13321331
prev_device = prev_data.device if isinstance(prev_data, torch.Tensor) else None
1333-
test_data = monai.transforms.transform.apply_transform(
1334-
_transform, test_data, transform.map_items, transform.unpack_items
1335-
)
1332+
test_data = apply_transform(_transform, test_data, transform.map_items, transform.unpack_items)
13361333
# every time the type or device changes, increment the counter
13371334
curr_data = _get_data(test_data, key)
13381335
curr_device = curr_data.device if isinstance(curr_data, torch.Tensor) else None

0 commit comments

Comments
 (0)