Add get_weights_checksum API and refactor update_weights_from_tensor tests with SHA256 verification#18913
Add get_weights_checksum API and refactor update_weights_from_tensor tests with SHA256 verification#18913aeft wants to merge 1 commit intosgl-project:mainfrom
Conversation
…tests with SHA256 verification
Summary of ChangesHello @aeft, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the reliability of model weight verification by replacing approximate checks with full SHA256 checksums. It introduces a new API to retrieve these checksums from the model engine and integrates this robust verification into existing weight update tests, ensuring that model parameters are accurately loaded and updated across various formats. This is a foundational step towards improving the integrity of model state management. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a get_weights_checksum API to verify model weights using SHA256, which is a significant improvement over the previous partial tensor comparison in tests. The implementation is well-structured, propagating the new API through the engine, tokenizer manager, scheduler, and model runner. The tests have been refactored to use this new checksum verification, making them much more robust. I have a few suggestions to improve maintainability and test coverage.
|
|
||
|
|
||
| class TestUpdateWeightsFromTensor(CustomTestCase): | ||
| def test_update_weights_from_tensor(self): |
There was a problem hiding this comment.
The test refactoring has removed testing for tensor parallelism (tp_size > 1). The previous implementation of test_update_weights_from_tensor iterated over tp_sizes = [1, 2]. While the new checksum-based verification is much more robust for a single GPU, it's important to also test the weight update functionality in a multi-GPU tensor parallel setup to prevent regressions. The get_weights_checksum API is designed to work with TP by computing a shard-local checksum, so it would be valuable to add a test case that verifies this.
There was a problem hiding this comment.
sglang/test/registered/rl/test_update_weights_from_tensor.py
Lines 3 to 4 in 5e3103a
The CI suite is stage-b-test-small-1-gpu, so tp_size=2 was always skipped.
Furthermore, cross-verifying shard-local checksums under TP>1 would require replicating the sharding logic in the test. I am open to suggestions if there's a better approach.
| if isinstance(t, DTensor): | ||
| t = t._local_tensor |
There was a problem hiding this comment.
Using _local_tensor accesses a private attribute of DTensor, which could be unstable across different PyTorch versions. It's better to use the public API to_local() to get the local tensor shard. This will improve code maintainability and stability.
| if isinstance(t, DTensor): | |
| t = t._local_tensor | |
| if isinstance(t, DTensor): | |
| t = t.to_local() |
| ] | ||
| checksum_before = engine.get_weights_checksum() | ||
| hf_params = _load_perturbed_hf_params() | ||
| engine.update_weights_from_tensor(list(hf_params)) |
There was a problem hiding this comment.
|
@zhaochenyang20 @dreamyang-liu Could you please take a look when you have a chance? Thanks! My overall plan is to start with |
Motivation
The previous tests only compared the first 5 elements of each tensor (atol=0.002), which is a weak guarantee for tensors with millions of elements. This PR replaces them with full SHA256 verification, as the first step of #18893.
Modifications
compute_weights_checksumutility insglang/srt/utils/weight_checksum.py(duplicated with the one in [Feature] Implement update_weights_from_disk for SGLang-D (Diffusion … #18306; we may consolidate them later)get_weights_checksumAPI through the full calling chain: engine -> tokenizer_manager -> scheduler -> tp_worker -> model_runner)GetWeightsChecksumReqInput/OutputIO structsTestUpdateWeightsFromTensorto use checksum verification:Tests
The affected unit tests passed on a GPU.
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci