Skip to content

[PP] Refactor PP to async mode#11852

Merged
ShangmingCai merged 42 commits intosgl-project:mainfrom
openanolis:Xuchun/pp-dev
Dec 12, 2025
Merged

[PP] Refactor PP to async mode#11852
ShangmingCai merged 42 commits intosgl-project:mainfrom
openanolis:Xuchun/pp-dev

Conversation

@XucSh
Copy link
Copy Markdown
Collaborator

@XucSh XucSh commented Oct 20, 2025

Motivation

see 11857

User can test with below command now:

python3 -m sglang.launch_server --model /opt/models/Qwen/Qwen3-8b --pp-size 4 --tp 2 --pp-async-batch-depth 1

Co-author: @ShangmingCai @merrymercy @alpha-baby

Cc: @ShangmingCai @merrymercy @whybeyoung @bluecoffee8

Checklist

merrymercy and others added 3 commits October 20, 2025 15:07
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@XucSh XucSh marked this pull request as draft October 20, 2025 07:29
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @XucSh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly upgrades the pipeline parallelism (PP) implementation by introducing asynchronous processing. The primary goal is to enhance efficiency by allowing GPU computation and CPU processing to overlap, particularly for the last rank in the pipeline, thereby mitigating potential performance bottlenecks. The changes involve a substantial refactoring of the PP scheduling logic into a new mixin and the addition of a configurable parameter to fine-tune the asynchronous batching behavior.

Highlights

  • Asynchronous Pipeline Parallelism (PP) Support: Introduced asynchronous capabilities for pipeline parallelism, allowing for better overlap of computation and communication, especially for the last PP rank.
  • Refactored PP Logic: The core pipeline parallelism event loop (event_loop_pp) has been extracted from scheduler.py into a new dedicated mixin, SchedulerPPMixin, improving modularity and maintainability.
  • Configurable Asynchronous Batch Depth: A new command-line argument --pp-async-batch-depth was added to server_args.py, enabling users to specify the depth of asynchronous batching for PP.
  • Asynchronous Point-to-Point Communication: The point_to_point_pyobj utility function was updated to support asynchronous sending, crucial for non-blocking communication in the new PP implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the pipeline parallelism (PP) logic into a new SchedulerPPMixin and adds support for asynchronous operations to improve performance by overlapping communication and computation. The changes are quite extensive, introducing a new pp_async_batch_depth server argument and modifying the point_to_point_pyobj utility for async sends.

While the overall direction is good, I've found a few critical issues in the new event_loop_pp implementation within scheduler_pp_mixin.py where essential logic for receiving data from previous pipeline stages appears to be commented out, which would break the pipeline. I've also pointed out some dead code that should be cleaned up for better maintainability. Please see the detailed comments for suggestions on how to fix these issues.

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
ShangmingCai and others added 6 commits October 21, 2025 11:44
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@XucSh XucSh marked this pull request as ready for review October 21, 2025 06:35
XucSh added 2 commits October 22, 2025 11:15
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@XucSh XucSh changed the title [PP] support async PP [PP] Refactor PP to async mode Oct 22, 2025
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@whybeyoung
Copy link
Copy Markdown
Collaborator

whybeyoung commented Oct 22, 2025

here is the benmark result in a800 80G*8
model: qwen3-8b
sglangserver: python -m sglang.launch_server --model-path /work/models/qwen8b --disable-radix-cache --pp-size 4 --trust-remote --host 0.0.0.0 --port 8001 --mem-fraction-static 0.8 --tokenizer-worker-num 8 --tp-size 2 --pp-async-batch-depth 1 --torch-compile-max-bs 8 --max-running-requests 20
benchcmd: python -m sglang.bench_serving --port 8001 --dataset-name random-ids --num-prompts 128 --random-input-len 1000 --random-output-len 1000 --random-range-ratio 0.9 --disable-stream
before:
main branch commit: 01f14a7

Input tokens: 121127
#Output tokens: 121703
Starting warmup with 1 sequences...
Warmup completed with 1 sequences. Starting main benchmark run...
100%|█████████████████████████████████████████| 128/128 [03:00<00:00,  1.41s/it]

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     128       
Benchmark duration (s):                  180.53    
Total input tokens:                      121127    
Total input text tokens:                 121127    
Total input vision tokens:               0         
Total generated tokens:                  121703    
Total generated tokens (retokenized):    120510    
Request throughput (req/s):              0.71      
Input token throughput (tok/s):          670.94    
Output token throughput (tok/s):         674.13    
Total token throughput (tok/s):          1345.08   
Concurrency:                             71.00     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   100144.97 
Median E2E Latency (ms):                 105706.80 
---------------Time to First Token----------------
Mean TTFT (ms):                          100145.09 
Median TTFT (ms):                        105706.89 
P99 TTFT (ms):                           179895.75 
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00      
Median ITL (ms):                         0.00      
P95 ITL (ms):                            0.00      
P99 ITL (ms):                            0.00      
Max ITL (ms):                            0.00      
==================================================
image

after:
commit: 8fec316

#Input tokens: 121127
#Output tokens: 121703
Starting warmup with 1 sequences...
Warmup completed with 1 sequences. Starting main benchmark run...
100%|█████████████████████████████████████████| 128/128 [01:41<00:00,  1.26it/s]

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     128       
Benchmark duration (s):                  101.86    
Total input tokens:                      121127    
Total input text tokens:                 121127    
Total input vision tokens:               0         
Total generated tokens:                  121703    
Total generated tokens (retokenized):    120204    
Request throughput (req/s):              1.26      
Input token throughput (tok/s):          1189.17   
Output token throughput (tok/s):         1194.83   
Total token throughput (tok/s):          2384.00   
Concurrency:                             68.74     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   54700.18  
Median E2E Latency (ms):                 58085.45  
---------------Time to First Token----------------
Mean TTFT (ms):                          54700.27  
Median TTFT (ms):                        58085.52  
P99 TTFT (ms):                           101474.87 
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00      
Median ITL (ms):                         0.00      
P95 ITL (ms):                            0.00      
P99 ITL (ms):                            0.00      
Max ITL (ms):                            0.00    
image

@ShangmingCai
Copy link
Copy Markdown
Collaborator

@ShangmingCai @ByronHsu @zhyncs could you review this? We found that this PR significantly improves SGLang's PP performance. Thanks!

@nvpohanh We are pretty close to finishing and determining the final design, will merge this in main ASAP, thx for the testing and performance verification.

XucSh and others added 7 commits November 21, 2025 15:02
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Co-authored-by: bluecoffee8 <jasperli2002@gmail.com>

Co-authored-by: Xuchun Shang <xuchun.shang@gmail.com>

Co-authored-by: ybyang <10629930+whybeyoung@users.noreply.github.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs
)
d2h_event = torch.cuda.Event()
d2h_event.record(torch.cuda.current_stream())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why d2h and copy stream ? there is no copy op in _pp_prep_batch_result?

@MichoChan
Copy link
Copy Markdown

would hang when return_logprob=True

@XucSh
Copy link
Copy Markdown
Collaborator Author

XucSh commented Dec 1, 2025

would hang when return_logprob=True

thanks for your feedback. will dig into it. Could you provide your test command?

@MichoChan
Copy link
Copy Markdown

would hang when return_logprob=True

thanks for your feedback. will dig into it. Could you provide your test command?
only hang when tp=8 , pp=2, nodes=2,when pp=4,tp=8,nodes=4 is ok

@weireweire
Copy link
Copy Markdown
Contributor

Could we do another rebase so I can run this on torch2.9/cuda13?

@ShangmingCai
Copy link
Copy Markdown
Collaborator

ShangmingCai commented Dec 10, 2025

/rerun-failed-ci 3

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@ShangmingCai
Copy link
Copy Markdown
Collaborator

ShangmingCai commented Dec 11, 2025

/rerun-failed-ci 2

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We think this PR is ready for public testing now. Please ping me in the comment of #11857 (or in the Slack channel or DM me in Slack) if you find any bugs or compatibility issues with this PR. We will come up with the following PRs to fix it ASAP.

Failed CI is irrelevant:
image

@ShangmingCai ShangmingCai merged commit c01b2ee into sgl-project:main Dec 12, 2025
282 of 352 checks passed
@ShangmingCai
Copy link
Copy Markdown
Collaborator

Update: @alpha-baby and @liusy58 also put many efforts into experimenting and testing on this PR, even though no related commits are included, they also contribute to this PR a lot. Sorry for forgetting to manually add them as a co-author. My bad.

Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 17, 2025
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: bluecoffee8 <jasperli2002@gmail.com>
Co-authored-by: zhangxiaolei123456 <zhangxiaolei.666@bytedance.com>
Co-authored-by: ybyang <10629930+whybeyoung@users.noreply.github.com>
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: bluecoffee8 <jasperli2002@gmail.com>
Co-authored-by: zhangxiaolei123456 <zhangxiaolei.666@bytedance.com>
Co-authored-by: ybyang <10629930+whybeyoung@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants