Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):


def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
q,
k,
v,
None,
out,
alibi_slopes,
dropout_p,
softmax_scale,
Expand Down Expand Up @@ -80,14 +80,16 @@ def _flash_attn_varlen_forward(
alibi_slopes,
return_softmax,
block_table,
*,
out=None
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q,
k,
v,
None,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
Expand Down Expand Up @@ -220,6 +222,8 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
*,
out=None,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
Expand All @@ -233,6 +237,7 @@ def forward(
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -284,6 +289,8 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
*,
out=None,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
Expand All @@ -302,6 +309,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -357,6 +365,7 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -370,6 +379,7 @@ def forward(
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -426,6 +436,7 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -444,6 +455,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
out=out,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Expand Down Expand Up @@ -505,6 +517,7 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -518,6 +531,7 @@ def forward(
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -575,6 +589,7 @@ def forward(
deterministic,
return_softmax,
block_table,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -593,6 +608,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
out=out,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Expand Down Expand Up @@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
Expand Down Expand Up @@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -704,6 +723,8 @@ def flash_attn_kvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
Expand Down Expand Up @@ -765,6 +786,7 @@ def flash_attn_kvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -779,6 +801,8 @@ def flash_attn_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -839,6 +863,7 @@ def flash_attn_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
Expand Down Expand Up @@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
Expand Down Expand Up @@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -1008,6 +1039,8 @@ def flash_attn_varlen_func(
deterministic=False,
return_attn_probs=False,
block_table=None,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1079,6 +1112,7 @@ def flash_attn_varlen_func(
deterministic,
return_attn_probs,
block_table,
out=out,
)


Expand All @@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache(
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
*,
out=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
Expand Down Expand Up @@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache(
cache_batch_idx,
block_table,
alibi_slopes,
None,
out,
softmax_scale,
causal,
window_size[0],
Expand Down