Skip to content

[Qwen3.5] mamba slice fix (Prefill TP != Decode TP & decode TP size>1)#20655

Merged
ShangmingCai merged 2 commits intosgl-project:mainfrom
YAMY1234:mamba_slice_fix
Mar 17, 2026
Merged

[Qwen3.5] mamba slice fix (Prefill TP != Decode TP & decode TP size>1)#20655
ShangmingCai merged 2 commits intosgl-project:mainfrom
YAMY1234:mamba_slice_fix

Conversation

@YAMY1234
Copy link
Copy Markdown
Contributor

Motivation

Fix incorrect Mamba state transfer when prefill and decode use different attention TP sizes in disaggregated serving with Mooncake. When prefill_attn_tp_size > decode_attn_tp_size (e.g., prefill TP=4 with decode dp-attention attn\tp_size=2), multiple prefill ranks write to the same decode rank's state buffer. The original code used local_tp_rank_in_group directly as the destination offset, which caused out-of-bound writes for higher-ranked prefill workers — resulting in corrupted Mamba states and degraded accuracy.

Modifications

In _send_mamba_state_slice, replaced the destination offset calculation for the attn_tp_size > dst_attn_tp_size branch:

  • Compute writers_per_decode = self.attn_tp_size // dst_attn_tp_size to determine how many prefill ranks map to each decode rank.
  • Compute local_writer_idx = local_tp_rank_in_group % writers_per_decode to get the correct writer index within the target decode rank, instead of using the global TP rank.
  • Use local_writer_idx * src_dim as dst_dim_start, ensuring each prefill rank writes to the correct non-overlapping slice of the decode buffer without exceeding its bounds.

Accuracy Tests

  • gpu_type: "gb200"
  • gpus_per_node: 4
  • prefill_nodes: 1
  • decode_nodes: 1
  • prefill_workers: 1
  • decode_workers: 1

Verify with GPQA, repeat=8:

Before this fix:

Repeat: 8, mean: 0.744
Scores: ['0.727', '0.768', '0.717', '0.712', '0.758', '0.758', '0.768', '0.747']

After this fix:

Repeat: 8, mean: 0.864
Scores: ['0.879', '0.874', '0.848', '0.864', '0.843', '0.859', '0.859', '0.884']
  prefill_environment:
    TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800"
    PYTHONUNBUFFERED: "1"
    NCCL_MNNVL_ENABLE: "1"
    NCCL_CUMEM_ENABLE: "1"
    MC_FORCE_MNNVL: "1"
    SGLANG_DG_CACHE_DIR: "/configs/deepgemm-cache"
    FLASHINFER_WORKSPACE_BASE: "/configs/flashinfer-cache"
    SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000"
    SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
    SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
    SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
    SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0"
    SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1"

  decode_environment:
    TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800"
    PYTHONUNBUFFERED: "1"
    NCCL_MNNVL_ENABLE: "1"
    NCCL_CUMEM_ENABLE: "1"
    MC_FORCE_MNNVL: "1"
    SGLANG_DG_CACHE_DIR: "/configs/deepgemm-cache"
    FLASHINFER_WORKSPACE_BASE: "/configs/flashinfer-cache"
    SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000"
    SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
    SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
    SGLANG_DECODE_BOOTSTRAP_TIMEOUT: "1000"
    SGLANG_HACK_SEQ_BOOTSTRAP_ROOM: "1"
    SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
    SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0"
    SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1"

  sglang_config:
    prefill:
      served-model-name: "Qwen/Qwen3.5-397B-A17B-FP8"
      model-path: "/model/"
      trust-remote-code: true

      attention-backend: "trtllm_mha"
      kv-cache-dtype: "fp8_e4m3"

      tensor-parallel-size: 4
      data-parallel-size: 1
      expert-parallel-size: 1

      mamba-scheduler-strategy: "no_buffer"
      disable-radix-cache: true
      mamba-track-interval: 2048
      mamba-ssm-dtype: "bfloat16"

      disaggregation-mode: "prefill"

      mem-fraction-static: 0.80
      chunked-prefill-size: 16384
      load-balance-method: "round_robin"
      watchdog-timeout: 1000000

    decode:
      served-model-name: "Qwen/Qwen3.5-397B-A17B-FP8"
      model-path: "/model/"
      trust-remote-code: true

      attention-backend: "trtllm_mha"
      quantization: "fp8"
      kv-cache-dtype: "fp8_e4m3"
      moe-runner-backend: "flashinfer_trtllm"

      # DEP2: DP2 + TP4 + EP2 with dp-attention (atten_tpsize=2)
      tp-size: 4
      dp-size: 2
      ep-size: 2
      enable-dp-attention: true
      enable-dp-lm-head: true
      moe-dense-tp-size: 1

      mamba-scheduler-strategy: "no_buffer"
      disable-radix-cache: true
      mamba-track-interval: 2048
      mamba-ssm-dtype: "bfloat16"

      disaggregation-mode: "decode"

      mem-fraction-static: 0.80
      chunked-prefill-size: 16384
      cuda-graph-max-bs: 1024
      decode-log-interval: 1
      stream-interval: 50
      watchdog-timeout: 1000000

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the Mamba state transfer mechanism within a disaggregated serving environment, specifically when prefill and decode attention tensor parallelism configurations are asymmetric. The previous implementation led to out-of-bounds writes and corrupted Mamba states, severely impacting model accuracy. The introduced changes ensure that prefill ranks correctly map and write to their designated, non-overlapping memory regions in the decode state buffer, thereby restoring the integrity of Mamba states and significantly improving model performance.

Highlights

  • Mamba State Transfer Fix: Resolved an issue where Mamba state transfer was incorrect when prefill and decode attention tensor parallelism (TP) sizes differed, specifically when prefill_attn_tp_size > decode_attn_tp_size.
  • Destination Offset Calculation: Modified the _send_mamba_state_slice function to correctly calculate the destination offset (dst_dim_start) for Mamba state slices. This now uses writers_per_decode and local_writer_idx to ensure each prefill rank writes to a non-overlapping, valid slice of the decode buffer.
  • Accuracy Improvement: Demonstrated a significant improvement in accuracy on GPQA tests (mean score increased from 0.744 to 0.864) after applying the fix, indicating the resolution of corrupted Mamba states.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/mooncake/conn.py
    • Corrected the calculation of dst_dim_start in _send_mamba_state_slice to prevent out-of-bounds writes when prefill attention TP size is greater than decode attention TP size.
Activity
  • The pull request was created by YAMY1234 to fix a Mamba state transfer issue.
  • The author provided a detailed motivation, modification description, and accuracy test results to support the changes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in Mamba state transfer when using different tensor parallelism sizes for prefill and decode stages. The change corrects the destination offset calculation for cases where multiple prefill ranks write to a single decode rank, preventing out-of-bounds memory writes and state corruption. The logic now correctly determines the writer index within the target decode rank's buffer, ensuring data is written to the appropriate slice. The fix is well-contained and appears correct based on the problem description and the significant accuracy improvements shown in testing.

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-ut test/registered/distributed/test_disaggregation_hybrid_attention.py

@github-actions
Copy link
Copy Markdown
Contributor

/rerun-ut is not available for fork PRs (security restriction).

Please ask a maintainer to add the run-ci label and use the normal CI flow, or use /rerun-failed-ci to rerun workflows that have already passed the gate.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-8-gpu-h200 to run independently (skipping dependencies).

@github-actions
Copy link
Copy Markdown
Contributor

🔗 View workflow run

@ShangmingCai
Copy link
Copy Markdown
Collaborator

image Related CI has passed.

@ShangmingCai ShangmingCai disabled auto-merge March 17, 2026 11:30
@ShangmingCai ShangmingCai merged commit cfead25 into sgl-project:main Mar 17, 2026
84 of 91 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants