Use weights_only=True for remaining torch.load() calls#28421
Conversation
Add _torch_load_weights_only() wrapper to lr_scheduler_test_data_generator.py and orttraining_test_ortmodule_pytorch_ddp.py, matching the pattern from PR #28097. This mitigates arbitrary code execution risk from malicious PyTorch checkpoints. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR is a security-focused follow-up that updates the remaining torch.load() call sites in the repo to prefer weights_only=True (when supported), reducing exposure to arbitrary pickle deserialization while keeping compatibility with older PyTorch versions via a fallback wrapper.
Changes:
- Added a local
_torch_load_weights_only()helper (with compatibility fallback) in two Python test/tooling files. - Replaced direct
torch.load(...)usages with_torch_load_weights_only(...)for checkpoint/state-dict loading.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py | Adds a torch.load wrapper preferring weights_only=True and uses it when loading the DDP checkpoint. |
| onnxruntime/test/testdata/test_data_generation/lr_scheduler/lr_scheduler_test_data_generator.py | Adds the same wrapper and uses it when reloading saved optimizer/scheduler state. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
APPROVE — Clean, correct security hardening of the last two torch.load() call sites.
The inspect.signature(torch.load) probe at module level is a nice improvement over the try/except TypeError fallback used in PR #28097 — it avoids masking unrelated TypeError exceptions from torch.load itself (e.g., corrupted files or bad arguments). Both files correctly forward **kwargs and load payloads (state dicts) that are safe for the restricted weights_only=True loader.
Nitpick (not blocking): The repo now has two different implementations of the same helper — the PR #28097 files (t5_helper.py, nv_run_pretraining.py) still use try/except TypeError, while these new files use inspect.signature. Consider back-porting the improved probe to the older helpers in a follow-up for consistency, or extracting a shared utility.
Description
Follow-up to PR #28097. Applies the same
_torch_load_weights_only()wrapper to the two remainingtorch.load()call sites.torch.loadcan deserialize arbitrary Python pickle payloads. Usingweights_only=Truerestricts loading to tensor/checkpoint data on supported PyTorch versions and is the safer default. The wrapper gracefully falls back to the defaulttorch.loadbehavior on older PyTorch versions that do not support theweights_onlyparameter.Summary of Changes
onnxruntime/test/testdata/test_data_generation/lr_scheduler/lr_scheduler_test_data_generator.py_torch_load_weights_only()helper and uses it when loading scheduler/optimizer state dicts.orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py_torch_load_weights_only()helper and uses it when loading DDP model checkpoint.Motivation and Context
These were the last two
torch.load()calls in the repository withoutweights_only=True. While both are in test/tooling code with low direct risk, this change ensures consistency with the pattern established in PR #28097 and eliminates all unsafe deserialization call sites.