Skip to content

Commit

Permalink
feat: change infer_ext ops function param order (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
CyCle1024 authored and jinminxi104 committed Aug 20, 2024
1 parent f170fb8 commit 41d9985
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def flash_context_attention(
for i in range(batch):
if torch.equal(q_seq_len[i], kv_seq_len[i]):
ext_ops.context_attention(
attn_output,
query_states,
key_states,
value_states,
Expand All @@ -35,13 +34,13 @@ def flash_context_attention(
num_q_heads,
num_kv_heads,
context.attention_mask[i:i + 1],
attn_output=attn_output,
)
else:
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
value_cache = value_cache.reshape(1, kv_cache_len,
num_kv_heads * dim)
ext_ops.paged_prefill_attention(
attn_output,
query_states,
key_cache,
value_cache,
Expand All @@ -53,14 +52,14 @@ def flash_context_attention(
num_q_heads,
num_kv_heads,
context.attention_mask[i:i + 1],
attn_output=attn_output,
)


def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
block_offsets, block_size):
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
ext_ops.paged_decode_attention(
attn_output.view(q.shape),
q,
k_cache,
v_cache,
Expand All @@ -69,6 +68,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
kv_seq_len,
num_q_heads,
num_kv_heads,
attn_output=attn_output.view(q.shape),
)


Expand Down

0 comments on commit 41d9985

Please sign in to comment.