Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
502950b
Updating torch.load To Load Weights Only
ericspod Sep 11, 2025
77d2d82
Autofix
ericspod Sep 11, 2025
79c2cf8
StateCacher should be fine with default pickle protocol
ericspod Sep 11, 2025
5f1f57c
Merge branch 'dev' into torch_load_fix
ericspod Sep 11, 2025
f6f9867
Docstring Update
ericspod Sep 11, 2025
93a5dd1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2025
10b5de3
Removing pickle_operations
ericspod Sep 12, 2025
64221d0
Fixes loading with weights_only for PersistenDataset by force convert…
ericspod Sep 12, 2025
8a75795
Tweak
ericspod Sep 12, 2025
a60569c
Comment unneeded components
ericspod Sep 13, 2025
b54e55d
Modify convert_to_tensor to skip converting primitives
ericspod Sep 14, 2025
7dc3ad3
Merge branch 'pickle_fixes' into torch_load_fix
ericspod Sep 14, 2025
52f8694
Trying safe torch load save usage in place of pickle
ericspod Sep 14, 2025
14e5e6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2025
38a618b
Updates to further remove pickle usage
ericspod Sep 14, 2025
28c7df2
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 14, 2025
77d6992
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2025
79e0966
Autofix
ericspod Sep 14, 2025
2a88d83
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 14, 2025
11c0ee5
Removing commented code
ericspod Sep 14, 2025
2edf46c
Pass argument in recursive call of convert_to_tensor
ericspod Sep 14, 2025
9b171d4
Type fix
ericspod Sep 14, 2025
3d6e0ca
Merge branch 'dev' into torch_load_fix
ericspod Sep 15, 2025
65a7b6d
Merge branch 'dev' into torch_load_fix
KumoLiu Sep 16, 2025
149a5bb
Fixing pickle protocol issue
ericspod Sep 16, 2025
58561b3
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 16, 2025
dd1de4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions monai/apps/nnunet/nnunet_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_nnunet_trainer(
cudnn.benchmark = True

if pretrained_model is not None:
state_dict = torch.load(pretrained_model)
state_dict = torch.load(pretrained_model, weights_only=True)
if "network_weights" in state_dict:
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
return nnunet_trainer
Expand Down Expand Up @@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
parameters = []

checkpoint = torch.load(
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"),
map_location=torch.device("cpu"),
weights_only=True,
)
trainer_name = checkpoint["trainer_name"]
configuration_name = checkpoint["init_args"]["configuration"]
Expand All @@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
else None
)
if Path(model_training_output_dir).joinpath(model_name).is_file():
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
monai_checkpoint = torch.load(
join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True
)
if "network_weights" in monai_checkpoint.keys():
parameters.append(monai_checkpoint["network_weights"])
else:
Expand Down Expand Up @@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str,
dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
)

nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
nnunet_checkpoint_final = torch.load(
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True
)
nnunet_checkpoint_best = torch.load(
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True
)

nnunet_checkpoint = {}
nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
Expand Down Expand Up @@ -470,7 +478,7 @@ def get_network_from_nnunet_plans(
if model_ckpt is None:
return network
else:
state_dict = torch.load(model_ckpt)
state_dict = torch.load(model_ckpt, weights_only=True)
network.load_state_dict(state_dict[model_key_in_ckpt])
return network

Expand Down Expand Up @@ -534,7 +542,7 @@ def subfiles(

Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)

nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True)
latest_checkpoints: list[str] = subfiles(
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
)
Expand All @@ -545,7 +553,7 @@ def subfiles(
epochs.sort()
final_epoch: int = epochs[-1]
monai_last_checkpoint: dict = torch.load(
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True
)

best_checkpoints: list[str] = subfiles(
Expand All @@ -558,7 +566,7 @@ def subfiles(
key_metrics.sort()
best_key_metric: str = key_metrics[-1]
monai_best_checkpoint: dict = torch.load(
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True
)

nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
Expand Down
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
from .thread_buffer import ThreadBuffer, ThreadDataLoader
from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
from .utils import (
PICKLE_KEY_SUFFIX,
# PICKLE_KEY_SUFFIX,
affine_to_spacing,
compute_importance_map,
compute_shape_offset,
Expand Down
34 changes: 26 additions & 8 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import collections.abc
from io import BytesIO
import math
import pickle
import shutil
Expand Down Expand Up @@ -207,6 +208,11 @@ class PersistentDataset(Dataset):
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
errors. If in doubt, it is advisable to clear the cache directory.

Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
be converted to tensors, however any other object type returned by transforms will not be loadable since
`torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
Legacy cache files may not be loadable and may need to be recomputed.

Lazy Resampling:
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
its documentation to familiarize yourself with the interaction between `PersistentDataset` and
Expand Down Expand Up @@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile, weights_only=False)
return torch.load(hashfile, weights_only=True)
except PermissionError as e:
if sys.platform != "win32":
raise e
except RuntimeError as e:
if "Invalid magic number; corrupt file" in str(e):
except (pickle.UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed
if "Invalid magic number; corrupt file" in str(e) or isinstance(e, pickle.UnpicklingError):
warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.")
hashfile.unlink()
else:
Expand All @@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_hash_file = Path(tmpdirname) / hashfile.name
torch.save(
obj=_item_transformed,
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand Down Expand Up @@ -594,6 +600,16 @@ def set_data(self, data: Sequence):
super().set_data(data=data)
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)

def _safe_serialize(self,val):
out=BytesIO()
torch.save(convert_to_tensor(val), out, protocol=self.pickle_protocol)
out.seek(0)
return out.read()

def _safe_deserialize(self,val):
out=BytesIO(val)
return torch.load(out,weights_only=True)

def _fill_cache_start_reader(self, show_progress=True):
"""
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
Expand All @@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
continue
if val is None:
val = self._pre_transform(deepcopy(item)) # keep the original hashed
val = pickle.dumps(val, protocol=self.pickle_protocol)
# val = pickle.dumps(val, protocol=self.pickle_protocol)
val=self._safe_serialize(val)
with env.begin(write=True) as txn:
txn.put(key, val)
done = True
Expand Down Expand Up @@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed):
warnings.warn("LMDBDataset: cache key not found, running fallback caching.")
return super()._cachecheck(item_transformed)
try:
return pickle.loads(data)
# return pickle.loads(data)
return self._safe_deserialize(data)
except Exception as err:
raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err

Expand Down Expand Up @@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
meta_hash_file = self.cache_dir / meta_hash_file_name
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
torch.save(
obj=self._meta_cache[meta_hash_file_name],
obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False),
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand All @@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)
2 changes: 1 addition & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,4 @@ def print_verbose(self) -> None:

# needed in later versions of Pytorch to indicate the class is safe for serialisation
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([MetaTensor])
torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])
63 changes: 32 additions & 31 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import torch
from torch.utils.data._utils.collate import default_collate

from monai import config
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
from monai.data.meta_obj import MetaObj
from monai.utils import (
Expand Down Expand Up @@ -93,7 +92,7 @@
"remove_keys",
"remove_extra_metadata",
"get_extra_metadata_keys",
"PICKLE_KEY_SUFFIX",
# "PICKLE_KEY_SUFFIX",
"is_no_channel",
]

Expand Down Expand Up @@ -418,30 +417,30 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
return


PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX
# PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX


def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
"""
Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate.
# def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True):
# """
# Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate.

Args:
data: a list or dictionary with substructures to be pickled/unpickled.
key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`).
is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False).
"""
if isinstance(data, Mapping):
data = dict(data)
for k in data:
if f"{k}".endswith(key):
if is_encode and not isinstance(data[k], bytes):
data[k] = pickle.dumps(data[k], 0)
if not is_encode and isinstance(data[k], bytes):
data[k] = pickle.loads(data[k])
return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()}
elif isinstance(data, (list, tuple)):
return [pickle_operations(item, key=key, is_encode=is_encode) for item in data]
return data
# Args:
# data: a list or dictionary with substructures to be pickled/unpickled.
# key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`).
# is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False).
# """
# if isinstance(data, Mapping):
# data = dict(data)
# for k in data:
# if f"{k}".endswith(key):
# if is_encode and not isinstance(data[k], bytes):
# data[k] = pickle.dumps(data[k], 0)
# if not is_encode and isinstance(data[k], bytes):
# data[k] = pickle.loads(data[k])
# return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()}
# elif isinstance(data, (list, tuple)):
# return [pickle_operations(item, key=key, is_encode=is_encode) for item in data]
# return data


def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
Expand Down Expand Up @@ -500,8 +499,8 @@ def list_data_collate(batch: Sequence):
key = None
collate_fn = default_collate
try:
if config.USE_META_DICT:
data = pickle_operations(data) # bc 0.9.0
# if config.USE_META_DICT:
# data = pickle_operations(data) # bc 0.9.0
if isinstance(elem, Mapping):
ret = {}
for k in elem:
Expand Down Expand Up @@ -654,15 +653,17 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
if isinstance(deco, Mapping):
_gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values())
ret = [dict(zip(deco, item)) for item in _gen]
if not config.USE_META_DICT:
return ret
return pickle_operations(ret, is_encode=False) # bc 0.9.0
# if not config.USE_META_DICT:
# return ret
# return pickle_operations(ret, is_encode=False) # bc 0.9.0
return ret
if isinstance(deco, Iterable):
_gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco)
ret_list = [list(item) for item in _gen]
if not config.USE_META_DICT:
return ret_list
return pickle_operations(ret_list, is_encode=False) # bc 0.9.0
# if not config.USE_META_DICT:
# return ret_list
# return pickle_operations(ret_list, is_encode=False) # bc 0.9.0
return ret_list
raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.")


Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)

