Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he

const int batch_size = sizes[0];
int seqlen_q = sizes[1];
const int seqlen_q_og = seqlen_q;
int num_heads = sizes[2];
const int num_heads_og = num_heads;
const int head_size_og = sizes[3];

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
Expand Down Expand Up @@ -784,8 +786,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
CHECK_SHAPE(out, batch_size, seqlen_q_og, num_heads_og, head_size_og);
if (head_size_og % 8 != 0) {
out = torch::empty_like(q_padded);
} else if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og}).transpose(1, 2);
}
} else {
out = torch::empty_like(q_padded);
}
Expand Down