Skip to content

[Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking#36178

Open
LucasWilkinson wants to merge 8 commits intomainfrom
lucas/sparse-indexer-logits-budget
Open

[Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking#36178
LucasWilkinson wants to merge 8 commits intomainfrom
lucas/sparse-indexer-logits-budget

Conversation

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson commented Mar 5, 2026

Alternative to #35488, credit to @haosdent

Summary

  • Adds a logits tensor size constraint to sparse MLA indexer prefill chunking to prevent CUDA OOM
  • Introduces VLLM_SPARSE_INDEXER_MAX_LOGITS_MB env var (default 512 MB) to bound the [M, N] float32 logits tensor
  • Replaces split_prefill_chunks with split_indexer_prefill_chunks that respects both workspace and logits size constraints

Test plan

  • Added unit tests for split_indexer_prefill_chunks covering various constraint scenarios
  • Run existing MLA tests: pytest tests/v1/attention/test_sparse_mla_backends.py

🤖 Generated with Claude Code

@mergify mergify bot added v1 bug Something isn't working labels Mar 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a logits size constraint to the sparse MLA indexer's prefill chunking logic to prevent out-of-memory errors. A new environment variable VLLM_SPARSE_INDEXER_MAX_LOGITS_MB is added to control this budget. The core change is the new split_indexer_prefill_chunks function, which correctly chunks requests based on both workspace size and the new logits size constraint. The implementation is robust, handling cases where a single request might exceed the budget to avoid getting stuck. The accompanying unit tests are comprehensive and cover various scenarios, ensuring the correctness of the new logic. Overall, this is a solid improvement for memory management in the sparse indexer.

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2026
@LucasWilkinson LucasWilkinson marked this pull request as ready for review March 24, 2026 14:32
@LucasWilkinson LucasWilkinson force-pushed the lucas/sparse-indexer-logits-budget branch from d82abae to f950710 Compare March 24, 2026 14:51
@mergify mergify bot removed the needs-rebase label Mar 24, 2026
@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 24, 2026
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
- query_start_loc_cpu[req_slice.start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_spans_from_batches would calculate multiple times for the same request right

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Copy Markdown
Collaborator Author

@LucasWilkinson LucasWilkinson Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! but this is once per forward pass and is overlapped due to async scheduling so i dont think avoiding the redundant work here is critical

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 25, 2026
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
- query_start_loc_cpu[req_slice.start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

The sparse MLA indexer allocates a [M, N] float32 logits tensor during
prefill, where M is total query tokens and N is total sequence length.
For long sequences or large batches, this can exceed GPU memory.

This adds a new constraint to split_indexer_prefill_chunks that bounds
M*N*4 bytes to VLLM_SPARSE_INDEXER_MAX_LOGITS_MB (default 512 MB),
preventing CUDA OOM during prefill.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson LucasWilkinson force-pushed the lucas/sparse-indexer-logits-budget branch from f950710 to 0a3cef6 Compare March 31, 2026 20:24
@mergify mergify bot removed the needs-rebase label Mar 31, 2026
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson LucasWilkinson added this to the v0.19.0 cherry picks milestone Mar 31, 2026
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants