Skip to content

Use weights_only=True for remaining torch.load() calls#28421

Merged
adrianlizarraga merged 2 commits into
mainfrom
adrianl/TorchLoad_WeightsOnly_TestCode
May 13, 2026
Merged

Use weights_only=True for remaining torch.load() calls#28421
adrianlizarraga merged 2 commits into
mainfrom
adrianl/TorchLoad_WeightsOnly_TestCode

Conversation

@adrianlizarraga

@adrianlizarraga adrianlizarraga commented May 8, 2026

Copy link
Copy Markdown
Contributor

Description

Follow-up to PR #28097. Applies the same _torch_load_weights_only() wrapper to the two remaining torch.load() call sites.

torch.load can deserialize arbitrary Python pickle payloads. Using weights_only=True restricts loading to tensor/checkpoint data on supported PyTorch versions and is the safer default. The wrapper gracefully falls back to the default torch.load behavior on older PyTorch versions that do not support the weights_only parameter.

Summary of Changes

File Change
onnxruntime/test/testdata/test_data_generation/lr_scheduler/lr_scheduler_test_data_generator.py Adds _torch_load_weights_only() helper and uses it when loading scheduler/optimizer state dicts.
orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py Adds _torch_load_weights_only() helper and uses it when loading DDP model checkpoint.

Motivation and Context

These were the last two torch.load() calls in the repository without weights_only=True. While both are in test/tooling code with low direct risk, this change ensures consistency with the pattern established in PR #28097 and eliminates all unsafe deserialization call sites.

Add _torch_load_weights_only() wrapper to lr_scheduler_test_data_generator.py
and orttraining_test_ortmodule_pytorch_ddp.py, matching the pattern from PR #28097.
This mitigates arbitrary code execution risk from malicious PyTorch checkpoints.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR is a security-focused follow-up that updates the remaining torch.load() call sites in the repo to prefer weights_only=True (when supported), reducing exposure to arbitrary pickle deserialization while keeping compatibility with older PyTorch versions via a fallback wrapper.

Changes:

  • Added a local _torch_load_weights_only() helper (with compatibility fallback) in two Python test/tooling files.
  • Replaced direct torch.load(...) usages with _torch_load_weights_only(...) for checkpoint/state-dict loading.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py Adds a torch.load wrapper preferring weights_only=True and uses it when loading the DDP checkpoint.
onnxruntime/test/testdata/test_data_generation/lr_scheduler/lr_scheduler_test_data_generator.py Adds the same wrapper and uses it when reloading saved optimizer/scheduler state.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.

@adrianlizarraga adrianlizarraga marked this pull request as ready for review May 8, 2026 22:26
@adrianlizarraga adrianlizarraga requested a review from tianleiwu May 8, 2026 23:13

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

APPROVE — Clean, correct security hardening of the last two torch.load() call sites.

The inspect.signature(torch.load) probe at module level is a nice improvement over the try/except TypeError fallback used in PR #28097 — it avoids masking unrelated TypeError exceptions from torch.load itself (e.g., corrupted files or bad arguments). Both files correctly forward **kwargs and load payloads (state dicts) that are safe for the restricted weights_only=True loader.

Nitpick (not blocking): The repo now has two different implementations of the same helper — the PR #28097 files (t5_helper.py, nv_run_pretraining.py) still use try/except TypeError, while these new files use inspect.signature. Consider back-porting the improved probe to the older helpers in a follow-up for consistency, or extracting a shared utility.

@adrianlizarraga adrianlizarraga merged commit 4e377a1 into main May 13, 2026
93 of 94 checks passed
@adrianlizarraga adrianlizarraga deleted the adrianl/TorchLoad_WeightsOnly_TestCode branch May 13, 2026 17:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants