[Qwen3.5] mamba slice fix (Prefill TP != Decode TP & decode TP size>1)#20655
[Qwen3.5] mamba slice fix (Prefill TP != Decode TP & decode TP size>1)#20655ShangmingCai merged 2 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in the Mamba state transfer mechanism within a disaggregated serving environment, specifically when prefill and decode attention tensor parallelism configurations are asymmetric. The previous implementation led to out-of-bounds writes and corrupted Mamba states, severely impacting model accuracy. The introduced changes ensure that prefill ranks correctly map and write to their designated, non-overlapping memory regions in the decode state buffer, thereby restoring the integrity of Mamba states and significantly improving model performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a critical bug in Mamba state transfer when using different tensor parallelism sizes for prefill and decode stages. The change corrects the destination offset calculation for cases where multiple prefill ranks write to a single decode rank, preventing out-of-bounds memory writes and state corruption. The logic now correctly determines the writer index within the target decode rank's buffer, ensuring data is written to the appropriate slice. The fix is well-contained and appears correct based on the problem description and the significant accuracy improvements shown in testing.
|
/rerun-ut test/registered/distributed/test_disaggregation_hybrid_attention.py |
|
❌ Please ask a maintainer to add the |
|
/tag-and-rerun-ci |
|
/rerun-stage stage-c-test-8-gpu-h200 |
|
✅ Triggered |
sgl-project#20655) Co-authored-by: Shangming Cai <csmthu@gmail.com>
sgl-project#20655) Co-authored-by: Shangming Cai <csmthu@gmail.com>

Motivation
Fix incorrect Mamba state transfer when prefill and decode use different attention TP sizes in disaggregated serving with Mooncake. When prefill_attn_tp_size > decode_attn_tp_size (e.g., prefill TP=4 with decode dp-attention attn\tp_size=2), multiple prefill ranks write to the same decode rank's state buffer. The original code used local_tp_rank_in_group directly as the destination offset, which caused out-of-bound writes for higher-ranked prefill workers — resulting in corrupted Mamba states and degraded accuracy.
Modifications
In
_send_mamba_state_slice, replaced the destination offset calculation for theattn_tp_size > dst_attn_tp_sizebranch:writers_per_decode = self.attn_tp_size // dst_attn_tp_sizeto determine how many prefill ranks map to each decode rank.local_writer_idx = local_tp_rank_in_group % writers_per_decodeto get the correct writer index within the target decode rank, instead of using the global TP rank.local_writer_idx * src_dimasdst_dim_start, ensuring each prefill rank writes to the correct non-overlapping slice of the decode buffer without exceeding its bounds.Accuracy Tests
Verify with GPQA, repeat=8:
Before this fix:
After this fix:
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci