Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
502950b
Updating torch.load To Load Weights Only
ericspod Sep 11, 2025
77d2d82
Autofix
ericspod Sep 11, 2025
79c2cf8
StateCacher should be fine with default pickle protocol
ericspod Sep 11, 2025
5f1f57c
Merge branch 'dev' into torch_load_fix
ericspod Sep 11, 2025
f6f9867
Docstring Update
ericspod Sep 11, 2025
93a5dd1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2025
10b5de3
Removing pickle_operations
ericspod Sep 12, 2025
64221d0
Fixes loading with weights_only for PersistenDataset by force convert…
ericspod Sep 12, 2025
8a75795
Tweak
ericspod Sep 12, 2025
a60569c
Comment unneeded components
ericspod Sep 13, 2025
b54e55d
Modify convert_to_tensor to skip converting primitives
ericspod Sep 14, 2025
7dc3ad3
Merge branch 'pickle_fixes' into torch_load_fix
ericspod Sep 14, 2025
52f8694
Trying safe torch load save usage in place of pickle
ericspod Sep 14, 2025
14e5e6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2025
38a618b
Updates to further remove pickle usage
ericspod Sep 14, 2025
28c7df2
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 14, 2025
77d6992
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 14, 2025
79e0966
Autofix
ericspod Sep 14, 2025
2a88d83
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 14, 2025
11c0ee5
Removing commented code
ericspod Sep 14, 2025
2edf46c
Pass argument in recursive call of convert_to_tensor
ericspod Sep 14, 2025
9b171d4
Type fix
ericspod Sep 14, 2025
3d6e0ca
Merge branch 'dev' into torch_load_fix
ericspod Sep 15, 2025
65a7b6d
Merge branch 'dev' into torch_load_fix
KumoLiu Sep 16, 2025
149a5bb
Fixing pickle protocol issue
ericspod Sep 16, 2025
58561b3
Merge branch 'torch_load_fix' of github.com:ericspod/MONAI into torch…
ericspod Sep 16, 2025
dd1de4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions monai/apps/nnunet/nnunet_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_nnunet_trainer(
cudnn.benchmark = True

if pretrained_model is not None:
state_dict = torch.load(pretrained_model)
state_dict = torch.load(pretrained_model, weights_only=True)
if "network_weights" in state_dict:
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
return nnunet_trainer
Expand Down Expand Up @@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
parameters = []

checkpoint = torch.load(
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"),
map_location=torch.device("cpu"),
weights_only=True,
)
trainer_name = checkpoint["trainer_name"]
configuration_name = checkpoint["init_args"]["configuration"]
Expand All @@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
else None
)
if Path(model_training_output_dir).joinpath(model_name).is_file():
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
monai_checkpoint = torch.load(
join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True
)
if "network_weights" in monai_checkpoint.keys():
parameters.append(monai_checkpoint["network_weights"])
else:
Expand Down Expand Up @@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str,
dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
)

nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
nnunet_checkpoint_final = torch.load(
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True
)
nnunet_checkpoint_best = torch.load(
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True
)

nnunet_checkpoint = {}
nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
Expand Down Expand Up @@ -470,7 +478,7 @@ def get_network_from_nnunet_plans(
if model_ckpt is None:
return network
else:
state_dict = torch.load(model_ckpt)
state_dict = torch.load(model_ckpt, weights_only=True)
network.load_state_dict(state_dict[model_key_in_ckpt])
return network

Expand Down Expand Up @@ -534,7 +542,7 @@ def subfiles(

Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)

nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True)
latest_checkpoints: list[str] = subfiles(
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
)
Expand All @@ -545,7 +553,7 @@ def subfiles(
epochs.sort()
final_epoch: int = epochs[-1]
monai_last_checkpoint: dict = torch.load(
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True
)

best_checkpoints: list[str] = subfiles(
Expand All @@ -558,7 +566,7 @@ def subfiles(
key_metrics.sort()
best_key_metric: str = key_metrics[-1]
monai_best_checkpoint: dict = torch.load(
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True
)

nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]
Expand Down
13 changes: 9 additions & 4 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ class PersistentDataset(Dataset):
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
errors. If in doubt, it is advisable to clear the cache directory.

Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
be converted to tensors, however any other object type returned by transforms will not be loadable since
`torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.

Lazy Resampling:
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
its documentation to familiarize yourself with the interaction between `PersistentDataset` and
Expand Down Expand Up @@ -371,7 +375,8 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile, weights_only=False)
# Loading with weights_only=False is expected to be safe as these should be the user's own cached data
return torch.load(hashfile, weights_only=True)
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand All @@ -392,7 +397,7 @@ def _cachecheck(self, item_transformed):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_hash_file = Path(tmpdirname) / hashfile.name
torch.save(
obj=_item_transformed,
obj=convert_to_tensor(_item_transformed),
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand Down Expand Up @@ -1650,7 +1655,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
meta_hash_file = self.cache_dir / meta_hash_file_name
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
torch.save(
obj=self._meta_cache[meta_hash_file_name],
obj=convert_to_tensor(self._meta_cache[meta_hash_file_name]),
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand All @@ -1670,4 +1675,4 @@ def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)
2 changes: 1 addition & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,4 @@ def print_verbose(self) -> None:

# needed in later versions of Pytorch to indicate the class is safe for serialisation
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([MetaTensor])
torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])
2 changes: 1 addition & 1 deletion monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)

k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any:
fn = self.cached[key]["obj"] # pytype: disable=attribute-error
if not os.path.exists(fn): # pytype: disable=wrong-arg-types
raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.")
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False)
data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True)
# copy back to device if necessary
if "device" in self.cached[key]:
data_obj = data_obj.to(self.cached[key]["device"])
Expand Down
2 changes: 1 addition & 1 deletion tests/data/meta_tensor/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_pickling(self):
with tempfile.TemporaryDirectory() as tmp_dir:
fname = os.path.join(tmp_dir, "im.pt")
torch.save(m, fname)
m2 = torch.load(fname, weights_only=False)
m2 = torch.load(fname, weights_only=True)
self.check(m2, m, ids=False)

@skip_if_no_cuda
Expand Down
8 changes: 7 additions & 1 deletion tests/utils/test_state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}]
TEST_CASE_1 = [
torch.Tensor([1]).to(DEVICE),
{"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL},
{
"in_memory": False,
"cache_dir": gettempdir(),
"pickle_module": None,
# TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility
"pickle_protocol": torch.serialization.DEFAULT_PROTOCOL,
},
]
TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}]
TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}]
Expand Down
Loading