diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 3f6faf51de6d..c49ccd24e3ad 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -42,6 +42,7 @@ FlashMLASparseBackend, triton_convert_req_index_to_global_index, ) +from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks from vllm.v1.attention.backends.utils import split_prefill_chunks from vllm.v1.attention.ops import flashmla @@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected): assert out == expected +@pytest.mark.parametrize( + "seq_lens,query_lens,workspace_size,max_logits_bytes,expected", + [ + # Logits constraint triggers split (M*N exceeds budget) + # req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000 + # req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250 + ( + torch.tensor([100, 100, 100]), + torch.tensor([10, 10, 10]), + 1000, # workspace allows all + 5000, # 1250 float32 elems -> forces split + [ + (slice(0, 1), slice(0, 10)), + (slice(1, 2), slice(0, 10)), + (slice(2, 3), slice(0, 10)), + ], + ), + # Both constraints satisfied - all fit in one chunk + ( + torch.tensor([10, 10, 10]), + torch.tensor([5, 5, 5]), + 100, + 10000, # 2500 elems, M*N = 15*30 = 450 < 2500 + [(slice(0, 3), slice(0, 15))], + ), + # Workspace constraint triggers first + ( + torch.tensor([50, 50, 50]), + torch.tensor([1, 1, 1]), + 50, # workspace only fits one at a time + 1000000, # logits budget is huge + [ + (slice(0, 1), slice(0, 1)), + (slice(1, 2), slice(0, 1)), + (slice(2, 3), slice(0, 1)), + ], + ), + # Greedy filling: first two fit, third doesn't + # req0: M=5, N=10 -> 50 elems + # req0+1: M=10, N=20 -> 200 elems <= 250 + # req0+1+2: M=15, N=30 -> 450 elems > 250 + ( + torch.tensor([10, 10, 10]), + torch.tensor([5, 5, 5]), + 100, + 1000, # 250 elems + [(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))], + ), + ], +) +def test_split_indexer_prefill_chunks( + seq_lens, query_lens, workspace_size, max_logits_bytes, expected +): + out = split_indexer_prefill_chunks( + seq_lens, + query_lens, + workspace_size, + max_logits_bytes, + ) + assert out == expected + + +def test_split_indexer_prefill_chunks_single_request_overflow(): + """Test that single request exceeding budget is sub-chunked on query dim.""" + seq_lens = torch.tensor([1000, 50]) + query_lens = torch.tensor([100, 5]) + + out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000) + # max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks + expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)] + # req1: M=5, N=50 -> 250 elems fits budget + expected.append((slice(1, 2), slice(0, 5))) + assert out == expected + + def test_triton_convert_returns_valid_counts(): """Test that return_valid_counts correctly counts non-negative indices.""" device = torch.device("cuda") diff --git a/vllm/envs.py b/vllm/envs.py index fb68fc4b1d37..0a4ce9b7cad5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -55,6 +55,7 @@ VLLM_CPU_INT4_W4A8: bool = True VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False + VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512 VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True @@ -861,6 +862,12 @@ def _get_or_set_default() -> str: ), # Enable SPMD mode for TPU backend. "VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), + # Maximum size (in MB) for logits tensor in sparse MLA indexer prefill chunks. + # Bounds the [M, N] float32 logits tensor to prevent CUDA OOM. + # Default: 512 MB + "VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int( + os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512") + ), # If set, the OpenAI API server will stay alive even after the underlying # AsyncLLMEngine errors and stops serving requests "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool( diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 496b457a15e2..ca148536f327 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.forward_context import get_forward_context from vllm.logger import init_logger @@ -51,6 +52,14 @@ def sparse_attn_indexer( ((total_seq_lens, head_dim), torch.float8_e4m3fn), ((total_seq_lens, 4), torch.uint8), ) + + # Dummy allocation to simulate for peak logits tensor memory during inference. + # FP8 elements so elements == bytes + max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + _ = torch.empty( + max_logits_elems, dtype=torch.uint8, device=hidden_states.device + ) + return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, @@ -101,13 +110,16 @@ def sparse_attn_indexer( for chunk in prefill_metadata.chunks: k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens] - ops.cp_gather_indexer_k_quant_cache( - kv_cache, - k_fp8, - k_scale, - chunk.block_table, - chunk.cu_seq_lens, - ) + + if not chunk.skip_kv_gather: + ops.cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + logits = fp8_mqa_logits( q_fp8[chunk.token_start : chunk.token_end], (k_fp8, k_scale.view(torch.float32).flatten()), diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 5deac4d2f108..927583a0f17b 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -22,7 +23,6 @@ ) from vllm.v1.attention.backends.utils import ( split_decodes_and_prefills, - split_prefill_chunks, ) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.cp_utils import get_total_cp_world_size @@ -30,6 +30,55 @@ logger = init_logger(__name__) +def split_indexer_prefill_chunks( + seq_lens_cpu: torch.Tensor, + query_lens_cpu: torch.Tensor, + workspace_size: int, + max_logits_bytes: int, + request_offset: int = 0, +) -> list[tuple[slice, slice]]: + """ + Split prefill requests into chunks for the sparse indexer, respecting: + - N constraint: total_seq_lens <= workspace_size (existing O(N) workspace) + - Logits constraint: M * N * 4 <= max_logits_bytes + + When a single request-level chunk still exceeds the logits budget, + sub-chunks on the query dimension (M) to bound peak memory. + + Returns list of (req_slice, query_slice) tuples. + """ + chunks: list[tuple[slice, slice]] = [] + n = len(seq_lens_cpu) + max_logits_elems = max_logits_bytes // 4 + end = 0 + + while end < n: + start, chunk_m, chunk_n = end, 0, 0 + + while end < n: + q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item() + new_m, new_n = chunk_m + q, chunk_n + s + if new_n <= workspace_size and new_m * new_n <= max_logits_elems: + chunk_m, chunk_n = new_m, new_n + end += 1 + else: + break + + # A single request can exceed the budget, requiring sub-chunking + # on the query dimension. + if end == start: + chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item() + end += 1 + + req_slice = slice(start + request_offset, end + request_offset) + max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m + for q_off in range(0, chunk_m, max_q): + sub_m = min(max_q, chunk_m - q_off) + chunks.append((req_slice, slice(q_off, q_off + sub_m))) + + return chunks + + class DeepseekV32IndexerBackend(AttentionBackend): @staticmethod def get_name() -> str: @@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: token_start: int token_end: int num_reqs: int + skip_kv_gather: bool = False @dataclass @@ -271,43 +321,51 @@ def __init__(self, *args, **kwargs): ) def build_one_prefill_chunk( - self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table - ): + self, + req_slice: slice, + query_slice: slice, + query_start_loc_cpu, + seq_lens_cpu, + block_table, + skip_kv_gather: bool = False, + ) -> DeepseekV32IndexerPrefillChunkMetadata: prefill_query_start_loc = ( - query_start_loc_cpu[reqs_start : reqs_end + 1] - - query_start_loc_cpu[reqs_start] + query_start_loc_cpu[req_slice.start : req_slice.stop + 1] + - query_start_loc_cpu[req_slice.start] ) cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( - prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device + prefill_query_start_loc, seq_lens_cpu[req_slice], self.device + ) + token_start = query_start_loc_cpu[req_slice.start].item() + total_seq_lens = seq_lens_cpu[req_slice].sum() + num_reqs = req_slice.stop - req_slice.start + seq_idx = torch.arange(0, num_reqs, dtype=torch.int32) + token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to( + self.device ) - token_start = query_start_loc_cpu[reqs_start].item() - token_end = query_start_loc_cpu[reqs_end].item() - total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() - seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32) - token_to_seq = torch.repeat_interleave( - seq_idx, seq_lens_cpu[reqs_start:reqs_end] - ).to(self.device) assert total_seq_lens <= self.max_prefill_buffer_size cu_seq_lens = ( torch.cat( [ torch.zeros(1, dtype=torch.int32), - seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0), + seq_lens_cpu[req_slice].cumsum(dim=0), ] ) .to(torch.int32) .to(self.device) ) + return DeepseekV32IndexerPrefillChunkMetadata( - cu_seqlen_ks=cu_seqlen_ks, - cu_seqlen_ke=cu_seqlen_ke, + cu_seqlen_ks=cu_seqlen_ks[query_slice], + cu_seqlen_ke=cu_seqlen_ke[query_slice], cu_seq_lens=cu_seq_lens, token_to_seq=token_to_seq, total_seq_lens=total_seq_lens, - block_table=block_table[reqs_start:reqs_end], - token_start=token_start, - token_end=token_end, - num_reqs=reqs_end - reqs_start, + block_table=block_table[req_slice], + token_start=token_start + query_slice.start, + token_end=token_start + query_slice.stop, + num_reqs=num_reqs, + skip_kv_gather=skip_kv_gather, ) def build( @@ -333,20 +391,27 @@ def build( prefill_metadata = None if num_prefills > 0: - chunk_seq_ids = split_prefill_chunks( + prefill_query_lens_cpu = torch.diff( + query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] + ) + max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + chunk_specs = split_indexer_prefill_chunks( common_attn_metadata.seq_lens_cpu[num_decodes:], + prefill_query_lens_cpu, self.max_prefill_buffer_size, + max_logits_bytes, request_offset=num_decodes, ) chunks = [ self.build_one_prefill_chunk( - reqs_start, - reqs_end, + req_slice, + query_slice, query_start_loc_cpu, common_attn_metadata.seq_lens_cpu, common_attn_metadata.block_table_tensor, + skip_kv_gather=query_slice.start > 0, ) - for reqs_start, reqs_end in chunk_seq_ids + for req_slice, query_slice in chunk_specs ] prefill_metadata = DeepseekV32IndexerPrefillMetadata( chunks=chunks,