diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 0feaee19642..2c890d47a03 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -765,7 +765,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); int64_t lse_size_before[] = {num_heads, batch_size, max_seqlen_q}; int64_t lse_size_after[] = {num_heads * max_seqlen_q, batch_size}; - softmax_lse = softmax_lse.reshape(lse_size_before).transpose(1, 2).reshape(lse_size_after); + + + if (params.num_splits > 1){ + // When KV-split is enabled (num_splits > 1), LSE is first computed partially through lse_accum tensors. Then, an additional kernel, combine_attn_seqk_parallel, reduces these partials into the final LSE. + // This kernel produces LSE in a [seqlen_q, h, b] layout which can be directly used as it is already in the canonical form. + softmax_lse = softmax_lse.reshape(lse_size_after); + }else{ + // The standard forward kernel produces LSE in a [b, h, seqlen_q] layout. + // It must be transposed to the canonical [seqlen_q, h, b] layout. + softmax_lse = softmax_lse.reshape(lse_size_before).transpose(1, 2).reshape(lse_size_after); + } + } return {out, softmax_lse};