-
Notifications
You must be signed in to change notification settings - Fork 5.1k
[BugFix] fix prefixcache performance and accuracy on ascend #13573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1b6809e
7bc6abe
409f2fb
00484ce
fdf1ab9
00dc8b1
8cb3fd7
a70bab8
4766af7
1298685
abdb6d8
ee96504
f86e50c
674eb8d
9a5cb14
de4dfca
ed234d9
a9246c8
efef2c2
3912328
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,9 +43,13 @@ class ForwardMetadata: | |
| seq_lens: Optional[torch.Tensor] = None | ||
| actual_seq_lengths_q: Optional[torch.Tensor] = None | ||
|
|
||
| # prefix cache | ||
| prefix_lens: Optional[torch.Tensor] = None | ||
| flatten_prefix_block_tables: Optional[torch.Tensor] = None | ||
|
|
||
|
|
||
| class AscendAttnMaskBuilder: | ||
| def __init__(self, model_runner: ModelRunner, device, use_fia): | ||
| def __init__(self, model_runner: ModelRunner, device, use_fia, use_mla): | ||
| """ | ||
| Initialize the AscendAttnMaskBuilder class. | ||
|
|
||
|
|
@@ -76,6 +80,13 @@ def __init__(self, model_runner: ModelRunner, device, use_fia): | |
| self.mix_mask_cache = self.generate_attn_mask(mixed_chunk_cache_len, "mix") | ||
| self.mix_seq_len_cached = self.mix_mask_cache.shape[0] | ||
|
|
||
| if use_mla: | ||
| # Initialize RingMla mask | ||
| ringmla_mask_len = 512 | ||
| self.ringmla_mask = self.generate_attn_mask( | ||
| ringmla_mask_len, "norm", torch.bfloat16 | ||
| ).to(self.device) | ||
|
|
||
| @staticmethod | ||
| def generate_mask_flag(max_seq_len): | ||
| """ | ||
|
|
@@ -216,6 +227,7 @@ def __init__(self, model_runner: ModelRunner): | |
| if self.use_mla: | ||
| self.kv_lora_rank = model_runner.model_config.kv_lora_rank | ||
| self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim | ||
| self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim | ||
| self.q_head_dim = ( | ||
| self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim | ||
| ) | ||
|
|
@@ -229,14 +241,16 @@ def __init__(self, model_runner: ModelRunner): | |
| model_runner.server_args.speculative_num_draft_tokens | ||
| ) | ||
| self.ascend_attn_mask_builder = AscendAttnMaskBuilder( | ||
| model_runner, self.device, self.use_fia | ||
| model_runner, self.device, self.use_fia, self.use_mla | ||
| ) | ||
| self.mask, self.fia_mask, self.mtp_mask, self.mix_mask = ( | ||
| self.ascend_attn_mask_builder.mask, | ||
| self.ascend_attn_mask_builder.fia_mask, | ||
| self.ascend_attn_mask_builder.mtp_mask, | ||
| self.ascend_attn_mask_builder.mix_mask_cache, | ||
| ) | ||
| if self.use_mla: | ||
| self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask | ||
|
|
||
| def get_verify_buffers_to_fill_after_draft(self): | ||
| """ | ||
|
|
@@ -279,6 +293,33 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): | |
|
|
||
| if forward_batch.forward_mode.is_target_verify(): | ||
| self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens | ||
|
|
||
| if ( | ||
| self.use_mla | ||
| and forward_batch.forward_mode.is_extend() | ||
| and sum(forward_batch.extend_prefix_lens_cpu) > 0 | ||
| ): | ||
| self.forward_metadata.prefix_lens = forward_batch.extend_prefix_lens.to( | ||
| "cpu" | ||
| ) | ||
| seq_prefix_lens = self.forward_metadata.prefix_lens.tolist() | ||
| self.forward_metadata.flatten_prefix_block_tables = torch.empty( | ||
| 0, dtype=torch.int32 | ||
| ).to(self.device) | ||
| for req_idx, seq_len in zip( | ||
| forward_batch.req_pool_indices.tolist(), seq_prefix_lens | ||
| ): | ||
| req_indices = forward_batch.req_to_token_pool.req_to_token[req_idx] | ||
| req_prefix_block_tables = ( | ||
| req_indices[:seq_len][:: self.page_size] // self.page_size | ||
| ) | ||
| self.forward_metadata.flatten_prefix_block_tables = torch.cat( | ||
| ( | ||
| self.forward_metadata.flatten_prefix_block_tables, | ||
| torch.flatten(req_prefix_block_tables), | ||
| ) | ||
| ) | ||
|
|
||
| if forward_batch.forward_mode.is_mixed(): | ||
| self.mix_mask = self.ascend_attn_mask_builder.update_mask( | ||
| self.forward_metadata | ||
|
|
@@ -590,15 +631,99 @@ def forward_extend( | |
| enable_gqa=use_gqa, | ||
| causal=causal, | ||
| ) | ||
| elif sum(forward_batch.extend_prefix_lens_cpu) > 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check if this feature supports mtp
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feature can be used together with MTP. Since the KV cache in the MTP stage is relatively small, enabling prefix cache is not necessary for now. |
||
| q, k, v = [ | ||
| data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v] | ||
| ] | ||
| q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) | ||
| k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1) | ||
|
|
||
| # 1st, compute extend tokens to get attn_output and attn_lse | ||
| num_tokens = q_nope.size(0) | ||
| attn_output = torch.zeros( | ||
| num_tokens, | ||
| layer.tp_q_head_num, | ||
| layer.v_head_dim, | ||
| dtype=q_nope.dtype, | ||
| device=q_nope.device, | ||
| ) | ||
| attn_lse = torch.zeros( | ||
| layer.tp_q_head_num, | ||
| num_tokens, | ||
| dtype=torch.float32, | ||
| device=q_nope.device, | ||
| ) | ||
| torch_npu.atb.npu_ring_mla( | ||
| q_nope=q_nope, | ||
| q_rope=q_rope, | ||
| k_nope=k_nope, | ||
| k_rope=k_rope, | ||
| value=v, | ||
| mask=self.ringmla_mask, | ||
| seqlen=self.forward_metadata.extend_seq_lens_cpu_int, | ||
| head_num=layer.tp_q_head_num, | ||
| kv_head_num=layer.tp_k_head_num, | ||
| pre_out=None, | ||
| prev_lse=None, | ||
| qk_scale=layer.scaling, | ||
| kernel_type="kernel_type_high_precision", | ||
| mask_type="mask_type_triu", | ||
| calc_type="calc_type_first_ring", | ||
| output=attn_output, | ||
| softmax_lse=attn_lse, | ||
| ) | ||
|
|
||
| # 2nd, load history kvcache(kv_a and k_pe) and calculate k_nope | ||
| k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) | ||
| v_buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) | ||
| kv_cached = torch.index_select( | ||
| k_buffer, 0, self.forward_metadata.flatten_prefix_block_tables | ||
| ) | ||
| k_rope_cached = torch.index_select( | ||
| v_buffer, 0, self.forward_metadata.flatten_prefix_block_tables | ||
| ).flatten(0, 1) | ||
|
|
||
| assert layer.kv_b_proj is not None | ||
| kv = layer.kv_b_proj(kv_cached)[0].view( | ||
| -1, layer.tp_k_head_num, self.qk_nope_head_dim + layer.v_head_dim | ||
| ) | ||
| k_nope, v = kv.split([self.qk_nope_head_dim, layer.v_head_dim], dim=-1) | ||
|
|
||
| # 3rd, compute history kv to attn_out | ||
| k_rope = k_rope_cached.expand(-1, layer.tp_k_head_num, -1) | ||
| seq_len = torch.stack( | ||
| [ | ||
| self.forward_metadata.extend_seq_lens_cpu_int, | ||
| self.forward_metadata.prefix_lens, | ||
| ] | ||
| ) | ||
| torch_npu.atb.npu_ring_mla( | ||
| q_nope=q_nope, | ||
| q_rope=q_rope, | ||
| k_nope=k_nope, | ||
| k_rope=k_rope, | ||
| value=v, | ||
| mask=self.ringmla_mask, | ||
| seqlen=seq_len, | ||
| head_num=layer.tp_q_head_num, | ||
| kv_head_num=layer.tp_k_head_num, | ||
| pre_out=attn_output, | ||
| prev_lse=attn_lse, | ||
| qk_scale=layer.scaling, | ||
| kernel_type="kernel_type_high_precision", | ||
| mask_type="no_mask", | ||
| calc_type="calc_type_default", | ||
| output=attn_output, | ||
| softmax_lse=attn_lse, | ||
| ) | ||
| attn_output = attn_output.reshape( | ||
| [-1, layer.tp_q_head_num, layer.v_head_dim] | ||
| ) | ||
| else: | ||
| assert ( | ||
| layer.qk_head_dim != layer.v_head_dim | ||
| ), "FIA only supports qk_head_dim != v_head_dim" | ||
|
|
||
| # Wait for the KV transfer to complete before performing attention computation. | ||
| forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) | ||
| forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) | ||
|
|
||
| num_token_padding = q.shape[0] | ||
| q, k, v = [ | ||
| data[: forward_batch.num_token_non_padded_cpu] for data in [q, k, v] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if only mla models need these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirmed that only MLA models need this. Added the self.use_mla check accordingly.