-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[Perf] Deepseekv3 performance optimization for eager mode #598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7bba874
3d0d3ec
b836b85
bcd36d5
a61cf81
89fe930
a25cb95
3cfa7a3
e7b3435
0729e3d
dc70465
2ede312
11a3199
0694979
326f22d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,7 @@ class AscendMLAPrefillMetadata: | |
| input_positions: torch.Tensor | ||
| block_table: torch.Tensor | ||
| max_query_len: int | ||
| max_context_len: int | ||
| max_seq_lens: int | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -65,6 +65,7 @@ class AscendMLADecodeMetadata: | |
| input_positions: torch.Tensor | ||
| block_table: torch.Tensor | ||
| seq_lens: torch.Tensor | ||
| max_seq_lens: int | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -131,11 +132,6 @@ def __init__(self, | |
| self.runner = runner | ||
| scheduler_config = runner.scheduler_config | ||
| self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled | ||
| # self.attn_mask = None | ||
| # if AscendMLAMetadataBuilder._attn_mask_builder is None: | ||
| # AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( | ||
| # 128, self.runner.model_config.dtype | ||
| # ) | ||
|
|
||
| def reorder_batch(self, input_batch: "InputBatch", | ||
| scheduler_output: "SchedulerOutput") -> bool: | ||
|
|
@@ -222,12 +218,14 @@ def build(self, | |
| num_reqs] | ||
| seq_lens = seq_lens_cpu | ||
| max_query_len = query_lens.max().item() | ||
| max_context_len = seq_lens.max().item() | ||
| max_seq_lens = seq_lens.max().item() | ||
|
|
||
| prefill_metadata = None | ||
| if self._num_prefills > 0: | ||
| reqs_start = self._num_decodes # prefill_start | ||
| tokens_start = self._num_decode_tokens | ||
| max_query_len = query_lens[tokens_start:].max().item() | ||
| max_seq_lens = seq_lens[tokens_start:].max().item() | ||
|
|
||
| prefill_metadata = AscendMLAPrefillMetadata( | ||
| attn_mask=self.runner.attn_mask, | ||
|
|
@@ -236,15 +234,17 @@ def build(self, | |
| input_positions=input_positions[tokens_start:], | ||
| block_table=block_table[reqs_start:, ...], | ||
| max_query_len=max_query_len, | ||
| max_context_len=max_context_len, | ||
| max_seq_lens=max_seq_lens, | ||
| ) | ||
|
|
||
| decode_metadata = None | ||
| if self._num_decodes > 0: | ||
| max_seq_lens = seq_lens[:self._num_decodes].max().item() | ||
| decode_metadata = AscendMLADecodeMetadata( | ||
| input_positions=input_positions[:self._num_decode_tokens], | ||
| block_table=block_table[:self._num_decode_tokens, ...], | ||
| seq_lens=seq_lens[:self._num_decode_tokens]) | ||
| seq_lens=seq_lens[:self._num_decode_tokens], | ||
| max_seq_lens=max_seq_lens) | ||
|
|
||
| return self.metadata_cls( # type: ignore | ||
| num_actual_tokens=num_actual_tokens, | ||
|
|
@@ -306,12 +306,18 @@ def __init__( | |
| self.qk_rope_head_dim = qk_rope_head_dim | ||
| self.qk_head_dim = qk_head_dim | ||
| self.v_head_dim = v_head_dim | ||
| # TODO: below padding should be removed after kernel is ready | ||
| # we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here | ||
| # and slice the final result to guarantee its functionality. | ||
| self.padding_head_dim = ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In prefill, we use MHA for computation, then the head_dim = nope_dim + rope_dim (192), while in decode, the absorbed and move_elision strategies are adopt, the head_dim=nope_dim, and we don't need pad, am I right?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are definately right, this padding dim is used for prefill to padding the tensor. Not just for v_head_dim vs (qk_rope + qk_nope), but also for the 128 divisble head_dim alignment requirements for the |
||
| (self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 + | ||
| 1) * 128 | ||
|
|
||
| # Hack for V1 for now to avoid torch library overhead (since we are | ||
| # already inside an attention custom op), pull out the forward | ||
| # method from the rotary embedding and call it directly | ||
| # TODO(lucas): we should probably find a cleaner way to do this | ||
| self.rotary_emb = rotary_emb.forward_native | ||
| self.rotary_emb = rotary_emb | ||
|
|
||
| self.q_proj = q_proj | ||
| self.kv_b_proj = kv_b_proj | ||
|
|
@@ -409,37 +415,73 @@ def _forward_prefill( | |
| ) -> torch.Tensor: | ||
| assert attn_metadata.prefill is not None | ||
|
|
||
| # TODO: enable this compute for flash attention computation | ||
| # kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ | ||
| # -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) | ||
| # k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) | ||
| # key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) | ||
| # v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]], | ||
| # value=0) | ||
| num_tokens = query.size(0) | ||
| attn_output = torch.empty(num_tokens, | ||
| self.num_heads, | ||
| self.v_head_dim, | ||
| dtype=query.dtype, | ||
| device=query.device) | ||
| # current requests is chunked in prefill, disable flash attention with chunked prefill | ||
| vanilla_chunked_prefill_mla( | ||
| output=attn_output, | ||
| query=query, | ||
| kv_cache=kv_c_and_k_pe_cache, | ||
| block_tables=attn_metadata.prefill.block_table, | ||
| query_lens=attn_metadata.prefill.query_lens, | ||
| context_lens=attn_metadata.prefill.context_lens, | ||
| kv_b_proj=self.kv_b_proj, | ||
| max_query_len=attn_metadata.prefill.max_query_len, | ||
| max_context_len=attn_metadata.prefill.max_context_len, | ||
| nope_dim=self.qk_nope_head_dim, | ||
| rope_dim=self.qk_rope_head_dim, | ||
| v_head_dim=self.v_head_dim, | ||
| scale=self.scale, | ||
| alibi_slopes=None, | ||
| causal=True) | ||
| attn_output = attn_output.view( | ||
| attn_output = None | ||
| # Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly | ||
| if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill: | ||
| attn_output = torch.empty(num_tokens, | ||
| self.num_heads * self.v_head_dim, | ||
| dtype=query.dtype, | ||
| device=query.device) | ||
| # current requests is chunked in prefill, disable flash attention with chunked prefill | ||
| vanilla_chunked_prefill_mla( | ||
| output=attn_output, | ||
| query=query, | ||
| kv_cache=kv_c_and_k_pe_cache, | ||
| block_tables=attn_metadata.prefill.block_table, | ||
| query_lens=attn_metadata.prefill.query_lens, | ||
| context_lens=attn_metadata.prefill.context_lens, | ||
| kv_b_proj=self.kv_b_proj, | ||
| max_query_len=attn_metadata.prefill.max_query_len, | ||
| max_context_len=attn_metadata.prefill.max_seq_lens, | ||
| nope_dim=self.qk_nope_head_dim, | ||
| rope_dim=self.qk_rope_head_dim, | ||
| v_head_dim=self.v_head_dim, | ||
| scale=self.scale, | ||
| alibi_slopes=None, | ||
| causal=True) | ||
| elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly: | ||
| attn_output = torch.empty(num_tokens, | ||
| self.num_heads, | ||
| self.padding_head_dim, | ||
| dtype=query.dtype, | ||
| device=query.device) | ||
| k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( | ||
| -1, self.num_heads, | ||
| self.qk_nope_head_dim + self.v_head_dim).split( | ||
| [self.qk_nope_head_dim, self.v_head_dim], dim=-1) | ||
| key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), | ||
| dim=-1) | ||
| pad_query = torch.nn.functional.pad(query, [ | ||
| 0, self.padding_head_dim - self.qk_rope_head_dim - | ||
| self.qk_nope_head_dim | ||
| ], | ||
| value=0) | ||
| pad_key = torch.nn.functional.pad(key, [ | ||
| 0, self.padding_head_dim - self.qk_rope_head_dim - | ||
| self.qk_nope_head_dim | ||
| ], | ||
| value=0) | ||
| pad_value = torch.nn.functional.pad( | ||
| value, [0, self.padding_head_dim - self.v_head_dim], value=0) | ||
| torch_npu._npu_flash_attention( | ||
| query=pad_query, | ||
| key=pad_key, | ||
| value=pad_value, | ||
| mask=attn_metadata.attn_mask, | ||
| seq_len=attn_metadata.prefill.context_lens, | ||
| scale_value=self.scale, | ||
| num_heads=self.num_heads, | ||
| num_kv_heads=self.num_heads, | ||
| out=attn_output) | ||
| attn_output = attn_output.view( | ||
| -1, self.num_heads, | ||
| self.padding_head_dim)[:, :, :self.v_head_dim] | ||
| else: | ||
| raise RuntimeError( | ||
| "Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !" | ||
| ) | ||
| attn_output = attn_output.reshape( | ||
| [num_tokens, self.num_heads * self.v_head_dim]) | ||
| return self.o_proj(attn_output)[0] | ||
|
|
||
|
|
@@ -457,7 +499,7 @@ def _forward_decode( | |
|
|
||
| q = torch.cat([q_nope, q_pe], dim=-1) | ||
| num_tokens = q.size(0) | ||
| attn_output = torch.randn( | ||
| attn_output = torch.empty( | ||
| [num_tokens, self.num_heads, self.kv_lora_rank], | ||
| dtype=q.dtype, | ||
| device=q.device) | ||
|
|
@@ -522,8 +564,10 @@ def forward( | |
| decode_ql_nope, decode_q_pe = \ | ||
| self._q_proj_and_k_up_proj(decode_hs_or_q_c) | ||
| decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( | ||
| attn_metadata.decode.input_positions, decode_q_pe.contiguous(), | ||
| decode_k_pe) | ||
| attn_metadata.decode.input_positions, | ||
| decode_q_pe.contiguous(), | ||
| decode_k_pe, | ||
| max_seq_len=attn_metadata.decode.max_seq_lens) | ||
|
|
||
| if has_prefill: | ||
| assert attn_metadata.prefill is not None | ||
|
|
@@ -533,7 +577,9 @@ def forward( | |
|
|
||
| prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( | ||
| attn_metadata.prefill.input_positions, | ||
| prefill_q_pe.contiguous(), prefill_k_pe) | ||
| prefill_q_pe.contiguous(), | ||
| prefill_k_pe, | ||
| max_seq_len=attn_metadata.prefill.max_seq_lens) | ||
|
|
||
| if kv_cache.numel() > 0: | ||
| key = torch.cat([ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
query_lens is a device tensor? if so, many D2H here, is this operation necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
query_lens is actually a cpu tensor, so no d2h operation will happened here, you can refer to line 220