Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
int64_t lse_size_before[] = {num_heads, batch_size, max_seqlen_q};
int64_t lse_size_after[] = {num_heads * max_seqlen_q, batch_size};
softmax_lse = softmax_lse.reshape(lse_size_before).transpose(1, 2).reshape(lse_size_after);


if (params.num_splits > 1){
// When KV-split is enabled (num_splits > 1), LSE is first computed partially through lse_accum tensors. Then, an additional kernel, combine_attn_seqk_parallel, reduces these partials into the final LSE.
// This kernel produces LSE in a [seqlen_q, h, b] layout which can be directly used as it is already in the canonical form.
softmax_lse = softmax_lse.reshape(lse_size_after);
}else{
// The standard forward kernel produces LSE in a [b, h, seqlen_q] layout.
// It must be transposed to the canonical [seqlen_q, h, b] layout.
softmax_lse = softmax_lse.reshape(lse_size_before).transpose(1, 2).reshape(lse_size_after);
}

}

return {out, softmax_lse};
Expand Down