Skip to content

Add get_weights_checksum API and refactor update_weights_from_tensor tests with SHA256 verification#18913

Open
aeft wants to merge 1 commit intosgl-project:mainfrom
aeft:refactor/refit-unit-tests
Open

Add get_weights_checksum API and refactor update_weights_from_tensor tests with SHA256 verification#18913
aeft wants to merge 1 commit intosgl-project:mainfrom
aeft:refactor/refit-unit-tests

Conversation

@aeft
Copy link
Copy Markdown

@aeft aeft commented Feb 17, 2026

Motivation

The previous tests only compared the first 5 elements of each tensor (atol=0.002), which is a weak guarantee for tensors with millions of elements. This PR replaces them with full SHA256 verification, as the first step of #18893.

Modifications

  • Add compute_weights_checksum utility in sglang/srt/utils/weight_checksum.py (duplicated with the one in [Feature] Implement update_weights_from_disk for SGLang-D (Diffusion … #18306; we may consolidate them later)
  • Add get_weights_checksum API through the full calling chain: engine -> tokenizer_manager -> scheduler -> tp_worker -> model_runner)
  • Add GetWeightsChecksumReqInput/Output IO structs
  • Refactor all 4 test methods in TestUpdateWeightsFromTensor to use checksum verification:
    • Load HF model, apply a deterministic perturbation, merge to SGLang format, compute expected checksum
    • Send perturbed params to engine via each load_format, verify engine checksum matches expected and differs from before

Tests

The affected unit tests passed on a GPU.

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @aeft, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the reliability of model weight verification by replacing approximate checks with full SHA256 checksums. It introduces a new API to retrieve these checksums from the model engine and integrates this robust verification into existing weight update tests, ensuring that model parameters are accurately loaded and updated across various formats. This is a foundational step towards improving the integrity of model state management.

Highlights

  • New API for Weight Checksums: Introduced a new get_weights_checksum API across the engine, tokenizer manager, scheduler, TP worker, and model runner to retrieve SHA256 checksums of model weights.
  • Enhanced Test Verification: Refactored all TestUpdateWeightsFromTensor methods to replace approximate tensor comparisons with full SHA256 checksum verification, providing a stronger guarantee of weight integrity.
  • Weight Checksum Utility: Added a new utility function compute_weights_checksum that calculates a deterministic SHA256 hash over all named model parameters, handling DTensor types.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/entrypoints/engine.py
    • Added get_weights_checksum method to the Engine class.
    • Imported GetWeightsChecksumReqInput.
  • python/sglang/srt/managers/io_struct.py
    • Added GetWeightsChecksumReqInput and GetWeightsChecksumReqOutput dataclasses for checksum requests.
  • python/sglang/srt/managers/scheduler.py
    • Registered the new get_weights_checksum method in the request dispatcher.
    • Imported GetWeightsChecksumReqInput.
  • python/sglang/srt/managers/scheduler_update_weights_mixin.py
    • Implemented get_weights_checksum to forward the request to the TP worker.
    • Imported GetWeightsChecksumReqInput and GetWeightsChecksumReqOutput.
  • python/sglang/srt/managers/tokenizer_communicator_mixin.py
    • Added communicator for get_weights_checksum requests and responses.
    • Implemented async get_weights_checksum to handle distributed checksum retrieval.
    • Imported GetWeightsChecksumReqInput and GetWeightsChecksumReqOutput.
  • python/sglang/srt/managers/tp_worker.py
    • Implemented get_weights_checksum to delegate the call to the model runner.
    • Imported GetWeightsChecksumReqInput.
  • python/sglang/srt/model_executor/model_runner.py
    • Added get_weights_checksum method to compute the SHA-256 checksum of local model parameters.
    • Imported compute_weights_checksum.
  • python/sglang/srt/utils/weight_checksum.py
    • Added a new utility file containing the compute_weights_checksum function.
  • test/registered/rl/test_update_weights_from_tensor.py
    • Refactored all test_update_weights_from_tensor methods to use SHA256 checksum verification.
    • Added helper functions for loading and perturbing Hugging Face model parameters and computing expected checksums.
    • Removed the _check_param helper function.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a get_weights_checksum API to verify model weights using SHA256, which is a significant improvement over the previous partial tensor comparison in tests. The implementation is well-structured, propagating the new API through the engine, tokenizer manager, scheduler, and model runner. The tests have been refactored to use this new checksum verification, making them much more robust. I have a few suggestions to improve maintainability and test coverage.



class TestUpdateWeightsFromTensor(CustomTestCase):
def test_update_weights_from_tensor(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test refactoring has removed testing for tensor parallelism (tp_size > 1). The previous implementation of test_update_weights_from_tensor iterated over tp_sizes = [1, 2]. While the new checksum-based verification is much more robust for a single GPU, it's important to also test the weight update functionality in a multi-GPU tensor parallel setup to prevent regressions. The get_weights_checksum API is designed to work with TP by computing a shard-local checksum, so it would be valuable to add a test case that verifies this.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register_cuda_ci(est_time=195, suite="stage-b-test-small-1-gpu")
register_amd_ci(est_time=195, suite="stage-b-test-small-1-gpu-amd")

The CI suite is stage-b-test-small-1-gpu, so tp_size=2 was always skipped.

Furthermore, cross-verifying shard-local checksums under TP>1 would require replicating the sharding logic in the test. I am open to suggestions if there's a better approach.

Comment on lines +14 to +15
if isinstance(t, DTensor):
t = t._local_tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using _local_tensor accesses a private attribute of DTensor, which could be unstable across different PyTorch versions. It's better to use the public API to_local() to get the local tensor shard. This will improve code maintainability and stability.

Suggested change
if isinstance(t, DTensor):
t = t._local_tensor
if isinstance(t, DTensor):
t = t.to_local()

]
checksum_before = engine.get_weights_checksum()
hf_params = _load_perturbed_hf_params()
engine.update_weights_from_tensor(list(hf_params))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _load_perturbed_hf_params function already returns a list. Calling list() on hf_params is redundant and creates an unnecessary shallow copy of the list.

Suggested change
engine.update_weights_from_tensor(list(hf_params))
engine.update_weights_from_tensor(hf_params)

@aeft
Copy link
Copy Markdown
Author

aeft commented Feb 17, 2026

@zhaochenyang20 @dreamyang-liu Could you please take a look when you have a chance? Thanks!

My overall plan is to start with update_weights_from_tensor in this first PR to demonstrate the pattern, then extend it to from_disk and from_distributed in follow-up PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant