[Bugfix] Share MLA decode output buffer across layers to fix OOM#37799
[Bugfix] Share MLA decode output buffer across layers to fix OOM#37799elvircrn wants to merge 1 commit intovllm-project:mainfrom
Conversation
The NaN fix (PR vllm-project#37442) allocated a persistent _decode_out buffer per attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of GPU memory that is also not accounted for during profiling, causing OOM when KV cache + buffers exceed available memory. Fix: use a single module-level buffer shared across all layers. Memory drops from ~15 GiB to ~256 MiB. The buffer is only written by one layer at a time (sequential forward pass), so sharing is safe. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
The code changes introduce a global output buffer to reduce memory allocation. The reviewer identified two critical issues: a potential race condition due to the lack of synchronization when accessing the global buffer in cutlass_mla.py, and an incomplete check for buffer reallocation in flashinfer_mla.py that does not account for the number of heads or LoRA rank, potentially leading to shape mismatch errors.
| global g_cutlass_decode_out | ||
| if ( | ||
| self._decode_out is None | ||
| or self._decode_out.shape[0] < B_q | ||
| or self._decode_out.dtype != dtype | ||
| g_cutlass_decode_out is None | ||
| or g_cutlass_decode_out.shape[0] < B_q | ||
| or g_cutlass_decode_out.dtype != dtype | ||
| ): | ||
| self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype) | ||
| out = self._decode_out[:B_q] | ||
| g_cutlass_decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype) | ||
| out = g_cutlass_decode_out[:B_q] |
There was a problem hiding this comment.
The use of a global mutable variable g_cutlass_decode_out without any synchronization mechanism can lead to race conditions in a multi-threaded environment. If multiple forward passes are executed concurrently in different threads (e.g., with multiple adapters), one thread might be reading g_cutlass_decode_out while another is reallocating it. This could cause crashes or silent data corruption.
To ensure thread safety, this critical section should be protected by a lock. You'll need to add import threading and define a lock at the module level (e.g., g_cutlass_decode_out_lock = threading.Lock()), then use it to guard access to the shared buffer.
| if ( | ||
| self._decode_out is None | ||
| or self._decode_out.shape[0] < B | ||
| or self._decode_out.dtype != dtype | ||
| g_fi_decode_out is None | ||
| or g_fi_decode_out.shape[0] < B | ||
| or g_fi_decode_out.dtype != dtype | ||
| ): |
There was a problem hiding this comment.
The check for re-allocating the shared g_fi_decode_out buffer is incomplete. It only considers the batch size (B) and dtype, but not the number of heads (q.shape[2]) or self.kv_lora_rank. If different models or layers with different num_heads or kv_lora_rank are used in the same process, an incorrectly sized buffer could be reused, leading to shape mismatch errors or silent correctness issues inside the trtllm_batch_decode_with_kv_cache_mla kernel.
To prevent this, the check should be extended to also verify that the number of heads and the LoRA rank match the existing buffer's dimensions.
if (
g_fi_decode_out is None
or g_fi_decode_out.shape[0] < B
or g_fi_decode_out.shape[1] != q.shape[2]
or g_fi_decode_out.shape[2] != self.kv_lora_rank
or g_fi_decode_out.dtype != dtype
):|
Hi @elvircrn, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @elvircrn, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Summary
PR #37442 added a per-layer
_decode_outbuffer to prevent NaN from CUDA graph padding slots. For DeepSeek-R1 (61 layers), this allocates ~15 GiB of persistent GPU memory that is not accounted for during profiling (the buffer is lazily allocated inforward_mqa, which is never called duringprofile_runwithis_profile=True).This causes OOM on memory-constrained configs (e.g. DeepSeek-R1 NVFP4 with DP4 on 4×GB200):
Fix
Make the
_decode_outbuffer a module-level singleton shared across all layers instead of per-instance. Since the forward pass is sequential (one layer at a time), the buffer is never used concurrently. Memory drops from ~15 GiB to ~256 MiB.Applies to both
FlashInferMLAImplandCutlassMLAImpl.Reproducer (from #37442)
Test plan
Fixes regression from #37442.
🤖 Generated with Claude Code