Skip to content

[WIP][Bugfix] Fix CUDA OOM in sparse_attn_indexer prefill with high concurrency#35488

Closed
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-34553
Closed

[WIP][Bugfix] Fix CUDA OOM in sparse_attn_indexer prefill with high concurrency#35488
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-34553

Conversation

@haosdent
Copy link
Copy Markdown
Contributor

@haosdent haosdent commented Feb 27, 2026

Purpose

Fix CUDA OOM in sparse_attn_indexer during prefill with high concurrency (e.g., 48–64 concurrent requests with ISL=10240 on 8×H200 serving GLM-5-FP8).

Fixes: #34553

Root Cause

fp8_mqa_logits allocates a [M, N] float32 logits tensor where M = total query tokens and N = total KV seq lengths. The existing split_prefill_chunks only bounds N (workspace size) but not the M×N product (logits size). At high concurrency, N grows to ~300K+ while M stays ~8192, causing a ~9.89 GiB allocation that OOMs.

A secondary issue: per-request KV slicing in a naive fix breaks TMA (Tensor Memory Access) alignment required by DeepGEMM's cuTensorMapEncodeTiled, causing sporadic CUDA_ERROR_INVALID_VALUE.

Solution

Metadata-level chunk splitting (inspired by #36178): Add split_indexer_prefill_chunks() in the metadata builder that considers both the workspace (N) and logits (M×N×4) constraints when splitting prefill requests into chunks. This prevents the logits tensor from ever exceeding the memory budget.

Query-only sub-chunking safety net: For single-request overflow (one request's M×N exceeds budget), sub-chunk only the query dimension while keeping the full KV buffer intact. This avoids per-request KV slicing and preserves TMA alignment.

Key changes:

  1. vllm/envs.py: Register VLLM_SPARSE_INDEXER_MAX_LOGITS_MB (default 512 MB) environment variable
  2. vllm/v1/attention/backends/mla/indexer.py: Add split_indexer_prefill_chunks() that respects both N and M×N constraints; update build() to use it
  3. vllm/model_executor/layers/sparse_attn_indexer.py: Add helper functions for prefill logits computation with query sub-chunking fallback; use vllm.envs instead of os.environ
  4. tests/v1/attention/test_sparse_mla_backends.py: Add parametrized tests for split_indexer_prefill_chunks

Test Plan

  • python -m pytest tests/v1/attention/test_sparse_mla_backends.py::test_split_indexer_prefill_chunks -x -v — 5/5 pass
  • python -m pytest tests/v1/attention/test_sparse_mla_backends.py::test_split_prefill_chunks -x -v — existing tests still pass
  • Full validation on SM90+ hardware: serve GLM-5-FP8 with high concurrency and confirm no OOM or CUDA errors

@mergify mergify bot added nvidia v1 bug Something isn't working labels Feb 27, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for a potential CUDA OOM error in the sparse attention indexer's prefill path, which can occur with high concurrency. The fix involves adding a memory budget for the logits tensor and implementing a per-request processing path with query sub-chunking when this budget is exceeded. This prevents the allocation of an excessively large logits tensor for the entire batch. The changes are well-structured, with a fast path for low-concurrency scenarios to avoid overhead. The use of CPU tensors for boundary lookups is a good optimization to prevent GPU synchronization within loops.

I've found one potential issue where the per-request processing logic might still lead to an OOM if a single request's sequence length is very large and the logits memory budget is configured to be small. I've added a specific comment with a suggestion to make the code more robust against this edge case.

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @haosdent.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 1, 2026
@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Mar 3, 2026

Having a designated memory budget for indexer looks promising, but after #35271 (adding non-deepgemm branch for mqa logit) this PR branch has merge conflict. Would you mind rebase once so that further testing gets smoother?

And I'd also like to humbly ping @LucasWilkinson for visibility as this PR lies in the DSV3.2 improvement meta-issue #31473.

@haosdent
Copy link
Copy Markdown
Contributor Author

haosdent commented Mar 4, 2026

Sorry for the delay @cjackal , I just rebased.
However, I am not sure if my solution make sense or have better ways from other vLLM masters, I would keep it as draft first until we align the solution to address the issue.

@haosdent haosdent force-pushed the fix-34553 branch 2 times, most recently from 8da9540 to 391d7a3 Compare March 4, 2026 12:53
@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Mar 4, 2026

Thanks for prompt rebase! I just tested this PR with the following throughput benchmark which previously OOM at sparse_attn_indexer.py running on H200 x 8 worker node:

vllm serve zai-org/GLM-5-FP8 \
  --gpu-memory-utilization 0.92 \
  --tensor-parallel-size 8 \
  --enable-expert-parallel \
  --all2all-backend deepep_low_latency \
  --max-num-batched-tokens 32768 \
  --max-num-seqs 16

vllm bench serve --backend openai --model zai-org/GLM-5-FP8 \
  --dataset-name random \
  --random-input-len 16384 \
  --random-output-len 512 \
  --random-range-ratio 0 \
  --ignore-eos \
  --num-prompts 128

Benchmark result w/ this PR:

============ Serving Benchmark Result ============
Successful requests:                     128      
Failed requests:                         0        
Request rate configured (RPS):           1.00     
Benchmark duration (s):                  403.43   
Total input tokens:                      2097152  
Total generated tokens:                  65536    
Request throughput (req/s):              0.32     
Output token throughput (tok/s):         162.45   
Peak output token throughput (tok/s):    400.00   
Peak concurrent requests:                20.00    
Total token throughput (tok/s):          5360.80  
---------------Time to First Token----------------
Mean TTFT (ms):                          8771.62  
Median TTFT (ms):                        8144.13  
P75 TTFT (ms):                           12550.71 
P90 TTFT (ms):                           14301.31 
-----Time per Output Token (excl. 1st token)-----
Mean TPOT (ms):                          79.06    
Median TPOT (ms):                        79.82    
P75 TPOT (ms):                           87.32    
P90 TPOT (ms):                           90.93    
==================================================

It now does not OOM, so I think it achieves its initial goal. It runs longer without crash, but it seems invalid values can be passed to fp8_mqa_logits in rare cases; I got the following traceback in that case:

...
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]   File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 240, in sparse_attn_indexer
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]     _prefill_chunk_logits_and_topk(
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]   File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 147, in _prefill_chunk_logits_and_topk
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]     _prefill_logits_and_topk(
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]   File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 71, in _prefill_logits_and_topk
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]     logits = _compute_prefill_mqa_logits(q_fp8, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]   File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 55, in _compute_prefill_mqa_logits
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]     return fp8_mqa_logits(
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]            ^^^^^^^^^^^^^^^
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]   File "/app/.venv/lib/python3.12/site-packages/vllm/utils/deep_gemm.py", line 270, in fp8_mqa_logits

(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]     return _fp8_mqa_logits_impl(
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924]            ^^^^^^^^^^^^^^^^^^^^^
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] RuntimeError: CUDA driver error (csrc/apis/../jit_kernels/impls/runtime_utils.hpp:199): 1 (CUDA_ERROR_INVALID_VALUE, invalid argument)
...

Sounds to me like the new per-request branch may send an invalid variable, but it might be deepgemm's bug as well. In any case, this PR is much more resilient under heavy load I think.

@haosdent haosdent force-pushed the fix-34553 branch 2 times, most recently from 550ab61 to 0ce7e89 Compare March 5, 2026 02:57
@haosdent
Copy link
Copy Markdown
Contributor Author

haosdent commented Mar 5, 2026

but it seems invalid values can be passed to fp8_mqa_logits in rare cases; I got the following traceback in that case:

Thanks @cjackal , it looks like because the logits I passed in didn't follow TMA alignment, have fixed it and may try again.

@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Mar 5, 2026

Thanks for the fix, it seems CUDA_ERROR_INVALID_VALUE error is gone now. I can run throughput benchmark + gsm8k without error.

Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haosdent thanks for the contribution!, thanks for tracking this down, I think we can use the existing chunking loop, i.e. something like: #36178

@haosdent
Copy link
Copy Markdown
Contributor Author

haosdent commented Mar 6, 2026

@LucasWilkinson your solution looks much more elegant! How about let me close mine and in favor of yours #36178

I also try to refer to your changes and then updated this PR

@haosdent haosdent force-pushed the fix-34553 branch 5 times, most recently from 50ea7c3 to 35e1bf8 Compare March 6, 2026 07:51
@cjackal
Copy link
Copy Markdown
Contributor

cjackal commented Mar 11, 2026

Hi, I have tested the new version atop of e568cf8, but it seems OOM at fp8_mqa_logits is back:

...
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]   File "<eval_with_key>.420", line 5, in forward
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]     sparse_attn_indexer = torch.ops.vllm.sparse_attn_indexer(x_22, 'model.layers.1.self_attn.indexer.k_cache', l_self_modules_layers_modules_1_modules_self_attn_modules_mla_attn_modules_indexer_modules_indexer_op_modules_k_cache_kv_cache_0_, q_fp8_1, k_3, weights_3, 128, 'ue8m0', 2048, 128, 202752, 8110080, l_self_modules_layers_modules_0_modules_self_attn_modules_mla_attn_modules_indexer_modules_indexer_op_topk_indices_buffer); x_22 = l_self_modules_layers_modules_1_modules_self_attn_modules_mla_attn_modules_indexer_modules_indexer_op_modules_k_cache_kv_cache_0_ = q_fp8_1 = k_3 = weights_3 = l_self_modules_layers_modules_0_modules_self_attn_modules_mla_attn_modules_indexer_modules_indexer_op_topk_indices_buffer = sparse_attn_indexer = None
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]   File "/app/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1209, in __call__
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]     return self._op(*args, **kwargs)
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]   File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 117, in sparse_attn_indexer.py", line 117, in sparse_attn_indexer
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]     logits = fp8_mqa_logits(
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]              ^^^^^^^^^^^^^^^
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]   File "/app/.venv/lib/python3.12/site-packages/vllm/utils/deep_gemm.py", line 270, in fp8_mqa_logits
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]     return _fp8_mqa_logits_impl(
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932]            ^^^^^^^^^^^^^^^^^^^^^
(Worker pid=960) (Worker_TP1_EP1 pid=960) ERROR 03-11 23:55:13 [multiproc_executor.py:932] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.03 GiB. GPU 1 has a total capacity of 139.81 GiB of which 2.63 GiB is free. Including non-PyTorch memory, this process has 137.17 GiB memory in use. Of the allocated memory 121.52 GiB is allocated by PyTorch, with 346.00 MiB allocated in private pools (e.g. CUDA Graphs), and 8.33 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
...

I can consistently reproduce the error above with a single 64k input message without noisy neighbors.

@haosdent
Copy link
Copy Markdown
Contributor Author

haosdent commented Mar 12, 2026

Hi, I have tested the new version atop of e568cf8, but it seems OOM at fp8_mqa_logits is back:

This should be because we removed sub-chunking loop, then with a single 64k request:

  • M (query_len) = 65,536
  • N (seq_len) = 65,536
  • Logits = M × N × 4 bytes = 65,536 × 65,536 × 4 = 16 GiB

We may need to add back the loop to mitigate this, do you have any suggestions? @LucasWilkinson

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @haosdent.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

…rency

Two-level defense against OOM from the [M, N] float32 logits tensor in
fp8_mqa_logits:

1. Metadata-level: split_indexer_prefill_chunks splits prefill requests
   into chunks respecting both workspace (N) and logits (M*N) budgets.

2. Kernel-level: sub-chunk the query dimension (M) inside
   sparse_attn_indexer when a single request's M*N exceeds the budget.
   This handles long-context single requests (e.g. 64k tokens) that
   cannot be split across metadata chunks.

Both share the VLLM_SPARSE_INDEXER_MAX_LOGITS_MB env var (default 512MB).

Fixes: vllm-project#34553
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: haosdent <haosdent@gmail.com>
@LucasWilkinson
Copy link
Copy Markdown
Collaborator

Added subchunking to #36178 PTAL

@haosdent
Copy link
Copy Markdown
Contributor Author

Close mine and in favor #36178

@haosdent haosdent closed this Mar 25, 2026
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: GLM-5 FP8 on H200 CUDA OOM in sparse_attn_indexer at High Concurrency

3 participants