[PD] improve kv offset calculation for MHA model with different tp size#18163
[PD] improve kv offset calculation for MHA model with different tp size#18163ShangmingCai merged 3 commits intosgl-project:mainfrom
Conversation
…efore KV transfer
Summary of ChangesHello @Ch3ngY1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a performance bottleneck in MHA models utilizing PD disaggregation, particularly when prefill and decode nodes operate with differing Tensor Parallelism sizes. The core issue was inefficient token-by-token KV transfer, leading to high Time To First Token. By implementing NumPy vectorization for memory address calculations during KV cache slicing, the change dramatically improves the efficiency of this transfer process, resulting in a notable reduction in TTFT and overall better latency without impacting accuracy. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a NumPy-based optimization to accelerate KV cache transfer when tensor parallelism sizes differ between prefill and decode nodes. The change is gated by a new environment variable. My review focuses on the correctness of the new implementation, code maintainability, and consistency with the documentation. I've identified a critical inconsistency in the environment variable naming that will prevent users from enabling this feature, and a significant code duplication that impacts maintainability. I've also noted an incorrect type hint that should be fixed for clarity.
python/sglang/srt/environ.py
Outdated
| SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE = EnvInt(2) | ||
| SGLANG_DISAGGREGATION_WAITING_TIMEOUT = EnvInt(300) | ||
| SGLANG_DISAGGREGATION_NIXL_BACKEND = EnvStr("UCX") | ||
| SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY = EnvBool(False) |
There was a problem hiding this comment.
The new environment variable is named SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY, but the pull request description and example scripts use SGLANG_DISAGGREGATION_DIFF_TP_NUMPY. This inconsistency will likely cause confusion for users trying to enable this optimization. For consistency and clarity, I recommend renaming it to match the documentation.
| SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY = EnvBool(False) | |
| SGLANG_DISAGGREGATION_DIFF_TP_NUMPY = EnvBool(False) |
| if use_numpy_opt: | ||
| # Compute indices and offsets | ||
| prefill_kv_indices_reshaped = prefill_kv_indices.astype(np.int64).reshape( | ||
| -1, 1 | ||
| ) | ||
| dst_kv_indices_reshaped = dst_kv_indices.astype(np.int64).reshape(-1, 1) | ||
| token_offsets = np.arange(page_size, dtype=np.int64).reshape(1, -1) | ||
| bytes_per_token_on_prefill = src_kv_item_len // page_size | ||
| bytes_per_token_on_decode = dst_kv_item_len // page_size | ||
| src_token_offsets_base = ( | ||
| token_offsets * bytes_per_token_on_prefill + src_head_slice_offset | ||
| ) | ||
| for layer_id in range(layers_current_pp_stage) | ||
| ] + [ | ||
| ( | ||
| src_v_ptrs[layer_id], | ||
| dst_v_ptrs[layer_id], | ||
| src_kv_item_len, | ||
| dst_kv_item_len, | ||
| src_head_slice_offset, | ||
| dst_head_slice_offset, | ||
| heads_bytes_per_token_to_send, | ||
| dst_token_offsets_base = ( | ||
| token_offsets * bytes_per_token_on_decode + dst_head_slice_offset | ||
| ) | ||
| for layer_id in range(layers_current_pp_stage) | ||
| ] | ||
|
|
||
| def process_layer_tp_aware(layer_params): | ||
| ( | ||
| src_ptr, | ||
| dst_ptr, | ||
| src_item_len, | ||
| dst_item_len, | ||
| src_head_slice_offset, | ||
| dst_head_slice_offset, | ||
| heads_bytes_per_token_to_send, | ||
| ) = layer_params | ||
| src_addr_list = [] | ||
| dst_addr_list = [] | ||
| length_list = [] | ||
|
|
||
| # Calculate strides for a single token slot | ||
| bytes_per_token_on_prefill = src_item_len // page_size | ||
| bytes_per_token_on_decode = dst_item_len // page_size | ||
|
|
||
| for i in range(len(prefill_kv_indices)): | ||
| prefill_page_idx = int(prefill_kv_indices[i]) | ||
| decode_page_idx = int(dst_kv_indices[i]) | ||
|
|
||
| # Get the starting addresses for the current src and dst pages | ||
| src_page_start_addr = src_ptr + prefill_page_idx * src_item_len | ||
| dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len | ||
|
|
||
| # Iterate through each valid token slot within the current page | ||
| for token_slot_in_page in range(page_size): | ||
| # Calculate the start address of the current token slot | ||
| src_token_slot_start_addr = ( | ||
| src_page_start_addr | ||
| + token_slot_in_page * bytes_per_token_on_prefill | ||
|
|
||
| def process_layer_tp_aware(ptrs): | ||
| src_ptr, dst_ptr = ptrs | ||
| src_page_starts = ( | ||
| src_ptr + prefill_kv_indices_reshaped * src_kv_item_len | ||
| ) | ||
| dst_page_starts = dst_ptr + dst_kv_indices_reshaped * dst_kv_item_len | ||
| src_addrs = src_page_starts + src_token_offsets_base | ||
| dst_addrs = dst_page_starts + dst_token_offsets_base | ||
| src_addr_list = src_addrs.reshape(-1).tolist() | ||
| if not src_addr_list: | ||
| return 0 | ||
| dst_addr_list = dst_addrs.reshape(-1).tolist() | ||
| total_chunks = len(src_addr_list) | ||
| length_list = [heads_bytes_per_token_to_send] * total_chunks | ||
| return self.engine.batch_transfer_sync( | ||
| mooncake_session_id, src_addr_list, dst_addr_list, length_list | ||
| ) | ||
|
|
||
| futures = [] | ||
| for i in range(layers_current_pp_stage): | ||
| futures.append( | ||
| executor.submit( | ||
| process_layer_tp_aware, (src_k_ptrs[i], dst_k_ptrs[i]) | ||
| ) | ||
| dst_token_slot_start_addr = ( | ||
| dst_page_start_addr | ||
| + token_slot_in_page * bytes_per_token_on_decode | ||
| ) | ||
| for i in range(layers_current_pp_stage): | ||
| futures.append( | ||
| executor.submit( | ||
| process_layer_tp_aware, (src_v_ptrs[i], dst_v_ptrs[i]) | ||
| ) | ||
| ) | ||
| else: | ||
| layers_params = [ | ||
| ( | ||
| src_k_ptrs[layer_id], | ||
| dst_k_ptrs[layer_id], | ||
| src_kv_item_len, | ||
| dst_kv_item_len, | ||
| src_head_slice_offset, | ||
| dst_head_slice_offset, | ||
| heads_bytes_per_token_to_send, | ||
| ) | ||
| for layer_id in range(layers_current_pp_stage) | ||
| ] + [ | ||
| ( | ||
| src_v_ptrs[layer_id], | ||
| dst_v_ptrs[layer_id], | ||
| src_kv_item_len, | ||
| dst_kv_item_len, | ||
| src_head_slice_offset, | ||
| dst_head_slice_offset, | ||
| heads_bytes_per_token_to_send, | ||
| ) | ||
| for layer_id in range(layers_current_pp_stage) | ||
| ] | ||
|
|
||
| def process_layer_tp_aware(layer_params): | ||
| ( | ||
| src_ptr, | ||
| dst_ptr, | ||
| src_item_len, | ||
| dst_item_len, | ||
| src_head_slice_offset, | ||
| dst_head_slice_offset, | ||
| heads_bytes_per_token_to_send, | ||
| ) = layer_params | ||
| src_addr_list = [] | ||
| dst_addr_list = [] | ||
| length_list = [] | ||
|
|
||
| # Calculate strides for a single token slot | ||
| bytes_per_token_on_prefill = src_item_len // page_size | ||
| bytes_per_token_on_decode = dst_item_len // page_size | ||
|
|
||
| for i in range(len(prefill_kv_indices)): | ||
| prefill_page_idx = int(prefill_kv_indices[i]) | ||
| decode_page_idx = int(dst_kv_indices[i]) | ||
|
|
||
| # Get the starting addresses for the current src and dst pages | ||
| src_page_start_addr = src_ptr + prefill_page_idx * src_item_len | ||
| dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len | ||
|
|
||
| # Iterate through each valid token slot within the current page | ||
| for token_slot_in_page in range(page_size): | ||
| # Calculate the start address of the current token slot | ||
| src_token_slot_start_addr = ( | ||
| src_page_start_addr | ||
| + token_slot_in_page * bytes_per_token_on_prefill | ||
| ) | ||
| dst_token_slot_start_addr = ( | ||
| dst_page_start_addr | ||
| + token_slot_in_page * bytes_per_token_on_decode | ||
| ) | ||
|
|
||
| # Calculate final src and dst addresses by applying head-slice offsets | ||
| src_slice_addr = src_token_slot_start_addr + src_head_slice_offset | ||
| dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset | ||
| # Calculate final src and dst addresses by applying head-slice offsets | ||
| src_slice_addr = ( | ||
| src_token_slot_start_addr + src_head_slice_offset | ||
| ) | ||
| dst_slice_addr = ( | ||
| dst_token_slot_start_addr + dst_head_slice_offset | ||
| ) | ||
|
|
||
| src_addr_list.append(src_slice_addr) | ||
| dst_addr_list.append(dst_slice_addr) | ||
| length_list.append(heads_bytes_per_token_to_send) | ||
| src_addr_list.append(src_slice_addr) | ||
| dst_addr_list.append(dst_slice_addr) | ||
| length_list.append(heads_bytes_per_token_to_send) | ||
|
|
||
| return self.engine.batch_transfer_sync( | ||
| mooncake_session_id, src_addr_list, dst_addr_list, length_list | ||
| ) | ||
| return self.engine.batch_transfer_sync( | ||
| mooncake_session_id, src_addr_list, dst_addr_list, length_list | ||
| ) | ||
|
|
||
| futures = [ | ||
| executor.submit( | ||
| process_layer_tp_aware, | ||
| layer_params, | ||
| ) | ||
| for layer_params in layers_params | ||
| ] | ||
| futures = [ | ||
| executor.submit( | ||
| process_layer_tp_aware, | ||
| layer_params, | ||
| ) | ||
| for layer_params in layers_params | ||
| ] |
There was a problem hiding this comment.
This if/else block introduces significant code duplication. The else block contains a full copy of the original implementation. This makes the code harder to read and maintain, as future changes might need to be applied in two places.
To improve maintainability, consider refactoring the original loop-based implementation into a separate private helper method. The else block would then simply call this new method. This would make the send_kvcache_slice function cleaner and the new optimized path more prominent.
| prefill_kv_indices_reshaped = prefill_kv_indices.astype(np.int64).reshape( | ||
| -1, 1 | ||
| ) | ||
| dst_kv_indices_reshaped = dst_kv_indices.astype(np.int64).reshape(-1, 1) |
There was a problem hiding this comment.
The type hints for prefill_kv_indices and dst_kv_indices in the send_kvcache_slice function signature are npt.NDArray[np.int64]. However, the actual data passed from transfer_worker is np.int32. This makes the .astype(np.int64) cast appear redundant.
To improve code clarity and correctness, please update the function signature's type hints to npt.NDArray[np.int32]. This will accurately reflect the input types and justify the necessity of the astype(np.int64) conversion for pointer arithmetic.
| use_numpy_opt = envs.SGLANG_DISAGGREGATION_OPT_DIFF_TP_NUMPY.get() | ||
| # Whether to use numpy optimization for different TP size on prefill node |
There was a problem hiding this comment.
If the performance gain is guaranteed, we should replace the original code with the new one, instead of adding a switch.
Or will this change be incompatible with some features?
There was a problem hiding this comment.
I’ve already tested the two models glm4.5-air-fp8 and qwen3-235b-fp8 in PD disaggregation. The testing setup also covered:
- 1×TP4 → 2×TP2
- 1×TP4 → 1×TP4DP4 (dp-attention, meaning
attention-tp-size = 1)
Therefore, I believe compatibility with MHA models can be guaranteed. To be safe, should I run any additional experiments?
There was a problem hiding this comment.
I mean we don't need this use_numpy_opt, just replace the original code, this is a general optimization, we don't need a fallback for the original one.
There was a problem hiding this comment.
Got it. I think we can directly replace the existing code. I will make a new commit
| src_addr_list.append(src_slice_addr) | ||
| dst_addr_list.append(dst_slice_addr) | ||
| length_list.append(heads_bytes_per_token_to_send) | ||
| # The original for loop is replaced by numpy optimization for different TP size on prefill node |
There was a problem hiding this comment.
This comment is not needed, since there is no original implementation anymore.
|
/tag-and-rerun-ci |
|
/rerun-failed-ci 1 |
…ze (sgl-project#18163) Co-authored-by: Shangming Cai <csmthu@gmail.com>
…ze (sgl-project#18163) Co-authored-by: Shangming Cai <csmthu@gmail.com>
…ze (sgl-project#18163) Co-authored-by: Shangming Cai <csmthu@gmail.com>

…efore KV transfer
Motivation
When deploying an MHA model with PD disaggregation, if the TP sizes of the prefill and decode nodes are different, KV transfer is performed by token, which leads to high TTFT.
SGLANG_DISAGGREGATION_DIFF_TP_NUMPY: use NumPy vectorization to accelerate the computation of destination pointers on the prefill side.
Modifications
Use Numpy vectorization to accelerate dst ptr computation in send_kvcache_slice
Accuracy Tests
w/o this PR:
w/ this PR, enable SGLANG_DISAGGREGATION_DIFF_TP_NUMPY
Benchmarking and Profiling
The results by benching 4K:1K, max-concurrency=64
w/o this PR:
w/ this PR, enable SGLANG_DISAGGREGATION_DIFF_TP_NUMPY
Scripts
Prefill (TP4)
Decode (2xTP2)
sglang-router
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci