[WIP][Bugfix] Fix CUDA OOM in sparse_attn_indexer prefill with high concurrency#35488
[WIP][Bugfix] Fix CUDA OOM in sparse_attn_indexer prefill with high concurrency#35488haosdent wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for a potential CUDA OOM error in the sparse attention indexer's prefill path, which can occur with high concurrency. The fix involves adding a memory budget for the logits tensor and implementing a per-request processing path with query sub-chunking when this budget is exceeded. This prevents the allocation of an excessively large logits tensor for the entire batch. The changes are well-structured, with a fast path for low-concurrency scenarios to avoid overhead. The use of CPU tensors for boundary lookups is a good optimization to prevent GPU synchronization within loops.
I've found one potential issue where the per-request processing logic might still lead to an OOM if a single request's sequence length is very large and the logits memory budget is configured to be small. I've added a specific comment with a suggestion to make the code more robust against this edge case.
|
This pull request has merge conflicts that must be resolved before it can be |
|
Having a designated memory budget for indexer looks promising, but after #35271 (adding non-deepgemm branch for mqa logit) this PR branch has merge conflict. Would you mind rebase once so that further testing gets smoother? And I'd also like to humbly ping @LucasWilkinson for visibility as this PR lies in the DSV3.2 improvement meta-issue #31473. |
|
Sorry for the delay @cjackal , I just rebased. |
8da9540 to
391d7a3
Compare
|
Thanks for prompt rebase! I just tested this PR with the following throughput benchmark which previously OOM at vllm serve zai-org/GLM-5-FP8 \
--gpu-memory-utilization 0.92 \
--tensor-parallel-size 8 \
--enable-expert-parallel \
--all2all-backend deepep_low_latency \
--max-num-batched-tokens 32768 \
--max-num-seqs 16
vllm bench serve --backend openai --model zai-org/GLM-5-FP8 \
--dataset-name random \
--random-input-len 16384 \
--random-output-len 512 \
--random-range-ratio 0 \
--ignore-eos \
--num-prompts 128Benchmark result w/ this PR: ============ Serving Benchmark Result ============
Successful requests: 128
Failed requests: 0
Request rate configured (RPS): 1.00
Benchmark duration (s): 403.43
Total input tokens: 2097152
Total generated tokens: 65536
Request throughput (req/s): 0.32
Output token throughput (tok/s): 162.45
Peak output token throughput (tok/s): 400.00
Peak concurrent requests: 20.00
Total token throughput (tok/s): 5360.80
---------------Time to First Token----------------
Mean TTFT (ms): 8771.62
Median TTFT (ms): 8144.13
P75 TTFT (ms): 12550.71
P90 TTFT (ms): 14301.31
-----Time per Output Token (excl. 1st token)-----
Mean TPOT (ms): 79.06
Median TPOT (ms): 79.82
P75 TPOT (ms): 87.32
P90 TPOT (ms): 90.93
==================================================It now does not OOM, so I think it achieves its initial goal. It runs longer without crash, but it seems invalid values can be passed to ...
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 240, in sparse_attn_indexer
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] _prefill_chunk_logits_and_topk(
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 147, in _prefill_chunk_logits_and_topk
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] _prefill_logits_and_topk(
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 71, in _prefill_logits_and_topk
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] logits = _compute_prefill_mqa_logits(q_fp8, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] File "/app/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/sparse_attn_indexer.py", line 55, in _compute_prefill_mqa_logits
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] return fp8_mqa_logits(
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] ^^^^^^^^^^^^^^^
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] File "/app/.venv/lib/python3.12/site-packages/vllm/utils/deep_gemm.py", line 270, in fp8_mqa_logits
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] return _fp8_mqa_logits_impl(
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] ^^^^^^^^^^^^^^^^^^^^^
(Worker pid=734) (Worker_TP1_EP1 pid=734) ERROR 03-05 00:20:25 [multiproc_executor.py:924] RuntimeError: CUDA driver error (csrc/apis/../jit_kernels/impls/runtime_utils.hpp:199): 1 (CUDA_ERROR_INVALID_VALUE, invalid argument)
...Sounds to me like the new per-request branch may send an invalid variable, but it might be deepgemm's bug as well. In any case, this PR is much more resilient under heavy load I think. |
550ab61 to
0ce7e89
Compare
Thanks @cjackal , it looks like because the logits I passed in didn't follow TMA alignment, have fixed it and may try again. |
|
Thanks for the fix, it seems CUDA_ERROR_INVALID_VALUE error is gone now. I can run throughput benchmark + gsm8k without error. |
|
@LucasWilkinson your solution looks much more elegant! How about let me close mine and in favor of yours #36178 I also try to refer to your changes and then updated this PR |
50ea7c3 to
35e1bf8
Compare
|
Hi, I have tested the new version atop of e568cf8, but it seems OOM at I can consistently reproduce the error above with a single 64k input message without noisy neighbors. |
This should be because we removed sub-chunking loop, then with a single 64k request:
We may need to add back the loop to mitigate this, do you have any suggestions? @LucasWilkinson |
|
This pull request has merge conflicts that must be resolved before it can be |
…rency Two-level defense against OOM from the [M, N] float32 logits tensor in fp8_mqa_logits: 1. Metadata-level: split_indexer_prefill_chunks splits prefill requests into chunks respecting both workspace (N) and logits (M*N) budgets. 2. Kernel-level: sub-chunk the query dimension (M) inside sparse_attn_indexer when a single request's M*N exceeds the budget. This handles long-context single requests (e.g. 64k tokens) that cannot be split across metadata chunks. Both share the VLLM_SPARSE_INDEXER_MAX_LOGITS_MB env var (default 512MB). Fixes: vllm-project#34553 Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: haosdent <haosdent@gmail.com>
|
Added subchunking to #36178 PTAL |
|
Close mine and in favor #36178 |
Purpose
Fix CUDA OOM in
sparse_attn_indexerduring prefill with high concurrency (e.g., 48–64 concurrent requests with ISL=10240 on 8×H200 serving GLM-5-FP8).Fixes: #34553
Root Cause
fp8_mqa_logitsallocates a[M, N]float32 logits tensor where M = total query tokens and N = total KV seq lengths. The existingsplit_prefill_chunksonly bounds N (workspace size) but not the M×N product (logits size). At high concurrency, N grows to ~300K+ while M stays ~8192, causing a ~9.89 GiB allocation that OOMs.A secondary issue: per-request KV slicing in a naive fix breaks TMA (Tensor Memory Access) alignment required by DeepGEMM's
cuTensorMapEncodeTiled, causing sporadicCUDA_ERROR_INVALID_VALUE.Solution
Metadata-level chunk splitting (inspired by #36178): Add
split_indexer_prefill_chunks()in the metadata builder that considers both the workspace (N) and logits (M×N×4) constraints when splitting prefill requests into chunks. This prevents the logits tensor from ever exceeding the memory budget.Query-only sub-chunking safety net: For single-request overflow (one request's M×N exceeds budget), sub-chunk only the query dimension while keeping the full KV buffer intact. This avoids per-request KV slicing and preserves TMA alignment.
Key changes:
vllm/envs.py: RegisterVLLM_SPARSE_INDEXER_MAX_LOGITS_MB(default 512 MB) environment variablevllm/v1/attention/backends/mla/indexer.py: Addsplit_indexer_prefill_chunks()that respects both N and M×N constraints; updatebuild()to use itvllm/model_executor/layers/sparse_attn_indexer.py: Add helper functions for prefill logits computation with query sub-chunking fallback; usevllm.envsinstead ofos.environtests/v1/attention/test_sparse_mla_backends.py: Add parametrized tests forsplit_indexer_prefill_chunksTest Plan
python -m pytest tests/v1/attention/test_sparse_mla_backends.py::test_split_indexer_prefill_chunks -x -v— 5/5 passpython -m pytest tests/v1/attention/test_sparse_mla_backends.py::test_split_prefill_chunks -x -v— existing tests still pass