[v0.13.0][Feature] Support DSA-CP for Hybrid scenario (#5702)#6120
[v0.13.0][Feature] Support DSA-CP for Hybrid scenario (#5702)#6120wangxiyuan merged 1 commit intovllm-project:releases/v0.13.0from
Conversation
There was a problem hiding this comment.
Code Review
This pull request extends DSA-CP to support hybrid prefill-decode scenarios, which is a significant feature for improving performance on Ascend NPUs. The changes primarily involve renaming SFA-CP to DSA-CP for consistency and introducing logic to manage o_proj weights differently for prefill and decode stages in mixed-role nodes. While the overall approach appears sound and aligns with the PR's goal, I've identified a critical bug related to a missing tensor reshape that will likely cause a runtime error, and a high-severity performance issue due to a redundant operation. Addressing these issues is crucial for the stability and efficiency of this new feature.
| q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] | ||
| q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] | ||
|
|
||
| cos_q, sin_q = cos, sin |
There was a problem hiding this comment.
The reshaping of cos and sin tensors for RoPE was removed from indexer_select_post_process, but it is necessary for the torch_npu.npu_rotary_mul operation on line 996. Without reshaping, cos_q and sin_q will have a shape of [num_tokens, rope_dim], which is not broadcastable with q_pe's shape of [num_tokens, num_heads, 1, rope_dim]. This will likely cause a runtime error or incorrect computation. Please restore the reshaping of cos and sin before they are used in npu_rotary_mul.
cos_q, sin_q = cos, sin
cos_q = cos_q.view(-1, 1, 1, self.qk_rope_head_dim)
sin_q = sin_q.view(-1, 1, 1, self.qk_rope_head_dim)| if kv_cache is not None: | ||
| torch_npu.npu_scatter_nd_update_( | ||
| kv_cache[2].view(-1, k.shape[-1]), | ||
| attn_metadata.slot_mapping.view(-1, 1), | ||
| k.view(-1, k.shape[-1])) # b, s, n, d |
There was a problem hiding this comment.
This block introduces a redundant update to kv_cache[2]. The torch_npu.npu_scatter_nd_update_ call on kv_cache[2] with tensor k is performed here, and then again inside the indexer_select_post_process method which is called on line 875. This duplicates the operation, which is inefficient and can lead to unexpected behavior. Please remove this redundant block.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
) Signed-off-by: zzhx1 <zzh_201018@outlook.com> > Extracted from PR vllm-project#5513 Based on the Sharded-CP feature PR:vllm-project#4702; RFC:vllm-project/vllm#30055 Extends DSA-CP to handle the FULL_DECODE_ONLY execution mode when running in a prefill-decode mixed (PD-mixed) serving environment, improving throughput and resource utilization for decode-intensive workloads. **In pure prefill nodes:** - Both q_proj and o_proj are sharded across world ranks, using **broadcast** for weights distribution. **In PD-mixed nodes (supporting both prefill and decode):** - q_proj is fully replicated (not sharded) to avoid communication overhead during decoding. - o_proj Using the original TP `RowParallelLinear` method to store weights **During prefill execution:** - o_proj forwards through all_gather to collect weights, reconstructing the complete o_proj weights on each card. **During decode (graph replay phase):** - Additional all_to_all (before o_proj) and reduce_scatter (after o_proj) are introduced to enable sequence-parallel output aggregation while maintaining correctness under SFA CP. - TTFT increased by **527%** - TPOT increased by **180%** <img width="1550" height="938" alt="image" src="https://github.com/user-attachments/assets/9b7a03d8-a3db-4a99-8923-6e5bfcfecf72" /> None - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn> Co-authored-by: clrs97 <524936896@qq.com>
) (vllm-project#6120) > Extracted from PR vllm-project#5513 Based on the Sharded-CP feature PR:vllm-project#4702; RFC:vllm-project/vllm#30055 Extends DSA-CP to handle the FULL_DECODE_ONLY execution mode when running in a prefill-decode mixed (PD-mixed) serving environment, improving throughput and resource utilization for decode-intensive workloads. **In pure prefill nodes:** - Both q_proj and o_proj are sharded across world ranks, using **broadcast** for weights distribution. **In PD-mixed nodes (supporting both prefill and decode):** - q_proj is fully replicated (not sharded) to avoid communication overhead during decoding. - o_proj Using the original TP `RowParallelLinear` method to store weights **During prefill execution:** - o_proj forwards through all_gather to collect weights, reconstructing the complete o_proj weights on each card. **During decode (graph replay phase):** - Additional all_to_all (before o_proj) and reduce_scatter (after o_proj) are introduced to enable sequence-parallel output aggregation while maintaining correctness under SFA CP. - TTFT increased by **527%** - TPOT increased by **180%** - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn> Co-authored-by: clrs97 <524936896@qq.com>
Extends DSA-CP to handle the FULL_DECODE_ONLY execution mode when running in a prefill-decode mixed (PD-mixed) serving environment, improving throughput and resource utilization for decode-intensive workloads.
In pure prefill nodes:
In PD-mixed nodes (supporting both prefill and decode):
RowParallelLinearmethod to store weightsDuring prefill execution:
During decode (graph replay phase):
Additional all_to_all (before o_proj) and reduce_scatter (after o_proj) are introduced to enable sequence-parallel output aggregation while maintaining correctness under SFA CP.
TTFT increased by 527%
TPOT increased by 180%
None
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?