From c970e4868a5925785635e656e2658756984fc61d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 22 May 2024 15:10:13 -0700 Subject: [PATCH 1/3] Don't use kwargs in autograd functions --- vllm_flash_attn/flash_attn_interface.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 6a5f16cf7a8..3a81a302745 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -709,7 +709,7 @@ def flash_attn_qkvpacked_func( alibi_slopes, deterministic, return_attn_probs, - out=out, + out, ) @@ -786,7 +786,7 @@ def flash_attn_kvpacked_func( alibi_slopes, deterministic, return_attn_probs, - out=out, + out, ) @@ -863,7 +863,7 @@ def flash_attn_func( alibi_slopes, deterministic, return_attn_probs, - out=out, + out, ) @@ -928,7 +928,7 @@ def flash_attn_varlen_qkvpacked_func( alibi_slopes, deterministic, return_attn_probs, - out=out, + out, ) @@ -1019,7 +1019,7 @@ def flash_attn_varlen_kvpacked_func( alibi_slopes, deterministic, return_attn_probs, - out=out, + out, ) @@ -1112,7 +1112,7 @@ def flash_attn_varlen_func( deterministic, return_attn_probs, block_table, - out=out, + out, ) From 2425fa22bbb8165832927b47a11556bdc519a5c3 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 22 May 2024 15:49:36 -0700 Subject: [PATCH 2/3] Fix out tensor input --- csrc/flash_attn/flash_api.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index e9b04414187..97ade65aa79 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -293,6 +293,15 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } if (is_causal) { window_size_right = 0; } + // do only checks here + if (out_.has_value()) { + auto out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + } + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); @@ -321,12 +330,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size at::Tensor out; if (out_.has_value()) { out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og}).transpose(1, 2); } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { From 7ee22d998223c01758e923bb5af6a3083576b337 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 22 May 2024 15:53:26 -0700 Subject: [PATCH 3/3] Revert --- csrc/flash_attn/flash_api.cpp | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 97ade65aa79..e9b04414187 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -293,15 +293,6 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } if (is_causal) { window_size_right = 0; } - // do only checks here - if (out_.has_value()) { - auto out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - } - // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); @@ -330,8 +321,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size at::Tensor out; if (out_.has_value()) { out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og}).transpose(1, 2); + out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else {