Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions tests/v1/attention/test_sparse_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
from vllm.v1.attention.backends.utils import split_prefill_chunks
from vllm.v1.attention.ops import flashmla

Expand Down Expand Up @@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected):
assert out == expected


@pytest.mark.parametrize(
"seq_lens,query_lens,workspace_size,max_logits_bytes,expected",
[
# Logits constraint triggers split (M*N exceeds budget)
# req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000
# req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250
(
torch.tensor([100, 100, 100]),
torch.tensor([10, 10, 10]),
1000, # workspace allows all
5000, # 1250 float32 elems -> forces split
[
(slice(0, 1), slice(0, 10)),
(slice(1, 2), slice(0, 10)),
(slice(2, 3), slice(0, 10)),
],
),
# Both constraints satisfied - all fit in one chunk
(
torch.tensor([10, 10, 10]),
torch.tensor([5, 5, 5]),
100,
10000, # 2500 elems, M*N = 15*30 = 450 < 2500
[(slice(0, 3), slice(0, 15))],
),
# Workspace constraint triggers first
(
torch.tensor([50, 50, 50]),
torch.tensor([1, 1, 1]),
50, # workspace only fits one at a time
1000000, # logits budget is huge
[
(slice(0, 1), slice(0, 1)),
(slice(1, 2), slice(0, 1)),
(slice(2, 3), slice(0, 1)),
],
),
# Greedy filling: first two fit, third doesn't
# req0: M=5, N=10 -> 50 elems
# req0+1: M=10, N=20 -> 200 elems <= 250
# req0+1+2: M=15, N=30 -> 450 elems > 250
(
torch.tensor([10, 10, 10]),
torch.tensor([5, 5, 5]),
100,
1000, # 250 elems
[(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))],
),
],
)
def test_split_indexer_prefill_chunks(
seq_lens, query_lens, workspace_size, max_logits_bytes, expected
):
out = split_indexer_prefill_chunks(
seq_lens,
query_lens,
workspace_size,
max_logits_bytes,
)
assert out == expected


def test_split_indexer_prefill_chunks_single_request_overflow():
"""Test that single request exceeding budget is sub-chunked on query dim."""
seq_lens = torch.tensor([1000, 50])
query_lens = torch.tensor([100, 5])

out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000)
# max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks
expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)]
# req1: M=5, N=50 -> 250 elems fits budget
expected.append((slice(1, 2), slice(0, 5)))
assert out == expected


def test_triton_convert_returns_valid_counts():
"""Test that return_valid_counts correctly counts non-negative indices."""
device = torch.device("cuda")
Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
VLLM_CPU_INT4_W4A8: bool = True
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
Expand Down Expand Up @@ -861,6 +862,12 @@ def _get_or_set_default() -> str:
),
# Enable SPMD mode for TPU backend.
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
# Maximum size (in MB) for logits tensor in sparse MLA indexer prefill chunks.
# Bounds the [M, N] float32 logits tensor to prevent CUDA OOM.
# Default: 512 MB
"VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int(
os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512")
),
# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(
Expand Down
26 changes: 19 additions & 7 deletions vllm/model_executor/layers/sparse_attn_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
Expand Down Expand Up @@ -51,6 +52,14 @@ def sparse_attn_indexer(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)

# Dummy allocation to simulate for peak logits tensor memory during inference.
# FP8 elements so elements == bytes
max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
_ = torch.empty(
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
)

return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
Expand Down Expand Up @@ -101,13 +110,16 @@ def sparse_attn_indexer(
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)

if not chunk.skip_kv_gather:
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)

logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
Expand Down
113 changes: 89 additions & 24 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
Expand All @@ -22,14 +23,62 @@
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.cp_utils import get_total_cp_world_size

logger = init_logger(__name__)


def split_indexer_prefill_chunks(
seq_lens_cpu: torch.Tensor,
query_lens_cpu: torch.Tensor,
workspace_size: int,
max_logits_bytes: int,
request_offset: int = 0,
) -> list[tuple[slice, slice]]:
"""
Split prefill requests into chunks for the sparse indexer, respecting:
- N constraint: total_seq_lens <= workspace_size (existing O(N) workspace)
- Logits constraint: M * N * 4 <= max_logits_bytes

When a single request-level chunk still exceeds the logits budget,
sub-chunks on the query dimension (M) to bound peak memory.

Returns list of (req_slice, query_slice) tuples.
"""
chunks: list[tuple[slice, slice]] = []
n = len(seq_lens_cpu)
max_logits_elems = max_logits_bytes // 4
end = 0

while end < n:
start, chunk_m, chunk_n = end, 0, 0

while end < n:
q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
new_m, new_n = chunk_m + q, chunk_n + s
if new_n <= workspace_size and new_m * new_n <= max_logits_elems:
chunk_m, chunk_n = new_m, new_n
end += 1
else:
break

# A single request can exceed the budget, requiring sub-chunking
# on the query dimension.
if end == start:
chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
end += 1

req_slice = slice(start + request_offset, end + request_offset)
max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m
for q_off in range(0, chunk_m, max_q):
sub_m = min(max_q, chunk_m - q_off)
chunks.append((req_slice, slice(q_off, q_off + sub_m)))

return chunks


class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
token_start: int
token_end: int
num_reqs: int
skip_kv_gather: bool = False


@dataclass
Expand Down Expand Up @@ -271,43 +321,51 @@ def __init__(self, *args, **kwargs):
)

def build_one_prefill_chunk(
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
):
self,
req_slice: slice,
query_slice: slice,
query_start_loc_cpu,
seq_lens_cpu,
block_table,
skip_kv_gather: bool = False,
) -> DeepseekV32IndexerPrefillChunkMetadata:
prefill_query_start_loc = (
query_start_loc_cpu[reqs_start : reqs_end + 1]
- query_start_loc_cpu[reqs_start]
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
- query_start_loc_cpu[req_slice.start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_spans_from_batches would calculate multiple times for the same request right

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Copy Markdown
Collaborator Author

@LucasWilkinson LucasWilkinson Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! but this is once per forward pass and is overlapped due to async scheduling so i dont think avoiding the redundant work here is critical

prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
)
token_start = query_start_loc_cpu[req_slice.start].item()
total_seq_lens = seq_lens_cpu[req_slice].sum()
num_reqs = req_slice.stop - req_slice.start
seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
self.device
)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
).to(self.device)
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = (
torch.cat(
[
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
seq_lens_cpu[req_slice].cumsum(dim=0),
]
)
.to(torch.int32)
.to(self.device)
)

return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seqlen_ks=cu_seqlen_ks[query_slice],
cu_seqlen_ke=cu_seqlen_ke[query_slice],
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
block_table=block_table[req_slice],
token_start=token_start + query_slice.start,
token_end=token_start + query_slice.stop,
num_reqs=num_reqs,
skip_kv_gather=skip_kv_gather,
)

def build(
Expand All @@ -333,20 +391,27 @@ def build(

prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
prefill_query_lens_cpu = torch.diff(
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
)
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
chunk_specs = split_indexer_prefill_chunks(
common_attn_metadata.seq_lens_cpu[num_decodes:],
prefill_query_lens_cpu,
self.max_prefill_buffer_size,
max_logits_bytes,
request_offset=num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start,
reqs_end,
req_slice,
query_slice,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor,
skip_kv_gather=query_slice.start > 0,
)
for reqs_start, reqs_end in chunk_seq_ids
for req_slice, query_slice in chunk_specs
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks,
Expand Down
Loading