Skip to content

[DSv32] Fix MTP and CP compatability#19062

Merged
Fridge003 merged 1 commit intosgl-project:mainfrom
vladnosiv:fix-dsv32-cp-and-mtp
Feb 21, 2026
Merged

[DSv32] Fix MTP and CP compatability#19062
Fridge003 merged 1 commit intosgl-project:mainfrom
vladnosiv:fix-dsv32-cp-and-mtp

Conversation

@vladnosiv
Copy link
Copy Markdown
Contributor

@vladnosiv vladnosiv commented Feb 20, 2026

Motivation

PR #17213 introduced separate
get_attention_cp_size() / get_attention_cp_rank() APIs and migrated the main model in deepseek_v2.py to use them, but missed deepseek_nextn.py.

Here is the assigns:

if self.nsa_enable_prefill_cp:
self.cp_rank = get_attention_tp_rank()
self.cp_size = get_attention_tp_size()

Before the refactor this returned 8 (the combined TP+CP group size). After the refactor it returns 1 (pure attention TP size), so the MTP model operates with cp_size=1, skips CP logic, and produces tensors with wrong shapes. Any request crashes with:

RuntimeError: seqlens_k must have shape (batch_size)

Accuracy Tests

Run config

export SGLANG_ENABLE_JIT_DEEPGEMM=1
export SGLANG_JIT_DEEPGEMM_FAST_WARMUP=1

export SGLANG_USE_FUSED_METADATA_COPY=0
export SGLANG_ENABLE_SPEC_V2=0

python3 -m sglang.launch_server \
      --model-path deepseek-ai/DeepSeek-V3.2 \
      --trust-remote-code \
      --port 30031 \
      --host "::" \
      --context-length 65536 \
      --chunked-prefill-size 65536 \
      --watchdog-timeout 3600 \
      --tp-size 8 \
      --attn-cp-size 8 \
      --enable-nsa-prefill-context-parallel \
      --cuda-graph-max-bs 75 \
      --chat-template examples/chat_template/tool_chat_template_deepseekv32.jinja \
      --page-size 64 \
      --model-loader-extra-config '{"enable_multithread_load": "true","num_threads": 64}' \
      --mem-fraction-static 0.8 \
      --speculative-algorithm EAGLE \
      --kv-cache-dtype fp8_e4m3

Bench

python benchmark/gsm8k/bench_sglang.py --port 30031 --num-questions 500 --num-shots 48 --parallel 100

Results

Accuracy: 0.966
Invalid: 0.000
Latency: 99.841 s
Output throughput: 478.503 token/s

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @vladnosiv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical compatibility issue within the deepseek_nextn.py model by correctly integrating the attention context parallel (CP) group size and rank calculations. This ensures that the model, particularly when operating in MTP mode, properly utilizes CP logic, preventing tensor shape mismatches and associated runtime errors, thereby enhancing the stability and correctness of the DeepSeekV3.2 implementation.

Highlights

  • Corrected Context Parallelism (CP) calculation: The deepseek_nextn.py model now accurately determines the attention context parallel group size and rank by utilizing get_attention_cp_size() and get_attention_cp_rank(), addressing a previous oversight.
  • Resolved MTP model runtime errors: This fix prevents a RuntimeError related to seqlens_k shape mismatches, which occurred because the MTP model incorrectly operated with cp_size=1 due to the previous misconfiguration, leading to skipped CP logic and malformed tensors.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/models/deepseek_nextn.py
    • Updated import statements to reference get_attention_cp_rank and get_attention_cp_size from sglang.srt.layers.dp_attention.
    • Modified the init method to assign self.cp_size using get_attention_cp_size() when nsa_enable_prefill_cp is active.
    • Adjusted the init method to set self.cp_rank and self.cp_size using get_attention_cp_rank() and get_attention_cp_size() respectively under the nsa_enable_prefill_cp condition.
Activity
  • Accuracy tests were performed using specific SGLANG environment configurations and server launch parameters.
  • Benchmarking was conducted with benchmark/gsm8k/bench_sglang.py against the configured server.
  • Test results indicated an accuracy of 0.966, an output throughput of 478.503 token/s, and a latency of 99.841 s.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in deepseek_nextn.py that caused a runtime crash due to incorrect API usage for tensor parallelism after a recent refactoring. The changes correctly replace get_attention_tp_rank/get_attention_tp_size with get_attention_cp_rank/get_attention_cp_size. This ensures compatibility with context parallelism (CP) and multi-path transformers (MTP), aligning the implementation with deepseek_v2.py. The fix is accurate and directly resolves the issue described.

@Fridge003 Fridge003 merged commit afd91e8 into sgl-project:main Feb 21, 2026
53 of 61 checks passed
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants