1313
1414import collections .abc
1515import math
16- import pickle
1716import shutil
1817import sys
1918import tempfile
2221import warnings
2322from collections .abc import Callable , Sequence
2423from copy import copy , deepcopy
24+ from io import BytesIO
2525from multiprocessing .managers import ListProxy
2626from multiprocessing .pool import ThreadPool
2727from pathlib import Path
28+ from pickle import UnpicklingError
2829from typing import IO , TYPE_CHECKING , Any , cast
2930
3031import numpy as np
@@ -207,6 +208,11 @@ class PersistentDataset(Dataset):
207208 not guaranteed, so caution should be used when modifying transforms to avoid unexpected
208209 errors. If in doubt, it is advisable to clear the cache directory.
209210
211+ Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
212+ be converted to tensors, however any other object type returned by transforms will not be loadable since
213+ `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
214+ Legacy cache files may not be loadable and may need to be recomputed.
215+
210216 Lazy Resampling:
211217 If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
212218 its documentation to familiarize yourself with the interaction between `PersistentDataset` and
@@ -248,8 +254,8 @@ def __init__(
248254 this arg is used by `torch.save`, for more details, please check:
249255 https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
250256 and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
251- pickle_protocol: can be specified to override the default protocol, default to `2 `.
252- this arg is used by ` torch.save`, for more details, please check:
257+ pickle_protocol: specifies pickle protocol when saving, with `torch.save `.
258+ Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
253259 https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
254260 hash_transform: a callable to compute hash from the transform information when caching.
255261 This may reduce errors due to transforms changing during experiments. Default to None (no hash).
@@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed):
371377
372378 if hashfile is not None and hashfile .is_file (): # cache hit
373379 try :
374- return torch .load (hashfile , weights_only = False )
380+ return torch .load (hashfile , weights_only = True )
375381 except PermissionError as e :
376382 if sys .platform != "win32" :
377383 raise e
378- except RuntimeError as e :
379- if "Invalid magic number; corrupt file" in str (e ):
384+ except ( UnpicklingError , RuntimeError ) as e : # corrupt or unloadable cached files are recomputed
385+ if "Invalid magic number; corrupt file" in str (e ) or isinstance ( e , UnpicklingError ) :
380386 warnings .warn (f"Corrupt cache file detected: { hashfile } . Deleting and recomputing." )
381387 hashfile .unlink ()
382388 else :
@@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed):
392398 with tempfile .TemporaryDirectory () as tmpdirname :
393399 temp_hash_file = Path (tmpdirname ) / hashfile .name
394400 torch .save (
395- obj = _item_transformed ,
401+ obj = convert_to_tensor ( _item_transformed , convert_numeric = False ) ,
396402 f = temp_hash_file ,
397403 pickle_module = look_up_option (self .pickle_module , SUPPORTED_PICKLE_MOD ),
398404 pickle_protocol = self .pickle_protocol ,
@@ -455,8 +461,8 @@ def __init__(
455461 this arg is used by `torch.save`, for more details, please check:
456462 https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
457463 and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
458- pickle_protocol: can be specified to override the default protocol, default to `2 `.
459- this arg is used by ` torch.save`, for more details, please check:
464+ pickle_protocol: specifies pickle protocol when saving, with `torch.save `.
465+ Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
460466 https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
461467 hash_transform: a callable to compute hash from the transform information when caching.
462468 This may reduce errors due to transforms changing during experiments. Default to None (no hash).
@@ -531,7 +537,7 @@ def __init__(
531537 hash_func : Callable [..., bytes ] = pickle_hashing ,
532538 db_name : str = "monai_cache" ,
533539 progress : bool = True ,
534- pickle_protocol = pickle . HIGHEST_PROTOCOL ,
540+ pickle_protocol = DEFAULT_PROTOCOL ,
535541 hash_transform : Callable [..., bytes ] | None = None ,
536542 reset_ops_id : bool = True ,
537543 lmdb_kwargs : dict | None = None ,
@@ -551,8 +557,9 @@ def __init__(
551557 defaults to `monai.data.utils.pickle_hashing`.
552558 db_name: lmdb database file name. Defaults to "monai_cache".
553559 progress: whether to display a progress bar.
554- pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
555- https://docs.python.org/3/library/pickle.html#pickle-protocols
560+ pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
561+ Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
562+ https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
556563 hash_transform: a callable to compute hash from the transform information when caching.
557564 This may reduce errors due to transforms changing during experiments. Default to None (no hash).
558565 Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
@@ -594,6 +601,15 @@ def set_data(self, data: Sequence):
594601 super ().set_data (data = data )
595602 self ._read_env = self ._fill_cache_start_reader (show_progress = self .progress )
596603
604+ def _safe_serialize (self , val ):
605+ out = BytesIO ()
606+ torch .save (convert_to_tensor (val ), out , pickle_protocol = self .pickle_protocol )
607+ out .seek (0 )
608+ return out .read ()
609+
610+ def _safe_deserialize (self , val ):
611+ return torch .load (BytesIO (val ), map_location = "cpu" , weights_only = True )
612+
597613 def _fill_cache_start_reader (self , show_progress = True ):
598614 """
599615 Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
@@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
619635 continue
620636 if val is None :
621637 val = self ._pre_transform (deepcopy (item )) # keep the original hashed
622- val = pickle .dumps (val , protocol = self .pickle_protocol )
638+ # val = pickle.dumps(val, protocol=self.pickle_protocol)
639+ val = self ._safe_serialize (val )
623640 with env .begin (write = True ) as txn :
624641 txn .put (key , val )
625642 done = True
@@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed):
664681 warnings .warn ("LMDBDataset: cache key not found, running fallback caching." )
665682 return super ()._cachecheck (item_transformed )
666683 try :
667- return pickle .loads (data )
684+ # return pickle.loads(data)
685+ return self ._safe_deserialize (data )
668686 except Exception as err :
669687 raise RuntimeError ("Invalid cache value, corrupted lmdb file?" ) from err
670688
@@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
16501668 meta_hash_file = self .cache_dir / meta_hash_file_name
16511669 temp_hash_file = Path (tmpdirname ) / meta_hash_file_name
16521670 torch .save (
1653- obj = self ._meta_cache [meta_hash_file_name ],
1671+ obj = convert_to_tensor ( self ._meta_cache [meta_hash_file_name ], convert_numeric = False ) ,
16541672 f = temp_hash_file ,
16551673 pickle_module = look_up_option (self .pickle_module , SUPPORTED_PICKLE_MOD ),
16561674 pickle_protocol = self .pickle_protocol ,
@@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name):
16701688 if meta_hash_file_name in self ._meta_cache :
16711689 return self ._meta_cache [meta_hash_file_name ]
16721690 else :
1673- return torch .load (self .cache_dir / meta_hash_file_name , weights_only = False )
1691+ return torch .load (self .cache_dir / meta_hash_file_name , weights_only = True )
0 commit comments