Skip to content

[DeepSeek 3.2] Support and optimize pipeline parallelis when context pipeline enabled#16380

Merged
Fridge003 merged 5 commits intosgl-project:mainfrom
antgroup:xyf/cp_opt
Jan 9, 2026
Merged

[DeepSeek 3.2] Support and optimize pipeline parallelis when context pipeline enabled#16380
Fridge003 merged 5 commits intosgl-project:mainfrom
antgroup:xyf/cp_opt

Conversation

@xu-yfei
Copy link
Copy Markdown
Contributor

@xu-yfei xu-yfei commented Jan 4, 2026

Motivation

For issue #15358. CC @whybeyoung

Support pipeline parallelism for the Prefill CP scenario (--enable-nsa-prefill-context-parallel). In this scenario, the linear layer of attention uses repeated weights, and the received input is scattered. Therefore, the attention TP group does not need to perform an all_gather operation during PP send/recv.

Optimize the indexer as follows: In the case of PP parallelism, part of the model execution may run in parallel with the PP recv operation, which occupies an additional SM resource. Thus, the total number of SMs allocated to the indexer’s deep_gemm.fp8_mqa_logits must be reduced by 1.

As shown in the figure below, for deep_gemm.fp8_mqa_logits on the first node (i.e., PP rank = 0), the latency is 1.312 ms when it does not overlap with the PP recv operation, and increases to 2.211 ms when overlapped with the recv operator.
Analysis reveals that deep_gemm.fp8_mqa_logits occupies all SMs by default. In the H20 scenario, the grid is (78,). However, since the recv operator already occupies 1 SM, the latency of deep_gemm.fp8_mqa_logits increases significantly.
Solution: In the PP scenario, we reduce the number of SMs for deep_gemm.fp8_mqa_logits by calling deep_gemm.set_num_sms, so as to maintain the performance of deep_gemm.fp8_mqa_logits.

image

In 2* 8*H20(96GB), 48.5K inputs, TTFT 3.11 s.

Modifications

  1. Support pipeline parallelism for the Prefill CP scenario
  2. Optimize the indexer, the total number of SMs allocated to the indexer’s deep_gemm.fp8_mqa_logits must be reduced by 1 when pp is enabled and is not last pp rank.

Accuracy Tests

In 2* 8*H20(96GB).

export SGLANG_PP_LAYER_PARTITION=30,31
python3 -m sglang.launch_server --model-path $MODEL_PATH --nnodes 2 --port 8000 --dist-init-addr ${node_rank_ip}:62001 \
--node-rank 0 --tp 8 --pp-size 2  --trust-remote-code --disable-radix-cache --mem-fraction-static 0.8 \
--max-running-requests 128 --watchdog-timeout 3600  --host 0.0.0.0 \
--chunked-prefill-size 8192 \
--enable-nsa-prefill-context-parallel \
--nsa-prefill-cp-mode round-robin-split \
--attention-backend nsa \
--nsa-decode-backend fa3 \
--tool-call-parser deepseekv32 \
--reasoning-parser deepseek-v3
export SGLANG_PP_LAYER_PARTITION=30,31
python3 -m sglang.launch_server --model-path $MODEL_PATH  --nnodes 2 --port 8000 --dist-init-addr ${node_rank_ip}:62001 \
--node-rank 1 --tp 8 --pp-size 2  --trust-remote-code --disable-radix-cache --mem-fraction-static 0.8 \
--max-running-requests 128 --watchdog-timeout 3600 \
--chunked-prefill-size 8192 \
--enable-nsa-prefill-context-parallel \
--nsa-prefill-cp-mode round-robin-split \
--nsa-decode-backend fa3 \
--attention-backend nsa \
--tool-call-parser deepseekv32 \
--reasoning-parser deepseek-v3
# gsm8k
100%|██████████| 1319/1319 [05:51<00:00,  3.75it/s]
Accuracy: 0.948
Invalid: 0.000
Latency: 357.199 s
Output throughput: 342.630 token/s
# mmlu
100%|██████████| 14042/14042 [17:44<00:00, 13.19it/s] 
subject: abstract_algebra, #q:100, acc: 0.780
subject: anatomy, #q:135, acc: 0.867
subject: astronomy, #q:152, acc: 0.961
subject: business_ethics, #q:100, acc: 0.860
subject: clinical_knowledge, #q:265, acc: 0.921
subject: college_biology, #q:144, acc: 0.972
subject: college_chemistry, #q:100, acc: 0.650
subject: college_computer_science, #q:100, acc: 0.870
subject: college_mathematics, #q:100, acc: 0.870
subject: college_medicine, #q:173, acc: 0.867
subject: college_physics, #q:102, acc: 0.892
subject: computer_security, #q:100, acc: 0.910
subject: conceptual_physics, #q:235, acc: 0.928
subject: econometrics, #q:114, acc: 0.772
subject: electrical_engineering, #q:145, acc: 0.883
subject: elementary_mathematics, #q:378, acc: 0.944
subject: formal_logic, #q:126, acc: 0.802
subject: global_facts, #q:100, acc: 0.720
subject: high_school_biology, #q:310, acc: 0.961
subject: high_school_chemistry, #q:203, acc: 0.877
subject: high_school_computer_science, #q:100, acc: 0.930
subject: high_school_european_history, #q:165, acc: 0.885
subject: high_school_geography, #q:198, acc: 0.965
subject: high_school_government_and_politics, #q:193, acc: 0.979
subject: high_school_macroeconomics, #q:390, acc: 0.921
subject: high_school_mathematics, #q:270, acc: 0.789
subject: high_school_microeconomics, #q:238, acc: 0.962
subject: high_school_physics, #q:151, acc: 0.841
subject: high_school_psychology, #q:545, acc: 0.969
subject: high_school_statistics, #q:216, acc: 0.875
subject: high_school_us_history, #q:204, acc: 0.966
subject: high_school_world_history, #q:237, acc: 0.954
subject: human_aging, #q:223, acc: 0.843
subject: human_sexuality, #q:131, acc: 0.924
subject: international_law, #q:121, acc: 0.967
subject: jurisprudence, #q:108, acc: 0.917
subject: logical_fallacies, #q:163, acc: 0.933
subject: machine_learning, #q:112, acc: 0.812
subject: management, #q:103, acc: 0.932
subject: marketing, #q:234, acc: 0.949
subject: medical_genetics, #q:100, acc: 0.940
subject: miscellaneous, #q:783, acc: 0.962
subject: moral_disputes, #q:346, acc: 0.879
subject: moral_scenarios, #q:895, acc: 0.799
subject: nutrition, #q:306, acc: 0.928
subject: philosophy, #q:311, acc: 0.920
subject: prehistory, #q:324, acc: 0.932
subject: professional_accounting, #q:282, acc: 0.883
subject: professional_law, #q:1534, acc: 0.726
subject: professional_medicine, #q:272, acc: 0.949
subject: professional_psychology, #q:612, acc: 0.908
subject: public_relations, #q:110, acc: 0.818
subject: security_studies, #q:245, acc: 0.882
subject: sociology, #q:201, acc: 0.965
subject: us_foreign_policy, #q:100, acc: 0.950
subject: virology, #q:166, acc: 0.590
subject: world_religions, #q:171, acc: 0.936
Total latency: 1064.955
Average accuracy: 0.880

Benchmarking and Profiling

Before optimize the indexer:
image
image

After optimize the indexer:
image
image

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments (/tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci) or contact authorized users to do so.
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Jan 4, 2026

@Fridge003 Could you please help review this PR?

@whybeyoung
Copy link
Copy Markdown
Collaborator

Nice job

@Fridge003
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Jan 7, 2026
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, LGTM.

@Fridge003 Fridge003 merged commit 05dfef9 into sgl-project:main Jan 9, 2026
292 of 362 checks passed
@llc-kc
Copy link
Copy Markdown
Contributor

llc-kc commented Feb 27, 2026

Hi, @ShangmingCai @xu-yfei
is this a bug in python/sglang/srt/server_args.py:

    def _handle_context_parallelism(self):
        if self.attn_cp_size > 1:
            # The tp_size is the world size, not the real tensor parallel size
            assert (
                self.tp_size % self.attn_cp_size == 0
            ), "tp_size must be divisible by attn_cp_size"
            assert (
                self.tp_size % (self.dp_size * self.attn_cp_size) == 0
            ), "tp_size must be divisible by dp_size * attn_cp_size"
            assert self.pp_size == 1, "PP is not supported with context parallelism"

When I tried to launch GLM5 with PP2 and TP8 CP8 (sglang v0.5.9):

# node 0
export LWS_LEADER_ADDRESS=xxx
export LWS_GROUP_SIZE=2
export LWS_WORKER_INDEX=0

export SGLANG_PP_LAYER_PARTITION=39,39

python3 -m sglang.launch_server \
--model-path model/GLM-5-FP8 \
--served-model-name GLM-5-FP8 \
--dist-init-addr ${LWS_LEADER_ADDRESS}:20000 --nnodes ${LWS_GROUP_SIZE} --node-rank ${LWS_WORKER_INDEX} \
--cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 112 120 128 \
--tp 8 \
--pp-size 2 \
--dp-size 1 \
--enable-nsa-prefill-context-parallel \
--attn-cp-size 8 \
--nsa-prefill-cp-mode round-robin-split \
--moe-dense-tp-size 1 \
--tool-call-parser glm47  \
--reasoning-parser glm45 \
--mem-fraction-static 0.86 \
--max-running-requests 128 \
--chunked-prefill-size 16384 \
--context-length 131072 \
--num-reserved-decode-tokens 1024 \
--trust-remote-code \
--host 0.0.0.0 \
--port 12121 \
--page-size 64  \
--enable-metrics \
--log-level debug \
--enable-cache-report \
--decode-log-interval 1 \
--collect-tokens-histogram \
--enable-request-time-stats-logging \
--model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}'

I got this Error:

  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 5592, in prepare_server_args
    return ServerArgs.from_cli_args(raw_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 5078, in from_cli_args
    return cls(**{attr: getattr(args, attr) for attr in attrs})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 331, in __init__
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 754, in __post_init__
    self._handle_context_parallelism()
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 2092, in _handle_context_parallelism
    assert self.pp_size == 1, "PP is not supported with context parallelism"
           ^^^^^^^^^^^^^^^^^
AssertionError: PP is not supported with context parallelism

@xu-yfei
Copy link
Copy Markdown
Contributor Author

xu-yfei commented Feb 27, 2026

Hi, @ShangmingCai @xu-yfei is this a bug in python/sglang/srt/server_args.py:

    def _handle_context_parallelism(self):
        if self.attn_cp_size > 1:
            # The tp_size is the world size, not the real tensor parallel size
            assert (
                self.tp_size % self.attn_cp_size == 0
            ), "tp_size must be divisible by attn_cp_size"
            assert (
                self.tp_size % (self.dp_size * self.attn_cp_size) == 0
            ), "tp_size must be divisible by dp_size * attn_cp_size"
            assert self.pp_size == 1, "PP is not supported with context parallelism"

When I tried to launch GLM5 with PP2 and TP8 CP8 (sglang v0.5.9):

# node 0
export LWS_LEADER_ADDRESS=xxx
export LWS_GROUP_SIZE=2
export LWS_WORKER_INDEX=0

export SGLANG_PP_LAYER_PARTITION=39,39

python3 -m sglang.launch_server \
--model-path model/GLM-5-FP8 \
--served-model-name GLM-5-FP8 \
--dist-init-addr ${LWS_LEADER_ADDRESS}:20000 --nnodes ${LWS_GROUP_SIZE} --node-rank ${LWS_WORKER_INDEX} \
--cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 112 120 128 \
--tp 8 \
--pp-size 2 \
--dp-size 1 \
--enable-nsa-prefill-context-parallel \
--attn-cp-size 8 \
--nsa-prefill-cp-mode round-robin-split \
--moe-dense-tp-size 1 \
--tool-call-parser glm47  \
--reasoning-parser glm45 \
--mem-fraction-static 0.86 \
--max-running-requests 128 \
--chunked-prefill-size 16384 \
--context-length 131072 \
--num-reserved-decode-tokens 1024 \
--trust-remote-code \
--host 0.0.0.0 \
--port 12121 \
--page-size 64  \
--enable-metrics \
--log-level debug \
--enable-cache-report \
--decode-log-interval 1 \
--collect-tokens-histogram \
--enable-request-time-stats-logging \
--model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}'

I got this Error:

  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 5592, in prepare_server_args
    return ServerArgs.from_cli_args(raw_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 5078, in from_cli_args
    return cls(**{attr: getattr(args, attr) for attr in attrs})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 331, in __init__
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 754, in __post_init__
    self._handle_context_parallelism()
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 2092, in _handle_context_parallelism
    assert self.pp_size == 1, "PP is not supported with context parallelism"
           ^^^^^^^^^^^^^^^^^
AssertionError: PP is not supported with context parallelism

@llc-kc The restrictions introduced by PR #17213

@ShangmingCai
Copy link
Copy Markdown
Collaborator

Hi, @ShangmingCai @xu-yfei is this a bug in python/sglang/srt/server_args.py:

    def _handle_context_parallelism(self):
        if self.attn_cp_size > 1:
            # The tp_size is the world size, not the real tensor parallel size
            assert (
                self.tp_size % self.attn_cp_size == 0
            ), "tp_size must be divisible by attn_cp_size"
            assert (
                self.tp_size % (self.dp_size * self.attn_cp_size) == 0
            ), "tp_size must be divisible by dp_size * attn_cp_size"
            assert self.pp_size == 1, "PP is not supported with context parallelism"

When I tried to launch GLM5 with PP2 and TP8 CP8 (sglang v0.5.9):

# node 0
export LWS_LEADER_ADDRESS=xxx
export LWS_GROUP_SIZE=2
export LWS_WORKER_INDEX=0

export SGLANG_PP_LAYER_PARTITION=39,39

python3 -m sglang.launch_server \
--model-path model/GLM-5-FP8 \
--served-model-name GLM-5-FP8 \
--dist-init-addr ${LWS_LEADER_ADDRESS}:20000 --nnodes ${LWS_GROUP_SIZE} --node-rank ${LWS_WORKER_INDEX} \
--cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 112 120 128 \
--tp 8 \
--pp-size 2 \
--dp-size 1 \
--enable-nsa-prefill-context-parallel \
--attn-cp-size 8 \
--nsa-prefill-cp-mode round-robin-split \
--moe-dense-tp-size 1 \
--tool-call-parser glm47  \
--reasoning-parser glm45 \
--mem-fraction-static 0.86 \
--max-running-requests 128 \
--chunked-prefill-size 16384 \
--context-length 131072 \
--num-reserved-decode-tokens 1024 \
--trust-remote-code \
--host 0.0.0.0 \
--port 12121 \
--page-size 64  \
--enable-metrics \
--log-level debug \
--enable-cache-report \
--decode-log-interval 1 \
--collect-tokens-histogram \
--enable-request-time-stats-logging \
--model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}'

I got this Error:

  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 5592, in prepare_server_args
    return ServerArgs.from_cli_args(raw_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 5078, in from_cli_args
    return cls(**{attr: getattr(args, attr) for attr in attrs})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 331, in __init__
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 754, in __post_init__
    self._handle_context_parallelism()
  File "/sgl-workspace/sglang/python/sglang/srt/server_args.py", line 2092, in _handle_context_parallelism
    assert self.pp_size == 1, "PP is not supported with context parallelism"
           ^^^^^^^^^^^^^^^^^
AssertionError: PP is not supported with context parallelism

this is a bug introduced by context parallelism refactor, we are fixing it now. If you are in a hurry, you can try the older release version, where that commit didn't get merged yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants