-
-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[Bugfix] [Performance]Better MTP Support when use flashmla #24045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
af963eb
e39e008
fbd4dee
8d5e8a6
f02168a
2fa1ed4
d1dcc97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -184,8 +184,23 @@ def _forward_decode( | |
| q = torch.cat(q, dim=-1) | ||
|
|
||
| assert isinstance(q, torch.Tensor) | ||
|
|
||
| batch_size = attn_metadata.decode.seq_lens.shape[0] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you refactor this into a utility function? It will likely need to be called in each backend that supports this feature (FlashInfer-MLA at least), so it will be nice to be able to reuse the logic. |
||
| total_tokens = q.shape[0] | ||
| num_heads = q.shape[1] | ||
| head_dim = q.shape[2] | ||
|
|
||
| # support uniform batch | ||
| if total_tokens % batch_size == 0: | ||
| seq_len = total_tokens // batch_size | ||
| q = q.view(batch_size, seq_len, num_heads, head_dim) | ||
| else: | ||
| raise ValueError( | ||
| f"total_tokens={total_tokens}, batch_size={batch_size}. " | ||
| f"Expected uniform batches with seq_len=1 or seq_len=2.") | ||
|
|
||
| o, lse = flash_mla_with_kvcache( | ||
| q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) | ||
| q=q, | ||
| k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 | ||
| block_table=attn_metadata.decode.block_table, | ||
| cache_seqlens=attn_metadata.decode.seq_lens, | ||
|
|
@@ -199,4 +214,6 @@ def _forward_decode( | |
| descale_k=layer._k_scale.reshape(1), | ||
| ) | ||
|
|
||
| o = o.view(total_tokens, num_heads, self.kv_lora_rank) | ||
|
|
||
| return o, lse | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -645,35 +645,41 @@ def subclass_attention_backend( | |
| def split_decodes_and_prefills( | ||
| common_attn_metadata: CommonAttentionMetadata, | ||
| decode_threshold: int = 1, | ||
| require_uniform: bool = False, | ||
| ) -> tuple[int, int, int, int]: | ||
| """ | ||
| Assuming a reordered batch, finds the boundary between prefill and decode | ||
| requests. | ||
|
|
||
| Args: | ||
| common_attn_metadata: CommonAttentionMetadata object containing the | ||
| batch metadata. | ||
| decode_threshold: The maximum query length to be considered a decode. | ||
|
|
||
| require_uniform: If True, only selects decode requests with the same | ||
| query length for uniform batching | ||
| If False, selects all decode requests regardless of | ||
| length variation. | ||
|
|
||
| Returns: | ||
| num_decodes: The number of decode requests. | ||
| num_prefills: The number of prefill requests. | ||
| num_decode_tokens: The number of tokens in the decode requests. | ||
| num_prefill_tokens: The number of tokens in the prefill requests. | ||
| """ | ||
|
|
||
| if require_uniform: | ||
| return split_decodes_and_prefills_uniform(common_attn_metadata, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of a separate function couldn't we just do something like:
but we have to drop that for #24845 anyways
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh sorry I see you want to handle the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (and still dropping
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LucasWilkinson I think the current implementation is probably correct. In the case of To handle this more thoroughly you would have to modify the batch reordering code. This PR doesn't, and only does a best-effort pass to read uniform decodes from the front, falling back to prefills if there's a mismatch. I think that is fine for now. Edit* to make the example a better counterexample.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh ya sorry im not doubting the correctness of the current implementation, sorry for the confusion!; I was just suggesting we can just modify the existing implementation and do: instead of the current (and remove then we wouldn't need the separate function and could achieve the same effect with alot less code (and it would be vectorized)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LucasWilkinson It's not clear to me why this is doable. You're talking about a modification to
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh it's because 'is_prefill' is fed into 'argmax' to find the split point which should return the index of the first prefill and ignore any subsequent decodes |
||
| decode_threshold) | ||
|
|
||
| max_query_len = common_attn_metadata.max_query_len | ||
| num_reqs = common_attn_metadata.num_reqs | ||
| num_tokens = common_attn_metadata.num_actual_tokens | ||
| query_start_loc = common_attn_metadata.query_start_loc_cpu | ||
|
|
||
| if max_query_len <= decode_threshold: | ||
| return num_reqs, 0, num_tokens, 0 | ||
|
|
||
| query_lens = query_start_loc[1:] - query_start_loc[:-1] | ||
| is_prefill = query_lens > decode_threshold | ||
| if not torch.any(is_prefill): | ||
| return num_reqs, 0, num_tokens, 0 | ||
|
|
||
| first_prefill = is_prefill.int().argmax(dim=-1).item() | ||
| assert torch.all(query_lens[first_prefill:] > decode_threshold) | ||
| assert torch.all(query_lens[:first_prefill] <= decode_threshold) | ||
|
|
@@ -684,6 +690,53 @@ def split_decodes_and_prefills( | |
| return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) | ||
|
|
||
|
|
||
| def split_decodes_and_prefills_uniform( | ||
| common_attn_metadata: CommonAttentionMetadata, | ||
| decode_threshold: int = 1, | ||
| ) -> tuple[int, int, int, int]: | ||
| """ | ||
| Similar to split_decodes_and_prefills but ensures decode batch is uniform. | ||
| Only selects decode requests with the same query length. | ||
| """ | ||
| num_reqs = common_attn_metadata.num_reqs | ||
| num_tokens = common_attn_metadata.num_actual_tokens | ||
| query_start_loc = common_attn_metadata.query_start_loc_cpu | ||
| query_lens = query_start_loc[1:] - query_start_loc[:-1] | ||
| # find all candidates that satisfy the threshold | ||
| decode_candidates = query_lens <= decode_threshold | ||
|
|
||
| if not torch.any(decode_candidates): | ||
| return 0, num_reqs, 0, num_tokens | ||
|
|
||
| first_len = None | ||
| first_prefill = 0 | ||
|
|
||
| # find the longest continuous uniform sequence from the front | ||
| for i in range(num_reqs): | ||
| current_len = query_lens[i].item() | ||
| if current_len > decode_threshold: | ||
| # prefill request,stop | ||
| break | ||
| if first_len is None: | ||
| # the first decode request | ||
| first_len = current_len | ||
| first_prefill = 1 | ||
| elif current_len == first_len: | ||
| # same length, continue | ||
| first_prefill = i + 1 | ||
| else: | ||
| # different length, stop | ||
| break | ||
|
|
||
| num_decodes = first_prefill | ||
| num_prefills = num_reqs - num_decodes | ||
| num_decode_tokens = query_start_loc[first_prefill].item( | ||
| ) if first_prefill < len(query_start_loc) else num_tokens | ||
| num_prefill_tokens = num_tokens - num_decode_tokens | ||
|
|
||
| return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) | ||
|
|
||
|
|
||
| def reorder_batch_to_split_decodes_and_prefills( | ||
| input_batch: "InputBatch", | ||
| scheduler_output: "SchedulerOutput", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might have negative consequences for backends which do not have kernel support for spec-friendly decodes. If so, we might want to have a per-backend flag to modulate when we apply this. Something like: