Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dlinfer] feat: add DlinferFlashAttention to support qwen vl. #2952

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl


class DlinferFlashAttentionImpl(FlashAttentionImpl):
"""dlinfer flash attention implementation."""

def __init__(
self,
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
):
if scale is None:
scale = 1.0 / (head_dim**0.5)
if num_kv_heads is None:
num_kv_heads = num_heads
if v_head_dim is None:
v_head_dim = head_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.v_head_dim = v_head_dim
self.causal = causal
self.sliding_window = sliding_window
self.logical_softcapping = logical_softcapping
from lmdeploy.pytorch.kernels.dlinfer import flash_attention_fwd
self.flash_attention_fwd = flash_attention_fwd

def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_start_loc: Tensor,
kv_seqlens: Tensor,
max_q_seqlen: int = None):
"""forward."""
q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_dim, )
out = query.new_empty(o_shape)
self.flash_attention_fwd(
query,
key,
value,
out,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
max_q_seqlen=max_q_seqlen,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logical_softcapping,
causal=self.causal,
)
return out


class DlinferFlashAttentionBuilder(FlashAttentionBuilder):
"""dlinfer attention builder."""

@staticmethod
def build(
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
**kwargs,
) -> FlashAttentionImpl:
"""build."""
return DlinferFlashAttentionImpl(
num_heads=num_heads,
head_dim=head_dim,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_dim=v_head_dim,
causal=causal,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
if layer_type == OpType.PagedAttention:
from .attention import DlinferAttentionBuilder
return DlinferAttentionBuilder
elif layer_type == OpType.FlashAttention:
from .flash_attention import DlinferFlashAttentionBuilder
return DlinferFlashAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import DlinferApplyRotaryEmbBuilder
return DlinferApplyRotaryEmbBuilder
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .awq_kernels import awq_linear
from .fill_kv_cache import fill_kv_cache
from .flash_attention import flash_attention_fwd
from .fused_moe import fused_moe
from .linear import linear
from .moe_gating_topk_softmax import moe_gating_topk_softmax
Expand All @@ -16,6 +17,7 @@
'fill_kv_cache',
'fused_moe',
'paged_attention_fwd',
'flash_attention_fwd',
'linear',
'moe_gating_topk_softmax',
'multinomial_sampling',
Expand Down
35 changes: 35 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
from dlinfer.utils.type_annotation import Tensor


def flash_attention_fwd(
query_states: Tensor,
key_states: Tensor,
value_states: Tensor,
attn_output: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_start_loc: Tensor,
kv_seqlens: Tensor,
max_q_seqlen: int = None,
window_size: int = None,
sm_scale: float = None,
logit_softcapping: float = None,
causal: bool = True,
):
num_q_heads = query_states.shape[1]
num_kv_heads = value_states.shape[1]
return ext_ops.prefill_attention(
query_states,
key_states,
value_states,
q_start_loc,
q_seqlens,
max_q_seqlen,
num_q_heads,
num_kv_heads,
attn_mask=None,
softmax_scale=sm_scale,
attn_output=attn_output,
)
Loading