Fix LSE output error in FA2 kv-split#87
Conversation
Signed-off-by: griii <guo_rui@mail.ustc.edu.cn>
|
Amazing! Thank you this is much appreciated! |
|
@LucasWilkinson
According to https://x.com/vllm_project/status/1985023958371184836, latest vLLM has the fix! Nice work! |
|
@danielhanchen Yes vllm-flash-attn ships inside vLLM now 👍 (the PyPi wheel is no-longer used), this is primarily easier to keep torch versions in-sync; its a bit hacky though so may go back to a dedicated wheel in the future. We don't actually rev Hope that helps! |
Background
During vLLM inference, some features like Cascade Attention require the LSE output from the attention mechanism.
When the FlashAttention-2 kernel operates with
seqlenq_ngroups_swapped = True(the case for an inference-only batch with GQA), it performs the attention computation with an internal shape of(b, ngroups, nheads_kv, d). After the computation, it restores thesoftmax_lsetensor to the expected layout using the following transformation:The Problem with split_kv
The issue arises when the split_kv path is triggered in FlashAttention-2's flash_api.cpp. The decision to use this path is made by the
num_splits_heuristicfunction. When split_kv is enabled, the kernel partitions the K/V tensors, stores partial LSE results in softmax_lse_accum, and finally launches acombine_attn_seqk_parallelkernel to reduce these partial results into the final LSE.The root cause of the bug lies in the memory layout defined within the combine_attn_seqk_parallel kernel. Specifically, when
seqlenq_ngroups_swapped = Trueandunpadded_lse = True, the layout for the output tensor gLSE_unpadded is constructed as follows:Because of this final_layout, the physical memory layout of gLSE_unpadded is already aligned with the desired output shape and can be correctly interpreted with a simple reshape like
softmax_lse.reshape(num_heads * max_seqlen_q, batch_size).However, the original code path unconditionally applies the
.transpose(1, 2)operation, which is incorrect for the split_kv case. This erroneous transpose corrupts the LSE layout, leading to a complete precision collapse in downstream operations.Reproducibility
This bug is hard to reproduce because it only manifests under a specific combination of conditions:
num_splits_heuristicis greater than 1. This value is sensitive to:Related Issues
vllm-project/vllm#17580
vllm-project/vllm#17652
vllm-project/vllm#17886
vllm-project/vllm#18345
vllm-project/vllm#22103
I successfully reproduced the issue using the sample program provided in vllm-project/vllm#22103 on an L20 machine, and the changes in this commit fix the problem.