Skip to content

refactor context parallel state #17213

Merged
Fridge003 merged 26 commits intosgl-project:mainfrom
dongjiyingdjy:support_cp
Feb 13, 2026
Merged

refactor context parallel state #17213
Fridge003 merged 26 commits intosgl-project:mainfrom
dongjiyingdjy:support_cp

Conversation

@dongjiyingdjy
Copy link
Copy Markdown
Contributor

@dongjiyingdjy dongjiyingdjy commented Jan 16, 2026

Motivation

Context parallelism is essential in long context LLM inference. It splits a long input sequence across multiple GPUs so attention can be computed in parallel, drastically reducing latency, which enables practical million-token context windows.
Previously, DeepSeek-V3.2 already supported CP and could use it together with DP. We aim to support the combination of CP, DP, and TP, and make it easier to apply to other models. To achieve this, we first refactored the original implementation.

Modifications

  • Add _ATTN_CP, _ATTN_TP and _MOE_CP as new GroupCoordinator;
  • Add attn_cp_size and moe_cp_size as new server_agrs;
  • Updated the relevant interfaces and processing logic in Scheduler.

Accuracy Tests

============================================================
DeepSeek-V3.2-Exp CP Single Node Results Summary
Dataset: gsm8k
Baseline: 0.935
============================================================

Model 1: deepseek-ai/DeepSeek-V3.2-Exp
  Accuracy: PASS
  Score: 0.980

Model 2: deepseek-ai/DeepSeek-V3.2-Exp
  Accuracy: PASS
  Score: 0.980

============================================================
OVERALL: ALL TESTS PASSED
============================================================

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @dongjiyingdjy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the context parallel state to support attention and MoE context parallelism. It introduces new group coordinators and updates initialization functions to manage these parallel processing capabilities. The changes also include modifications to scheduler processes and the addition of new server arguments for configuration.

Highlights

  • Context Parallelism Refactor: This PR introduces context parallelism for attention and MoE layers, enhancing parallel processing capabilities.
  • Group Coordinator Updates: New group coordinators (_ATTN_TP, _ATTN_CP, _MOE_CP) are added to manage attention and MoE context parallelism.
  • Initialization Function Modification: The initialize_model_parallel function is updated to include parameters for attention and MoE context parallel sizes.
  • Scheduler Processes: Scheduler processes are modified to pass attention and MoE context parallel ranks.
  • Server Arguments: New server arguments attn_cp_size and moe_cp_size are added to configure context parallelism.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces context parallelism for attention and MoE layers, which is a significant refactoring. The changes are extensive, touching many files to plumb through the new configuration and rank information. The core logic for creating the new parallel groups has been added.

My review focuses on the correctness of the new parallelism group initialization. While the logic for attention context parallelism (_ATTN_CP) and other groups seems correct, I've found a critical issue in the initialization of the MoE context parallel group (_MOE_CP). The current implementation appears to create groups across pipeline stages instead of within a single stage, which is incorrect for context parallelism.

I've provided a detailed comment with a suggested fix for this issue. Please address this to ensure the correctness of MoE context parallelism.

Also, there is a small typo in the pull request title: "refatcor" should be "refactor".

Comment on lines +1677 to +1682
for i in range(num_tensor_model_parallel_groups):
for j in range(moe_tp_size * moe_ep_size):
st = i * tensor_model_parallel_size + j
en = (i + 1) * tensor_model_parallel_size + j
ranks = list(range(st, en, moe_tp_size * moe_ep_size))
group_ranks.append(ranks)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for creating the _MOE_CP (MoE Context Parallel) group appears to be incorrect. It seems to be creating groups across pipeline parallel stages, similar to how pipeline parallel groups are formed, rather than creating context parallel groups within a single pipeline stage.

A context parallel group for MoE should group ranks that handle different parts of the context but the same expert and tensor slice. The current implementation:

for i in range(num_tensor_model_parallel_groups):
    for j in range(moe_tp_size * moe_ep_size):
        st = i * tensor_model_parallel_size + j
        en = (i + 1) * tensor_model_parallel_size + j
        ranks = list(range(st, en, moe_tp_size * moe_ep_size))
        group_ranks.append(ranks)

Here, i iterates through pipeline stages, and en points to a rank in the next pipeline stage, which is incorrect for a context parallel group.

A correct implementation should iterate within a single pipeline stage. Assuming a rank layout of (cp, ep, tp) within a tensor parallel group, the logic should be something like this:

        for i in range(num_tensor_model_parallel_groups):
            for j in range(moe_ep_size):
                for k in range(moe_tp_size):
                    # Assuming a rank layout of (cp, ep, tp)
                    base = i * tensor_model_parallel_size + j * moe_tp_size + k
                    stride = moe_ep_size * moe_tp_size
                    ranks = [base + c * stride for c in range(moe_cp_size)]
                    group_ranks.append(ranks)

@Fridge003
Copy link
Copy Markdown
Collaborator

@Fridge003 Fridge003 self-assigned this Jan 18, 2026
@Fridge003 Fridge003 changed the title refatcor context parallel state refactor context parallel state Jan 20, 2026
@dongjiyingdjy
Copy link
Copy Markdown
Contributor Author

Please make sure this PR can pass this unit test https://github.com/sgl-project/sglang/blob/main/test/registered/8-gpu-models/test_deepseek_v32_cp_single_node.py

Both tests already passed. Thanks!

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jan 21, 2026
@Fridge003
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@Fridge003 Fridge003 added the format Auto Format Code label Feb 11, 2026
Copy link
Copy Markdown
Collaborator

@Fridge003 Fridge003 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait for CIs

@Fridge003 Fridge003 merged commit 8b4c364 into sgl-project:main Feb 13, 2026
193 of 206 checks passed
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
slin1237 pushed a commit that referenced this pull request Feb 17, 2026
)

PR #17213 added attn_cp_rank and moe_dp_rank parameters to
run_scheduler_process but the gRPC scheduler_launcher was not updated,
causing startup failure due to missing arguments.
slin1237 pushed a commit that referenced this pull request Feb 17, 2026
)

PR #17213 added attn_cp_rank and moe_dp_rank parameters to
run_scheduler_process but the gRPC scheduler_launcher was not updated,
causing startup failure due to missing arguments.
slin1237 pushed a commit that referenced this pull request Feb 17, 2026
)

PR #17213 added attn_cp_rank and moe_dp_rank parameters to
run_scheduler_process but the gRPC scheduler_launcher was not updated,
causing startup failure due to missing arguments.
@llc-kc
Copy link
Copy Markdown
Contributor

llc-kc commented Feb 27, 2026

This PR disable PP+CP, will this be supported in the future?

magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation format Auto Format Code high priority run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants