diff --git a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py index eed147b623..c2730ae25b 100644 --- a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py +++ b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py @@ -26,7 +26,6 @@ def flash_context_attention( for i in range(batch): if torch.equal(q_seq_len[i], kv_seq_len[i]): ext_ops.context_attention( - attn_output, query_states, key_states, value_states, @@ -35,13 +34,13 @@ def flash_context_attention( num_q_heads, num_kv_heads, context.attention_mask[i:i + 1], + attn_output=attn_output, ) else: key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) ext_ops.paged_prefill_attention( - attn_output, query_states, key_cache, value_cache, @@ -53,6 +52,7 @@ def flash_context_attention( num_q_heads, num_kv_heads, context.attention_mask[i:i + 1], + attn_output=attn_output, ) @@ -60,7 +60,6 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, block_offsets, block_size): num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] ext_ops.paged_decode_attention( - attn_output.view(q.shape), q, k_cache, v_cache, @@ -69,6 +68,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, kv_seq_len, num_q_heads, num_kv_heads, + attn_output=attn_output.view(q.shape), )