k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any:
fn = self.cached[key]["obj"] # pytype: disable=attribute-error
if not os.path.exists(fn): # pytype: disable=wrong-arg-types
raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.")
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False)
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True)
# copy back to device if necessary
if "device" in self.cached[key]:
data_obj = data_obj.to(self.cached[key]["device"])
Expand Down
4 changes: 3 additions & 1 deletion monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def convert_to_tensor(
wrap_sequence: bool = False,
track_meta: bool = False,
safe: bool = False,
convert_numeric: bool = True
) -> Any:
"""
Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`,
Expand All @@ -136,6 +137,7 @@ def convert_to_tensor(
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`.
If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`.
convert_numeric: if `True`, convert numeric Python values to tensors.

"""

Expand Down Expand Up @@ -167,7 +169,7 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
if data.ndim > 0:
data = np.ascontiguousarray(data)
return _convert_tensor(data, dtype=dtype, device=device)
elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)):
elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))):
return _convert_tensor(data, dtype=dtype, device=device)
elif isinstance(data, list):
list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data]
Expand Down
2 changes: 1 addition & 1 deletion tests/data/meta_tensor/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_pickling(self):
with tempfile.TemporaryDirectory() as tmp_dir:
fname = os.path.join(tmp_dir, "im.pt")
torch.save(m, fname)
m2 = torch.load(fname, weights_only=False)
m2 = torch.load(fname, weights_only=True)
self.check(m2, m, ids=False)

@skip_if_no_cuda
Expand Down
8 changes: 7 additions & 1 deletion tests/utils/test_state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}]
TEST_CASE_1 = [
torch.Tensor([1]).to(DEVICE),
{"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL},
{
"in_memory": False,
"cache_dir": gettempdir(),
"pickle_module": None,
# TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility
"pickle_protocol": torch.serialization.DEFAULT_PROTOCOL,
},
]
TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}]
TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}]
Expand Down
Loading