diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a90bb4fbf5ab3..db682b4ac63b0 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -251,9 +251,28 @@ def forward( _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, }: - from vllm.vllm_flash_attn import flash_attn_func - - out = flash_attn_func(query, key, value, softmax_scale=self.scale) + from vllm.vllm_flash_attn import flash_attn_varlen_func + + cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device) + cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, + step=kv_len, + dtype=torch.int32, + device=key.device) + + out = flash_attn_varlen_func( + query.flatten(0, 1), + key.flatten(0, 1), + value.flatten(0, 1), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=self.scale, + ) + out = out.reshape(bsz, q_len, -1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops