Skip to content

Commit

Permalink
[Misc][Bugfix] FA3 support to ViT MHA layer (vllm-project#12435)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
ywang96 and Isotr0py authored Jan 26, 2025
1 parent 324960a commit 2a0309a
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,28 @@ def forward(
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
}:
from vllm.vllm_flash_attn import flash_attn_func

out = flash_attn_func(query, key, value, softmax_scale=self.scale)
from vllm.vllm_flash_attn import flash_attn_varlen_func

cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=query.device)
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
step=kv_len,
dtype=torch.int32,
device=key.device)

out = flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
out = out.reshape(bsz, q_len, -1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

Expand Down

0 comments on commit 2a0309a

Please sign in to comment.