Skip to content

[PD] improve kv offset calculation for MHA model with different tp size#18163

Merged
ShangmingCai merged 3 commits intosgl-project:mainfrom
Ch3ngY1:feat/numpy-optimize
Feb 5, 2026
Merged

[PD] improve kv offset calculation for MHA model with different tp size#18163
ShangmingCai merged 3 commits intosgl-project:mainfrom
Ch3ngY1:feat/numpy-optimize

Conversation

@Ch3ngY1
Copy link
Copy Markdown
Contributor

@Ch3ngY1 Ch3ngY1 commented Feb 3, 2026

…efore KV transfer

Motivation

When deploying an MHA model with PD disaggregation, if the TP sizes of the prefill and decode nodes are different, KV transfer is performed by token, which leads to high TTFT.

SGLANG_DISAGGREGATION_DIFF_TP_NUMPY: use NumPy vectorization to accelerate the computation of destination pointers on the prefill side.

Modifications

Use Numpy vectorization to accelerate dst ptr computation in send_kvcache_slice

Accuracy Tests

deploy Acc Latency
baseline 0.844 71.910
w/ this PR 0.849 68.468

w/o this PR:

root@dsw-292900-7db9ddf7db-2tll9:/sgl-workspace/sglang# bash tmp/eval.sh 
Loading GSM8K Platinum dataset from HuggingFace...
100%|██████████████████████████████████████████████████████████████████████████████████████| 1209/1209 [01:11<00:00, 16.86it/s]
Accuracy: 0.844
Invalid: 0.001
Latency: 71.910 s
Output throughput: 2805.579 token/s

w/ this PR, enable SGLANG_DISAGGREGATION_DIFF_TP_NUMPY

root@dsw-292900-7db9ddf7db-2tll9:/sgl-workspace/sglang# bash tmp/eval.sh 
Loading GSM8K Platinum dataset from HuggingFace...
100%|██████████████████████████████████████████████████████████████████████████████████████| 1209/1209 [01:08<00:00, 17.70it/s]
Accuracy: 0.849
Invalid: 0.000
Latency: 68.468 s
Output throughput: 2800.606 token/s

Benchmarking and Profiling

The results by benching 4K:1K, max-concurrency=64

deploy TTFT TPOT
baseline 899.37 30.74
+numpy 719.44 30.89

w/o this PR:

============ Serving Benchmark Result ============
Backend:                                 vllm      
Traffic request rate:                    inf       
Max request concurrency:                 64        
Successful requests:                     1000      
Benchmark duration (s):                  264.08    
Total input tokens:                      1995681   
Total generated tokens:                  500246    
Total generated tokens (retokenized):    499851    
Request throughput (req/s):              3.79      
Input token throughput (tok/s):          7557.10   
Output token throughput (tok/s):         1894.30   
Total token throughput (tok/s):          9451.39   
Concurrency:                             61.54     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   16252.07  
Median E2E Latency (ms):                 15713.56  
---------------Time to First Token----------------
Mean TTFT (ms):                          899.37    
Median TTFT (ms):                        436.68    
P99 TTFT (ms):                           9384.51   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           30.91     
Median ITL (ms):                         30.90     
P95 ITL (ms):                            33.34     
P99 ITL (ms):                            36.31     
Max ITL (ms):                            290.79    
Mean TPOT (ms):                          30.74     
Median TPOT (ms):                        30.96     
P95 TPOT (ms):                           32.12     
P99 TPOT (ms):                           32.50     
Max TPOT (ms):                           35.81     
==================================================

w/ this PR, enable SGLANG_DISAGGREGATION_DIFF_TP_NUMPY

============ Serving Benchmark Result ============
Backend:                                 vllm      
Traffic request rate:                    inf       
Max request concurrency:                 64        
Successful requests:                     1000      
Benchmark duration (s):                  260.38    
Total input tokens:                      1995681   
Total generated tokens:                  500246    
Total generated tokens (retokenized):    499864    
Request throughput (req/s):              3.84      
Input token throughput (tok/s):          7664.46   
Output token throughput (tok/s):         1921.21   
Total token throughput (tok/s):          9585.67   
Concurrency:                             61.99     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   16141.67  
Median E2E Latency (ms):                 15858.06  
---------------Time to First Token----------------
Mean TTFT (ms):                          719.44    
Median TTFT (ms):                        395.77    
P99 TTFT (ms):                           6386.17   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           31.03     
Median ITL (ms):                         31.03     
P95 ITL (ms):                            33.90     
P99 ITL (ms):                            36.21     
Max ITL (ms):                            236.73    
Mean TPOT (ms):                          30.89     
Median TPOT (ms):                        31.06     
P95 TPOT (ms):                           32.85     
P99 TPOT (ms):                           33.57     
Max TPOT (ms):                           34.59     
==================================================

Scripts

Prefill (TP4)

# ====== start of numpy opt ======
export SGLANG_DISAGGREGATION_DIFF_TP_NUMPY=1
# ====== end of numpy opt ======

python3 -m sglang.launch_server \
    --model-path /dev/shm/GLM-4.5-Air-FP8 \
    --served-model-name=GLM-4.5-Air-FP8 \
    --disaggregation-mode prefill \
    --tp-size 4 \
    --mem-fraction-static 0.8 \
    --disable-cuda-graph \
    --max-running-requests 320 \
    --kv-cache-dtype fp8_e4m3 \
    --disable-radix-cache \
    --host 0.0.0.0 \
    --enable-metrics \
    --attention-backend fa3 \
    --port 30001 \
    --page-size 64 \
    --chunked-prefill-size $((16384 * 4)) \
    --max-prefill-tokens $((16384 * 4))

Decode (2xTP2)

python3 -m sglang.launch_server \
    --model-path /dev/shm/GLM-4.5-Air-FP8 \
    --served-model-name=GLM-4.5-Air-FP8 \
    --disaggregation-mode decode \
    --tp-size 2 \
    --mem-fraction-static 0.8 \
    --max-running-requests 128 \
    --kv-cache-dtype fp8_e4m3 \
    --host 0.0.0.0 \
    --enable-metrics \
    --attention-backend fa3 \
    --port 30011 \
    --page-size 64 \
    --speculative-algorithm EAGLE \
    --speculative-num-steps 1 \
    --speculative-eagle-topk 1 \
    --speculative-num-draft-tokens 2 \
    --disable-shared-experts-fusion

sglang-router

python3 -m sglang_router.launch_router \
  --pd-disaggregation \
  --prefill http://localhost:30001 \
  --decode http://localhost:30011 \
  --decode http://localhost:30021 --mini-lb

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Ch3ngY1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a performance bottleneck in MHA models utilizing PD disaggregation, particularly when prefill and decode nodes operate with differing Tensor Parallelism sizes. The core issue was inefficient token-by-token KV transfer, leading to high Time To First Token. By implementing NumPy vectorization for memory address calculations during KV cache slicing, the change dramatically improves the efficiency of this transfer process, resulting in a notable reduction in TTFT and overall better latency without impacting accuracy.

Highlights

  • Performance Optimization for KV Transfer: Introduces NumPy vectorization to accelerate the computation of destination pointers during KV cache transfer, specifically addressing scenarios where MHA models use PD disaggregation with different Tensor Parallelism (TP) sizes on prefill and decode nodes.
  • Reduced Time To First Token (TTFT): The optimization significantly reduces the Time To First Token (TTFT), as demonstrated by benchmarks showing a decrease from 899.37ms to 719.44ms in a 4K:1K, max-concurrency=64 setup.
  • Configurable Optimization: Adds a new environment variable, SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY, allowing users to enable or disable this NumPy-based optimization for KV transfer.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/mooncake/conn.py
    • Refactored send_kvcache_slice function to conditionally use NumPy vectorization for calculating source and destination memory addresses for KV cache slices.
    • Introduced a new execution path that leverages NumPy array operations for efficient batch processing of address computations, replacing iterative Python loops for improved performance.
    • The original loop-based logic is retained as a fallback when the NumPy optimization is not enabled.
  • python/sglang/srt/environ.py
    • Added a new environment variable SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY (boolean, default False) to control the activation of the NumPy-based KV transfer optimization.
Activity
  • The pull request introduces a new feature and performance improvement.
  • Benchmarking results for accuracy and latency have been provided, showing positive impacts.
  • Example scripts for configuring prefill, decode, and router with the new optimization are included.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a NumPy-based optimization to accelerate KV cache transfer when tensor parallelism sizes differ between prefill and decode nodes. The change is gated by a new environment variable. My review focuses on the correctness of the new implementation, code maintainability, and consistency with the documentation. I've identified a critical inconsistency in the environment variable naming that will prevent users from enabling this feature, and a significant code duplication that impacts maintainability. I've also noted an incorrect type hint that should be fixed for clarity.

SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE = EnvInt(2)
SGLANG_DISAGGREGATION_WAITING_TIMEOUT = EnvInt(300)
SGLANG_DISAGGREGATION_NIXL_BACKEND = EnvStr("UCX")
SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY = EnvBool(False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new environment variable is named SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY, but the pull request description and example scripts use SGLANG_DISAGGREGATION_DIFF_TP_NUMPY. This inconsistency will likely cause confusion for users trying to enable this optimization. For consistency and clarity, I recommend renaming it to match the documentation.

Suggested change
SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY = EnvBool(False)
SGLANG_DISAGGREGATION_DIFF_TP_NUMPY = EnvBool(False)

Comment on lines +470 to +602
if use_numpy_opt:
# Compute indices and offsets
prefill_kv_indices_reshaped = prefill_kv_indices.astype(np.int64).reshape(
-1, 1
)
dst_kv_indices_reshaped = dst_kv_indices.astype(np.int64).reshape(-1, 1)
token_offsets = np.arange(page_size, dtype=np.int64).reshape(1, -1)
bytes_per_token_on_prefill = src_kv_item_len // page_size
bytes_per_token_on_decode = dst_kv_item_len // page_size
src_token_offsets_base = (
token_offsets * bytes_per_token_on_prefill + src_head_slice_offset
)
for layer_id in range(layers_current_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
dst_token_offsets_base = (
token_offsets * bytes_per_token_on_decode + dst_head_slice_offset
)
for layer_id in range(layers_current_pp_stage)
]

def process_layer_tp_aware(layer_params):
(
src_ptr,
dst_ptr,
src_item_len,
dst_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
) = layer_params
src_addr_list = []
dst_addr_list = []
length_list = []

# Calculate strides for a single token slot
bytes_per_token_on_prefill = src_item_len // page_size
bytes_per_token_on_decode = dst_item_len // page_size

for i in range(len(prefill_kv_indices)):
prefill_page_idx = int(prefill_kv_indices[i])
decode_page_idx = int(dst_kv_indices[i])

# Get the starting addresses for the current src and dst pages
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len

# Iterate through each valid token slot within the current page
for token_slot_in_page in range(page_size):
# Calculate the start address of the current token slot
src_token_slot_start_addr = (
src_page_start_addr
+ token_slot_in_page * bytes_per_token_on_prefill

def process_layer_tp_aware(ptrs):
src_ptr, dst_ptr = ptrs
src_page_starts = (
src_ptr + prefill_kv_indices_reshaped * src_kv_item_len
)
dst_page_starts = dst_ptr + dst_kv_indices_reshaped * dst_kv_item_len
src_addrs = src_page_starts + src_token_offsets_base
dst_addrs = dst_page_starts + dst_token_offsets_base
src_addr_list = src_addrs.reshape(-1).tolist()
if not src_addr_list:
return 0
dst_addr_list = dst_addrs.reshape(-1).tolist()
total_chunks = len(src_addr_list)
length_list = [heads_bytes_per_token_to_send] * total_chunks
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
)

futures = []
for i in range(layers_current_pp_stage):
futures.append(
executor.submit(
process_layer_tp_aware, (src_k_ptrs[i], dst_k_ptrs[i])
)
dst_token_slot_start_addr = (
dst_page_start_addr
+ token_slot_in_page * bytes_per_token_on_decode
)
for i in range(layers_current_pp_stage):
futures.append(
executor.submit(
process_layer_tp_aware, (src_v_ptrs[i], dst_v_ptrs[i])
)
)
else:
layers_params = [
(
src_k_ptrs[layer_id],
dst_k_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_current_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_current_pp_stage)
]

def process_layer_tp_aware(layer_params):
(
src_ptr,
dst_ptr,
src_item_len,
dst_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
) = layer_params
src_addr_list = []
dst_addr_list = []
length_list = []

# Calculate strides for a single token slot
bytes_per_token_on_prefill = src_item_len // page_size
bytes_per_token_on_decode = dst_item_len // page_size

for i in range(len(prefill_kv_indices)):
prefill_page_idx = int(prefill_kv_indices[i])
decode_page_idx = int(dst_kv_indices[i])

# Get the starting addresses for the current src and dst pages
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len

# Iterate through each valid token slot within the current page
for token_slot_in_page in range(page_size):
# Calculate the start address of the current token slot
src_token_slot_start_addr = (
src_page_start_addr
+ token_slot_in_page * bytes_per_token_on_prefill
)
dst_token_slot_start_addr = (
dst_page_start_addr
+ token_slot_in_page * bytes_per_token_on_decode
)

# Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
# Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = (
src_token_slot_start_addr + src_head_slice_offset
)
dst_slice_addr = (
dst_token_slot_start_addr + dst_head_slice_offset
)

src_addr_list.append(src_slice_addr)
dst_addr_list.append(dst_slice_addr)
length_list.append(heads_bytes_per_token_to_send)
src_addr_list.append(src_slice_addr)
dst_addr_list.append(dst_slice_addr)
length_list.append(heads_bytes_per_token_to_send)

return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
)
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
)

futures = [
executor.submit(
process_layer_tp_aware,
layer_params,
)
for layer_params in layers_params
]
futures = [
executor.submit(
process_layer_tp_aware,
layer_params,
)
for layer_params in layers_params
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block introduces significant code duplication. The else block contains a full copy of the original implementation. This makes the code harder to read and maintain, as future changes might need to be applied in two places.

To improve maintainability, consider refactoring the original loop-based implementation into a separate private helper method. The else block would then simply call this new method. This would make the send_kvcache_slice function cleaner and the new optimized path more prominent.

Comment on lines +472 to +475
prefill_kv_indices_reshaped = prefill_kv_indices.astype(np.int64).reshape(
-1, 1
)
dst_kv_indices_reshaped = dst_kv_indices.astype(np.int64).reshape(-1, 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hints for prefill_kv_indices and dst_kv_indices in the send_kvcache_slice function signature are npt.NDArray[np.int64]. However, the actual data passed from transfer_worker is np.int32. This makes the .astype(np.int64) cast appear redundant.

To improve code clarity and correctness, please update the function signature's type hints to npt.NDArray[np.int32]. This will accurately reflect the input types and justify the necessity of the astype(np.int64) conversion for pointer arithmetic.

Comment on lines +467 to +468
use_numpy_opt = envs.SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY.get()
# Whether to use numpy optimization for different TP size on prefill node
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the performance gain is guaranteed, we should replace the original code with the new one, instead of adding a switch.

Or will this change be incompatible with some features?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve already tested the two models glm4.5-air-fp8 and qwen3-235b-fp8 in PD disaggregation. The testing setup also covered:

  1. 1×TP4 → 2×TP2
  2. 1×TP4 → 1×TP4DP4 (dp-attention, meaning attention-tp-size = 1)

Therefore, I believe compatibility with MHA models can be guaranteed. To be safe, should I run any additional experiments?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we don't need this use_numpy_opt, just replace the original code, this is a general optimization, we don't need a fallback for the original one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think we can directly replace the existing code. I will make a new commit

src_addr_list.append(src_slice_addr)
dst_addr_list.append(dst_slice_addr)
length_list.append(heads_bytes_per_token_to_send)
# The original for loop is replaced by numpy optimization for different TP size on prefill node
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not needed, since there is no original implementation anymore.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Feb 4, 2026
@ShangmingCai ShangmingCai changed the title [PD Disaggregation] improve kv transfer for MHA model with different tp size on prefill node [PD] improve kv offset calculation for MHA model with different tp size Feb 4, 2026
@ShangmingCai
Copy link
Copy Markdown
Collaborator

ShangmingCai commented Feb 4, 2026

/rerun-failed-ci 1

@ShangmingCai
Copy link
Copy Markdown
Collaborator

CI has passed.
image

stage-c-test-large-4-gpu is flaky and not related to this PR.

@ShangmingCai ShangmingCai merged commit f730c18 into sgl-project:main Feb 5, 2026
435 of 464 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 5, 2026
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants