Skip to content

Commit

Permalink
[dlinfer]change llm op interface of paged_prefill_attention. (#2977)
Browse files Browse the repository at this point in the history
* [dlinfer]modify interface to support camb multi-batch-conv

* [dlinfer]change order for paged_prefill
  • Loading branch information
JackWeiw authored Jan 13, 2025
1 parent f066384 commit 39af9c8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
max_q_seq_len: int = 1
max_kv_seq_len: int = 1
quant_meta: Dict = None
cu_seq_lens_kv: Optional[Tensor] = None


class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
Expand Down Expand Up @@ -79,6 +80,8 @@ def forward(
max_q_seq_len = attn_metadata.max_q_seq_len
max_kv_seq_len = attn_metadata.max_kv_seq_len
quant_bits = attn_metadata.quant_policy
cu_seq_lens_kv = attn_metadata.cu_seq_lens_kv

if attn_metadata.quant_meta is not None:
k_scales_zeros = [
next(attn_metadata.quant_meta['k_scales']),
Expand Down Expand Up @@ -128,6 +131,7 @@ def forward(
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
cu_seq_lens_kv=cu_seq_lens_kv,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
is_decoding=is_decoding,
Expand Down
7 changes: 7 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def prefill_attention(
q_start_loc: Tensor,
q_seq_len: Tensor,
kv_seq_len: Tensor,
cu_seq_lens_kv: Tensor,
max_q_seq_len: int,
max_kv_seq_len: int,
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
is_unpaged_prefill: Optional[bool],
Expand Down Expand Up @@ -51,7 +53,9 @@ def prefill_attention(
q_start_loc,
q_seq_len,
kv_seq_len,
cu_seq_lens_kv,
max_q_seq_len,
max_kv_seq_len,
num_q_heads,
num_kv_heads,
attn_mask,
Expand Down Expand Up @@ -105,6 +109,7 @@ def paged_attention_fwd(
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_seqlens: Tensor,
cu_seq_lens_kv: Tensor,
max_q_seq_len: int,
max_kv_seq_len: int,
is_decoding: bool,
Expand All @@ -127,7 +132,9 @@ def paged_attention_fwd(
q_start_loc,
q_seqlens,
kv_seqlens,
cu_seq_lens_kv,
max_q_seq_len,
max_kv_seq_len,
block_size,
attn_mask,
is_unpaged_prefill,
Expand Down

0 comments on commit 39af9c8

Please sign in to comment.