Skip to content

[Bugfix] Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding#37442

Merged
tlrmchlsmth merged 4 commits intovllm-project:mainfrom
elvircrn:zero-init-attn-output-pr
Mar 19, 2026
Merged

[Bugfix] Zero-init MLA attention output buffers to prevent NaN from CUDA graph padding#37442
tlrmchlsmth merged 4 commits intovllm-project:mainfrom
elvircrn:zero-init-attn-output-pr

Conversation

@elvircrn
Copy link
Copy Markdown
Contributor

@elvircrn elvircrn commented Mar 18, 2026

Summary

When running CUDA graph decode with padding (e.g. batch of 1024 with 1 real request), unused slots have seq_lens=0. The MLA decode kernels (both CUTLASS MLA and FlashInfer TRT-LLM MLA) skip writing output for these slots, leaving stale data in the output buffer. If that stale data contains NaN (from a previous iteration or uninitialized memory), it propagates to real tokens via downstream per-tensor FP8 quantization (amax over the entire batch).

This was observed in production on GB200 (SM100) with DeepSeek-R1 NVFP4 causing intermittent NaN outputs.

Root cause

Two mechanisms produce NaN in padding slots:

  1. Kernel skip (seq_lens=0): The kernel writes nothing for padding rows, so whatever was in the output buffer persists.
  2. TMA tile overread (seq_lens=1): MLA kernels read KV cache in 128-entry TMA tiles. With seq_lens=1, entries 1–127 are read from uninitialized KV cache pages. If those contain NaN, softmax produces NaN. (Related: FlashAttention3 forward producing NaN output when NaN exist in parts of input data that it should not be reading Dao-AILab/flash-attention#1974)

While vLLM already zero-inits KV cache pages (fixing mechanism 2), mechanism 1 requires the output buffer itself to be clean.

Fix

Pre-allocate output buffers with torch.zeros once per layer (cached as instance attributes), then reuse via slicing on each call. This:

  • Prevents NaN contamination — padding slots always contain zeros
  • Zero runtime cost — no per-call memset; the buffer is allocated once and reused, compatible with CUDA graph capture/replay
  • Applies to both backends — CUTLASS MLA (cutlass_mla.py) and FlashInfer MLA (flashinfer_mla.py)

Test plan

  • Verified fix eliminates NaN on GB200 with DeepSeek-R1 NVFP4 (1 real + 1023 padding tokens, ~1200 seq_len)
  • Minimal repro confirms: dirty output buffer + seq_lens=0 → NaN propagation; zero-init output → clean
  • Performance benchmark shows zero overhead (168 µs/iter with and without fix)
  • Pre-commit hooks pass (ruff, mypy, etc.)

🤖 Generated with Claude Code

@elvircrn elvircrn requested a review from pavanimajety as a code owner March 18, 2026 14:28
@mergify mergify bot added nvidia v1 bug Something isn't working labels Mar 18, 2026
@elvircrn elvircrn force-pushed the zero-init-attn-output-pr branch from dafe6ad to 9fb8044 Compare March 18, 2026 14:32
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug causing NaN propagation in MLA attention when using CUDA graph decoding with padding. The fix involves pre-allocating and zero-initializing output buffers to prevent stale data from contaminating results. The changes in vllm/v1/attention/backends/mla/cutlass_mla.py are well-implemented. However, the implementation in vllm/v1/attention/backends/mla/flashinfer_mla.py contains a critical bug in the shape and dtype of the pre-allocated buffer, which will cause runtime errors. My review provides a correction for this issue.

@elvircrn elvircrn force-pushed the zero-init-attn-output-pr branch from 9fb8044 to c8dfe2f Compare March 18, 2026 14:35
…graph padding

When using CUDA graph capture with batch padding, padding slots with
seq_lens=0 are skipped by the attention kernel, leaving uninitialized
output. This NaN/garbage in padding rows can propagate to real tokens
through downstream operations.

Fix: pre-allocate output buffers with zero-init, reuse across calls.
The buffer is lazily allocated on first use and reused on subsequent
calls, so the zeroing cost is paid once (not per CUDA graph replay).

Changes:
- CUTLASS MLA: cache output+LSE buffers as class properties, zero-init
  on allocation, slice to batch size each call
- FlashInfer MLA: same pattern, pass pre-zeroed out= tensor to
  trtllm_batch_decode_with_kv_cache_mla
- CUTLASS workspace: zero-init at allocation (was torch.empty)

Related: Dao-AILab/flash-attention#1974

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
@elvircrn elvircrn force-pushed the zero-init-attn-output-pr branch from c8dfe2f to fe46b00 Compare March 18, 2026 14:53
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
elvircrn and others added 3 commits March 18, 2026 19:11
…decode)

The output buffer was 3D (B, H, kv_lora_rank) but q is 4D
(B, q_len_per_req, H, D) for spec decode. The kernel wrote past the
buffer causing CUDBG_EXCEPTION_WARP_ILLEGAL_ADDRESS.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
FlashInfer expects out as 3D (batch, num_heads, kv_lora_rank).
For multi-token queries (spec decode), let the kernel allocate
its own output buffer.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
@MatthewBonanni
Copy link
Copy Markdown
Collaborator

Implemented a workaround so that we can fix q_len > 1 case: zero out padding slots after kernel call. Adds some overhead. Will eliminate once fix is upstreamed to FlashInfer.

Minimal reproducer here, verifies the fix works: https://gist.github.com/MatthewBonanni/5c425568a2b880edcd7b5f03a8048e2d

============================================================
MLA Decode NaN Propagation Reproducer
============================================================
Batch: 1 real + 127 padding
Seq len: 512
All cases start with NaN-poisoned buffers.

--- CUTLASS MLA ---
  fix OFF  amax=     nan  padding_nan= True  BUG
  fix ON   amax=  0.4453  padding_nan=False  OK

--- FlashInfer MLA (q_len=1) ---
  fix OFF  amax=     nan  padding_nan= True  BUG
  fix ON   amax=  0.3770  padding_nan=False  OK

--- FlashInfer MLA (q_len=4) ---
  fix OFF  amax=     nan  padding_nan= True  BUG
  fix ON   amax=  0.5000  padding_nan=False  OK

============================================================
fix OFF: kernel skips padding slots, NaN persists,
         per-tensor amax=NaN -> FP8 quant fails.
fix ON:  zero-init buffer (q_len=1) or zero padding
         after kernel (q_len>1) -> amax is clean.
