-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking #36178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
76b77a5
46eed97
5c4a4c8
7607bd5
58a1f48
0a3cef6
62a5ee6
33d0a22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.config import VllmConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
|
|
@@ -22,14 +23,62 @@ | |
| ) | ||
| from vllm.v1.attention.backends.utils import ( | ||
| split_decodes_and_prefills, | ||
| split_prefill_chunks, | ||
| ) | ||
| from vllm.v1.kv_cache_interface import AttentionSpec | ||
| from vllm.v1.worker.cp_utils import get_total_cp_world_size | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def split_indexer_prefill_chunks( | ||
| seq_lens_cpu: torch.Tensor, | ||
| query_lens_cpu: torch.Tensor, | ||
| workspace_size: int, | ||
| max_logits_bytes: int, | ||
| request_offset: int = 0, | ||
| ) -> list[tuple[slice, slice]]: | ||
| """ | ||
| Split prefill requests into chunks for the sparse indexer, respecting: | ||
| - N constraint: total_seq_lens <= workspace_size (existing O(N) workspace) | ||
| - Logits constraint: M * N * 4 <= max_logits_bytes | ||
|
|
||
| When a single request-level chunk still exceeds the logits budget, | ||
| sub-chunks on the query dimension (M) to bound peak memory. | ||
|
|
||
| Returns list of (req_slice, query_slice) tuples. | ||
| """ | ||
| chunks: list[tuple[slice, slice]] = [] | ||
| n = len(seq_lens_cpu) | ||
| max_logits_elems = max_logits_bytes // 4 | ||
| end = 0 | ||
|
|
||
| while end < n: | ||
| start, chunk_m, chunk_n = end, 0, 0 | ||
|
|
||
| while end < n: | ||
| q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item() | ||
| new_m, new_n = chunk_m + q, chunk_n + s | ||
| if new_n <= workspace_size and new_m * new_n <= max_logits_elems: | ||
| chunk_m, chunk_n = new_m, new_n | ||
| end += 1 | ||
| else: | ||
| break | ||
|
|
||
| # A single request can exceed the budget, requiring sub-chunking | ||
| # on the query dimension. | ||
| if end == start: | ||
| chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item() | ||
| end += 1 | ||
|
|
||
| req_slice = slice(start + request_offset, end + request_offset) | ||
| max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m | ||
| for q_off in range(0, chunk_m, max_q): | ||
| sub_m = min(max_q, chunk_m - q_off) | ||
| chunks.append((req_slice, slice(q_off, q_off + sub_m))) | ||
|
|
||
| return chunks | ||
|
|
||
|
|
||
| class DeepseekV32IndexerBackend(AttentionBackend): | ||
| @staticmethod | ||
| def get_name() -> str: | ||
|
|
@@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: | |
| token_start: int | ||
| token_end: int | ||
| num_reqs: int | ||
| skip_kv_gather: bool = False | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -271,43 +321,51 @@ def __init__(self, *args, **kwargs): | |
| ) | ||
|
|
||
| def build_one_prefill_chunk( | ||
| self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table | ||
| ): | ||
| self, | ||
| req_slice: slice, | ||
| query_slice: slice, | ||
| query_start_loc_cpu, | ||
| seq_lens_cpu, | ||
| block_table, | ||
| skip_kv_gather: bool = False, | ||
| ) -> DeepseekV32IndexerPrefillChunkMetadata: | ||
| prefill_query_start_loc = ( | ||
| query_start_loc_cpu[reqs_start : reqs_end + 1] | ||
| - query_start_loc_cpu[reqs_start] | ||
| query_start_loc_cpu[req_slice.start : req_slice.stop + 1] | ||
| - query_start_loc_cpu[req_slice.start] | ||
| ) | ||
| cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kv_spans_from_batches would calculate multiple times for the same request right
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! but this is once per forward pass and is overlapped due to async scheduling so i dont think avoiding the redundant work here is critical |
||
| prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device | ||
| prefill_query_start_loc, seq_lens_cpu[req_slice], self.device | ||
| ) | ||
| token_start = query_start_loc_cpu[req_slice.start].item() | ||
| total_seq_lens = seq_lens_cpu[req_slice].sum() | ||
| num_reqs = req_slice.stop - req_slice.start | ||
| seq_idx = torch.arange(0, num_reqs, dtype=torch.int32) | ||
| token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to( | ||
| self.device | ||
| ) | ||
| token_start = query_start_loc_cpu[reqs_start].item() | ||
| token_end = query_start_loc_cpu[reqs_end].item() | ||
| total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() | ||
| seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32) | ||
| token_to_seq = torch.repeat_interleave( | ||
| seq_idx, seq_lens_cpu[reqs_start:reqs_end] | ||
| ).to(self.device) | ||
| assert total_seq_lens <= self.max_prefill_buffer_size | ||
| cu_seq_lens = ( | ||
| torch.cat( | ||
| [ | ||
| torch.zeros(1, dtype=torch.int32), | ||
| seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0), | ||
| seq_lens_cpu[req_slice].cumsum(dim=0), | ||
| ] | ||
| ) | ||
| .to(torch.int32) | ||
| .to(self.device) | ||
| ) | ||
|
|
||
| return DeepseekV32IndexerPrefillChunkMetadata( | ||
| cu_seqlen_ks=cu_seqlen_ks, | ||
| cu_seqlen_ke=cu_seqlen_ke, | ||
| cu_seqlen_ks=cu_seqlen_ks[query_slice], | ||
| cu_seqlen_ke=cu_seqlen_ke[query_slice], | ||
| cu_seq_lens=cu_seq_lens, | ||
| token_to_seq=token_to_seq, | ||
| total_seq_lens=total_seq_lens, | ||
| block_table=block_table[reqs_start:reqs_end], | ||
| token_start=token_start, | ||
| token_end=token_end, | ||
| num_reqs=reqs_end - reqs_start, | ||
| block_table=block_table[req_slice], | ||
| token_start=token_start + query_slice.start, | ||
| token_end=token_start + query_slice.stop, | ||
| num_reqs=num_reqs, | ||
| skip_kv_gather=skip_kv_gather, | ||
| ) | ||
|
|
||
| def build( | ||
|
|
@@ -333,20 +391,27 @@ def build( | |
|
|
||
| prefill_metadata = None | ||
| if num_prefills > 0: | ||
| chunk_seq_ids = split_prefill_chunks( | ||
| prefill_query_lens_cpu = torch.diff( | ||
| query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] | ||
| ) | ||
| max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 | ||
| chunk_specs = split_indexer_prefill_chunks( | ||
| common_attn_metadata.seq_lens_cpu[num_decodes:], | ||
| prefill_query_lens_cpu, | ||
| self.max_prefill_buffer_size, | ||
| max_logits_bytes, | ||
| request_offset=num_decodes, | ||
| ) | ||
| chunks = [ | ||
| self.build_one_prefill_chunk( | ||
| reqs_start, | ||
| reqs_end, | ||
| req_slice, | ||
| query_slice, | ||
| query_start_loc_cpu, | ||
| common_attn_metadata.seq_lens_cpu, | ||
| common_attn_metadata.block_table_tensor, | ||
| skip_kv_gather=query_slice.start > 0, | ||
| ) | ||
| for reqs_start, reqs_end in chunk_seq_ids | ||
| for req_slice, query_slice in chunk_specs | ||
| ] | ||
| prefill_metadata = DeepseekV32IndexerPrefillMetadata( | ||
| chunks=chunks, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.