@@ -575,6 +575,7 @@ def __init__(
575575 cache_rate : float = 1.0 ,
576576 num_workers : Optional [int ] = None ,
577577 progress : bool = True ,
578+ copy_cache : bool = True ,
578579 ) -> None :
579580 """
580581 Args:
@@ -587,11 +588,16 @@ def __init__(
587588 num_workers: the number of worker processes to use.
588589 If num_workers is None then the number returned by os.cpu_count() is used.
589590 progress: whether to display a progress bar.
591+ copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
592+ default to `True`. if the random transforms don't modify the cache content
593+ or every cache item is only used once in a `multi-processing` environment,
594+ may set `copy=False` for better performance.
590595 """
591596 if not isinstance (transform , Compose ):
592597 transform = Compose (transform )
593598 super ().__init__ (data = data , transform = transform )
594599 self .progress = progress
600+ self .copy_cache = copy_cache
595601 self .cache_num = min (int (cache_num ), int (len (data ) * cache_rate ), len (data ))
596602 self .num_workers = num_workers
597603 if self .num_workers is not None :
@@ -656,7 +662,8 @@ def _transform(self, index: int):
656662 # only need to deep copy data on first non-deterministic transform
657663 if not start_run :
658664 start_run = True
659- data = deepcopy (data )
665+ if self .copy_cache :
666+ data = deepcopy (data )
660667 data = apply_transform (_transform , data )
661668 return data
662669
@@ -722,6 +729,10 @@ class SmartCacheDataset(Randomizable, CacheDataset):
722729 shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch.
723730 it will not modify the original input data sequence in-place.
724731 seed: random seed if shuffle is `True`, default to `0`.
732+ copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
733+ default to `True`. if the random transforms don't modify the cache content
734+ or every cache item is only used once in a `multi-processing` environment,
735+ may set `copy=False` for better performance.
725736 """
726737
727738 def __init__ (
@@ -736,14 +747,15 @@ def __init__(
736747 progress : bool = True ,
737748 shuffle : bool = True ,
738749 seed : int = 0 ,
750+ copy_cache : bool = True ,
739751 ) -> None :
740752 if shuffle :
741753 self .set_random_state (seed = seed )
742754 data = copy (data )
743755 self .randomize (data )
744756 self .shuffle = shuffle
745757
746- super ().__init__ (data , transform , cache_num , cache_rate , num_init_workers , progress )
758+ super ().__init__ (data , transform , cache_num , cache_rate , num_init_workers , progress , copy_cache )
747759 if self ._cache is None :
748760 self ._cache = self ._fill_cache ()
749761 if self .cache_num >= len (data ):
0 commit comments