-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dlinfer] add DlinferFlashAttention to fix qwen vl (#2952)
- Loading branch information
1 parent
a0a7728
commit 3a98ae9
Showing
4 changed files
with
134 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from torch import Tensor | ||
|
||
from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl | ||
|
||
|
||
class DlinferFlashAttentionImpl(FlashAttentionImpl): | ||
"""dlinfer flash attention implementation.""" | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_dim: int, | ||
scale: float = None, | ||
num_kv_heads: int = None, | ||
v_head_dim: int = None, | ||
causal: bool = True, | ||
sliding_window: int = None, | ||
logical_softcapping: float = None, | ||
): | ||
if scale is None: | ||
scale = 1.0 / (head_dim**0.5) | ||
if num_kv_heads is None: | ||
num_kv_heads = num_heads | ||
if v_head_dim is None: | ||
v_head_dim = head_dim | ||
self.num_heads = num_heads | ||
self.head_dim = head_dim | ||
self.scale = scale | ||
self.num_kv_heads = num_kv_heads | ||
self.v_head_dim = v_head_dim | ||
self.causal = causal | ||
self.sliding_window = sliding_window | ||
self.logical_softcapping = logical_softcapping | ||
from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd | ||
self.flash_attention_fwd = flash_attention_fwd | ||
|
||
def forward(self, | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
q_start_loc: Tensor, | ||
q_seqlens: Tensor, | ||
kv_start_loc: Tensor, | ||
kv_seqlens: Tensor, | ||
max_q_seqlen: int = None): | ||
"""forward.""" | ||
q_shape = query.shape | ||
o_shape = q_shape[:-1] + (self.v_head_dim, ) | ||
out = query.new_empty(o_shape) | ||
self.flash_attention_fwd( | ||
query, | ||
key, | ||
value, | ||
out, | ||
q_start_loc=q_start_loc, | ||
q_seqlens=q_seqlens, | ||
kv_start_loc=kv_start_loc, | ||
kv_seqlens=kv_seqlens, | ||
max_q_seqlen=max_q_seqlen, | ||
window_size=self.sliding_window, | ||
sm_scale=self.scale, | ||
logit_softcapping=self.logical_softcapping, | ||
causal=self.causal, | ||
) | ||
return out | ||
|
||
|
||
class DlinferFlashAttentionBuilder(FlashAttentionBuilder): | ||
"""dlinfer attention builder.""" | ||
|
||
@staticmethod | ||
def build( | ||
num_heads: int, | ||
head_dim: int, | ||
scale: float = None, | ||
num_kv_heads: int = None, | ||
v_head_dim: int = None, | ||
causal: bool = True, | ||
sliding_window: int = None, | ||
logical_softcapping: float = None, | ||
**kwargs, | ||
) -> FlashAttentionImpl: | ||
"""build.""" | ||
return DlinferFlashAttentionImpl( | ||
num_heads=num_heads, | ||
head_dim=head_dim, | ||
scale=scale, | ||
num_kv_heads=num_kv_heads, | ||
v_head_dim=v_head_dim, | ||
causal=causal, | ||
sliding_window=sliding_window, | ||
logical_softcapping=logical_softcapping, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import dlinfer.ops as ext_ops | ||
from dlinfer.utils.type_annotation import Tensor | ||
|
||
|
||
def flash_attention_fwd( | ||
query_states: Tensor, | ||
key_states: Tensor, | ||
value_states: Tensor, | ||
attn_output: Tensor, | ||
q_start_loc: Tensor, | ||
q_seqlens: Tensor, | ||
kv_start_loc: Tensor, | ||
kv_seqlens: Tensor, | ||
max_q_seqlen: int = None, | ||
window_size: int = None, | ||
sm_scale: float = None, | ||
logit_softcapping: float = None, | ||
causal: bool = True, | ||
): | ||
num_q_heads = query_states.shape[1] | ||
num_kv_heads = value_states.shape[1] | ||
return ext_ops.prefill_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
q_start_loc, | ||
q_seqlens, | ||
max_q_seqlen, | ||
num_q_heads, | ||
num_kv_heads, | ||
attn_mask=None, | ||
softmax_scale=sm_scale, | ||
attn_output=attn_output, | ||
) |