Skip to content

[Bugfix] Share MLA decode output buffer across layers to fix OOM#37799

Open
elvircrn wants to merge 1 commit intovllm-project:mainfrom
elvircrn:fix/shared-mla-decode-buffer
Open

[Bugfix] Share MLA decode output buffer across layers to fix OOM#37799
elvircrn wants to merge 1 commit intovllm-project:mainfrom
elvircrn:fix/shared-mla-decode-buffer

Conversation

@elvircrn
Copy link
Copy Markdown
Contributor

Summary

PR #37442 added a per-layer _decode_out buffer to prevent NaN from CUDA graph padding slots. For DeepSeek-R1 (61 layers), this allocates ~15 GiB of persistent GPU memory that is not accounted for during profiling (the buffer is lazily allocated in forward_mqa, which is never called during profile_run with is_profile=True).

This causes OOM on memory-constrained configs (e.g. DeepSeek-R1 NVFP4 with DP4 on 4×GB200):

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.94 GiB.
GPU has 184.00 GiB total, 2.38 GiB free.

Fix

Make the _decode_out buffer a module-level singleton shared across all layers instead of per-instance. Since the forward pass is sequential (one layer at a time), the buffer is never used concurrently. Memory drops from ~15 GiB to ~256 MiB.

Applies to both FlashInferMLAImpl and CutlassMLAImpl.

Reproducer (from #37442)

python3 -m vllm.entrypoints.openai.api_server \
  --model nvidia/DeepSeek-R1-0528-NVFP4 \
  --trust-remote-code --no-enable-prefix-caching \
  --dtype auto --kv-cache-dtype fp8 \
  --tensor-parallel-size 1 --data-parallel-size 4 \
  --max-num-seqs 1024 --max-model-len 10240 \
  --gpu-memory-utilization 0.9 \
  --max-num-batched-tokens 8192 \
  --enable-expert-parallel

Test plan

  • Verified OOM no longer occurs on 4×GB200 with DeepSeek-R1 NVFP4 DP4
  • NaN prevention still works (buffer is still zero-initialized, just shared)
  • CI

Fixes regression from #37442.

🤖 Generated with Claude Code

The NaN fix (PR vllm-project#37442) allocated a persistent _decode_out buffer
per attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB
of GPU memory that is also not accounted for during profiling, causing
OOM when KV cache + buffers exceed available memory.

Fix: use a single module-level buffer shared across all layers.
Memory drops from ~15 GiB to ~256 MiB. The buffer is only written
by one layer at a time (sequential forward pass), so sharing is safe.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@elvircrn elvircrn requested a review from pavanimajety as a code owner March 22, 2026 10:27
@mergify mergify bot added nvidia v1 bug Something isn't working labels Mar 22, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code changes introduce a global output buffer to reduce memory allocation. The reviewer identified two critical issues: a potential race condition due to the lack of synchronization when accessing the global buffer in cutlass_mla.py, and an incomplete check for buffer reallocation in flashinfer_mla.py that does not account for the number of heads or LoRA rank, potentially leading to shape mismatch errors.

Comment on lines +229 to +236
global g_cutlass_decode_out
if (
self._decode_out is None
or self._decode_out.shape[0] < B_q
or self._decode_out.dtype != dtype
g_cutlass_decode_out is None
or g_cutlass_decode_out.shape[0] < B_q
or g_cutlass_decode_out.dtype != dtype
):
self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
out = self._decode_out[:B_q]
g_cutlass_decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
out = g_cutlass_decode_out[:B_q]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of a global mutable variable g_cutlass_decode_out without any synchronization mechanism can lead to race conditions in a multi-threaded environment. If multiple forward passes are executed concurrently in different threads (e.g., with multiple adapters), one thread might be reading g_cutlass_decode_out while another is reallocating it. This could cause crashes or silent data corruption.

To ensure thread safety, this critical section should be protected by a lock. You'll need to add import threading and define a lock at the module level (e.g., g_cutlass_decode_out_lock = threading.Lock()), then use it to guard access to the shared buffer.

Comment on lines 215 to 219
if (
self._decode_out is None
or self._decode_out.shape[0] < B
or self._decode_out.dtype != dtype
g_fi_decode_out is None
or g_fi_decode_out.shape[0] < B
or g_fi_decode_out.dtype != dtype
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The check for re-allocating the shared g_fi_decode_out buffer is incomplete. It only considers the batch size (B) and dtype, but not the number of heads (q.shape[2]) or self.kv_lora_rank. If different models or layers with different num_heads or kv_lora_rank are used in the same process, an incorrectly sized buffer could be reused, leading to shape mismatch errors or silent correctness issues inside the trtllm_batch_decode_with_kv_cache_mla kernel.

To prevent this, the check should be extended to also verify that the number of heads and the LoRA rank match the existing buffer's dimensions.

            if (
                g_fi_decode_out is None
                or g_fi_decode_out.shape[0] < B
                or g_fi_decode_out.shape[1] != q.shape[2]
                or g_fi_decode_out.shape[2] != self.kv_lora_rank
                or g_fi_decode_out.dtype != dtype
            ):

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 22, 2026

Hi @elvircrn, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@elvircrn elvircrn marked this pull request as draft March 22, 2026 10:53
@elvircrn elvircrn marked this pull request as ready for review March 22, 2026 14:19
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 22, 2026

Hi @elvircrn, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia v1

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant