Skip to content

Commit

Permalink
[dlinfer] add DlinferFlashAttention to fix qwen vl (#2952)
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou authored Dec 26, 2024
1 parent a0a7728 commit 3a98ae9
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 0 deletions.
94 changes: 94 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl


class DlinferFlashAttentionImpl(FlashAttentionImpl):
"""dlinfer flash attention implementation."""

def __init__(
self,
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
):
if scale is None:
scale = 1.0 / (head_dim**0.5)
if num_kv_heads is None:
num_kv_heads = num_heads
if v_head_dim is None:
v_head_dim = head_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.v_head_dim = v_head_dim
self.causal = causal
self.sliding_window = sliding_window
self.logical_softcapping = logical_softcapping
from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd
self.flash_attention_fwd = flash_attention_fwd

def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_start_loc: Tensor,
kv_seqlens: Tensor,
max_q_seqlen: int = None):
"""forward."""
q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_dim, )
out = query.new_empty(o_shape)
self.flash_attention_fwd(
query,
key,
value,
out,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
max_q_seqlen=max_q_seqlen,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logical_softcapping,
causal=self.causal,
)
return out


class DlinferFlashAttentionBuilder(FlashAttentionBuilder):
"""dlinfer attention builder."""

@staticmethod
def build(
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
**kwargs,
) -> FlashAttentionImpl:
"""build."""
return DlinferFlashAttentionImpl(
num_heads=num_heads,
head_dim=head_dim,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_dim=v_head_dim,
causal=causal,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
if layer_type == OpType.PagedAttention:
from .attention import DlinferAttentionBuilder
return DlinferAttentionBuilder
elif layer_type == OpType.FlashAttention:
from .flash_attention import DlinferFlashAttentionBuilder
return DlinferFlashAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
return DlinferApplyRotaryEmbBuilder
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
from .flash_attention import flash_attention_fwd
from .fused_moe import fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
Expand All @@ -16,6 +17,7 @@
'fill_kv_cache',
'fused_moe',
'paged_attention_fwd',
'flash_attention_fwd',
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
Expand Down
35 changes: 35 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from dlinfer.utils.type_annotation import Tensor


def flash_attention_fwd(
query_states: Tensor,
key_states: Tensor,
value_states: Tensor,
attn_output: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_start_loc: Tensor,
kv_seqlens: Tensor,
max_q_seqlen: int = None,
window_size: int = None,
sm_scale: float = None,
logit_softcapping: float = None,
causal: bool = True,
):
num_q_heads = query_states.shape[1]
num_kv_heads = value_states.shape[1]
return ext_ops.prefill_attention(
query_states,
key_states,
value_states,
q_start_loc,
q_seqlens,
max_q_seqlen,
num_q_heads,
num_kv_heads,
attn_mask=None,
softmax_scale=sm_scale,
attn_output=attn_output,
)

0 comments on commit 3a98ae9

Please sign in to comment.