============================================================

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 18, 2026
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 18, 2026 23:44
@tlrmchlsmth tlrmchlsmth self-assigned this Mar 18, 2026
@tlrmchlsmth tlrmchlsmth merged commit ef2c4f7 into vllm-project:main Mar 19, 2026
60 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 19, 2026
khluu pushed a commit that referenced this pull request Mar 19, 2026
…UDA graph padding (#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
(cherry picked from commit ef2c4f7)
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…UDA graph padding (vllm-project#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
ibifrost pushed a commit to ibifrost/vllm that referenced this pull request Mar 20, 2026
…UDA graph padding (vllm-project#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
(cherry picked from commit ef2c4f7)
@hjjq
Copy link
Copy Markdown
Contributor

hjjq commented Mar 21, 2026

Hi @elvircrn The below cmds hit OOM (4xGB200). Reverting this PR makes the OOM go away.

  export VLLM_USE_FLASHINFER_MOE_FP4=1
  export VLLM_USE_NCCL_SYMM_MEM=1
  export NCCL_NVLS_ENABLE=1
  export NCCL_CUMEM_ENABLE=1

  python3 -m vllm.entrypoints.openai.api_server \
    --host 0.0.0.0 --port 8000 \
    --model nvidia/DeepSeek-R1-0528-NVFP4 \
    --trust-remote-code --no-enable-prefix-caching \
    --dtype auto --kv-cache-dtype fp8 \
    --tensor-parallel-size 1 --data-parallel-size 4 \
    --max-num-seqs 1024 --max-model-len 10240 \
    --gpu-memory-utilization 0.9 \
    --max-num-batched-tokens 8192 \
    --enable-expert-parallel --async-scheduling \
    --compilation_config.max_cudagraph_capture_size 2048 \
    --compilation_config.cudagraph_mode FULL_DECODE_ONLY
  vllm bench serve \
    --host localhost --port 8000 \
    --tokenizer nvidia/DeepSeek-R1-0528-NVFP4 \
    --trust-remote-code --dataset-name random \
    --random-input-len 2000 --random-output-len 1000 \
    --ignore-eos --max-concurrency 128 \
    --request-rate inf --num-prompts 640

Server crashes with:
  MemoryError: CUDA out of memory. Tried to allocate 3.94 GiB. GPU has 184.00 GiB total, 2.38 GiB free.

Some Claude analysis:

The commit adds a persistent `_decode_out` tensor to each `FlashInferMLAImpl` (and `CutlassMLAImpl`) instance to prevent NaN contamination from CUDA graph padding slots. Since each attention layer creates its own impl instance, DeepSeek-R1 (61 layers)
allocates 61 separate buffers of size `max_batch * num_heads * kv_lora_rank * dtype_size` = ~256 MiB each, totaling ~15 GiB of new persistent GPU memory.

This memory is **not accounted for during memory profiling** because:

1. `profile_run()` calls `_dummy_run(max_num_tokens, is_profile=True)`
2. `is_profile=True` sets `force_eager=True` (`gpu_model_runner.py:5078`)
3. This results in `cudagraph_runtime_mode = NONE`
4. Attention metadata is only built when `force_attention or cudagraph_runtime_mode == FULL` (`gpu_model_runner.py:5136`) _ neither is true during profiling
5. So `forward_mqa()` is never called, and the lazy `_decode_out` allocation never triggers

The buffers materialize later during CUDA graph capture/serving (after KV cache is already sized), consuming memory that was budgeted for runtime allocations. The first large MoE kernel allocation (3.94 GiB) then OOMs.

Note: sharing a single buffer across layers would reduce the overhead from ~15 GiB to ~256 MiB, which fits within existing headroom _ but it still wouldn't be accounted for during profiling. The cleanest fix is to not persist the buffer at all: just ze
ro the padding slots after the kernel returns (as the `q_len > 1` path already does), which avoids both the memory overhead and the profiling gap.

elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 22, 2026
The NaN fix (PR vllm-project#37442) allocated a persistent _decode_out buffer
per attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB
of GPU memory that is also not accounted for during profiling, causing
OOM when KV cache + buffers exceed available memory.

Fix: use a single module-level buffer shared across all layers.
Memory drops from ~15 GiB to ~256 MiB. The buffer is only written
by one layer at a time (sequential forward pass), so sharing is safe.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 22, 2026
The NaN fix (PR vllm-project#37442) allocated a persistent _decode_out buffer
per attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB
of GPU memory that is also not accounted for during profiling, causing
OOM when KV cache + buffers exceed available memory.

Fix: use a single module-level buffer shared across all layers.
Memory drops from ~15 GiB to ~256 MiB. The buffer is only written
by one layer at a time (sequential forward pass), so sharing is safe.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
varun-sundar-rabindranath pushed a commit to neuralmagic/vllm that referenced this pull request Mar 22, 2026
varun-sundar-rabindranath pushed a commit to neuralmagic/vllm that referenced this pull request Mar 22, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 23, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm that referenced this pull request Mar 24, 2026
…UDA graph padding (vllm-project#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
(cherry picked from commit ef2c4f7)
varun-sundar-rabindranath pushed a commit to neuralmagic/vllm that referenced this pull request Mar 25, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
…UDA graph padding (vllm-project#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…UDA graph padding (vllm-project#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 27, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

The zero-init workaround is unnecessary — the NaN was caused by a
different issue (int64 expert IDs in the routing simulator). Reverting
to restore the original torch.empty allocation which avoids the
overhead of pre-allocated zero-init buffers.

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
varun-sundar-rabindranath pushed a commit to neuralmagic/vllm that referenced this pull request Mar 27, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…UDA graph padding (vllm-project#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…UDA graph padding (vllm-project#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
tlrmchlsmth pushed a commit to tlrmchlsmth/vllm that referenced this pull request Mar 28, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

The zero-init workaround is unnecessary — the NaN was caused by a
different issue (int64 expert IDs in the routing simulator). Reverting
to restore the original torch.empty allocation which avoids the
overhead of pre-allocated zero-init buffers.

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 28, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

The zero-init workaround is unnecessary — the NaN was caused by a
different issue (int64 expert IDs in the routing simulator). Reverting
to restore the original torch.empty allocation which avoids the
overhead of pre-allocated zero-init buffers.

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 29, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

The zero-init workaround is unnecessary — the NaN was caused by a
different issue (int64 expert IDs in the routing simulator). Reverting
to restore the original torch.empty allocation which avoids the
overhead of pre-allocated zero-init buffers.

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
tlrmchlsmth added a commit to tlrmchlsmth/vllm that referenced this pull request Mar 30, 2026
Cherry-pick d4a41a9: Revert "Zero-init MLA attention output buffers
to prevent NaN from CUDA graph padding (vllm-project#37442)"

Apply PR vllm-project#38148: Fix NaN from stale FP4 scale padding in
create_fp4_scale_tensor

Signed-off-by: Travis Stephens <travis@anthropic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 30, 2026
…N from CUDA graph padding (vllm-project#37442)"

This reverts commit ef2c4f7.

The zero-init workaround is unnecessary — the NaN was caused by a
different issue (int64 expert IDs in the routing simulator). Reverting
to restore the original torch.empty allocation which avoids the
overhead of pre-allocated zero-init buffers.

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…UDA graph padding (vllm-project#37442)

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants