Skip to content

Commit d3d1743

Browse files
Nic-Mamonai-bot
andauthored
3000 Support not copy in CacheDataset (#3001)
* [DLMED] add copy option Signed-off-by: Nic Ma <[email protected]> * [DLMED] enhance test Signed-off-by: Nic Ma <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> Co-authored-by: monai-bot <[email protected]>
1 parent 790fc8f commit d3d1743

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

monai/data/dataset.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ def __init__(
575575
cache_rate: float = 1.0,
576576
num_workers: Optional[int] = None,
577577
progress: bool = True,
578+
copy_cache: bool = True,
578579
) -> None:
579580
"""
580581
Args:
@@ -587,11 +588,16 @@ def __init__(
587588
num_workers: the number of worker processes to use.
588589
If num_workers is None then the number returned by os.cpu_count() is used.
589590
progress: whether to display a progress bar.
591+
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
592+
default to `True`. if the random transforms don't modify the cache content
593+
or every cache item is only used once in a `multi-processing` environment,
594+
may set `copy=False` for better performance.
590595
"""
591596
if not isinstance(transform, Compose):
592597
transform = Compose(transform)
593598
super().__init__(data=data, transform=transform)
594599
self.progress = progress
600+
self.copy_cache = copy_cache
595601
self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data))
596602
self.num_workers = num_workers
597603
if self.num_workers is not None:
@@ -656,7 +662,8 @@ def _transform(self, index: int):
656662
# only need to deep copy data on first non-deterministic transform
657663
if not start_run:
658664
start_run = True
659-
data = deepcopy(data)
665+
if self.copy_cache:
666+
data = deepcopy(data)
660667
data = apply_transform(_transform, data)
661668
return data
662669

@@ -722,6 +729,10 @@ class SmartCacheDataset(Randomizable, CacheDataset):
722729
shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch.
723730
it will not modify the original input data sequence in-place.
724731
seed: random seed if shuffle is `True`, default to `0`.
732+
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
733+
default to `True`. if the random transforms don't modify the cache content
734+
or every cache item is only used once in a `multi-processing` environment,
735+
may set `copy=False` for better performance.
725736
"""
726737

727738
def __init__(
@@ -736,14 +747,15 @@ def __init__(
736747
progress: bool = True,
737748
shuffle: bool = True,
738749
seed: int = 0,
750+
copy_cache: bool = True,
739751
) -> None:
740752
if shuffle:
741753
self.set_random_state(seed=seed)
742754
data = copy(data)
743755
self.randomize(data)
744756
self.shuffle = shuffle
745757

746-
super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress)
758+
super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache)
747759
if self._cache is None:
748760
self._cache = self._fill_cache()
749761
if self.cache_num >= len(data):

tests/test_cachedataset.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from parameterized import parameterized
2020

2121
from monai.data import CacheDataset, DataLoader, PersistentDataset, SmartCacheDataset
22-
from monai.transforms import Compose, Lambda, LoadImaged, ThreadUnsafe, Transform
22+
from monai.transforms import Compose, Lambda, LoadImaged, RandLambda, ThreadUnsafe, Transform
2323
from monai.utils import get_torch_version_tuple
2424

2525
TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label", "extra"])]), (128, 128, 128)]
@@ -84,27 +84,36 @@ def test_shape(self, transform, expected_shape):
8484
def test_set_data(self):
8585
data_list1 = list(range(10))
8686

87-
transform = Lambda(func=lambda x: np.array([x * 10]))
87+
transform = Compose(
88+
[
89+
Lambda(func=lambda x: np.array([x * 10])),
90+
RandLambda(func=lambda x: x + 1),
91+
]
92+
)
8893

8994
dataset = CacheDataset(
9095
data=data_list1,
9196
transform=transform,
9297
cache_rate=1.0,
9398
num_workers=4,
9499
progress=True,
100+
copy_cache=False if sys.platform == "linux" else True,
95101
)
96102

97103
num_workers = 2 if sys.platform == "linux" else 0
98104
dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1)
99105
for i, d in enumerate(dataloader):
100-
np.testing.assert_allclose([[data_list1[i] * 10]], d)
106+
np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d)
107+
# simulate another epoch, the cache content should not be modified
108+
for i, d in enumerate(dataloader):
109+
np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d)
101110

102111
# update the datalist and fill the cache content
103112
data_list2 = list(range(-10, 0))
104113
dataset.set_data(data=data_list2)
105114
# rerun with updated cache content
106115
for i, d in enumerate(dataloader):
107-
np.testing.assert_allclose([[data_list2[i] * 10]], d)
116+
np.testing.assert_allclose([[data_list2[i] * 10 + 1]], d)
108117

109118

110119
class _StatefulTransform(Transform, ThreadUnsafe):

0 commit comments

Comments
 (0)