[Fix] Share decode output buffer across MLA layers to reduce memory#37805
[Fix] Share decode output buffer across MLA layers to reduce memory#37805xueliangyang-oeuler wants to merge 4 commits intovllm-project:mainfrom
Conversation
The NaN fix (PR vllm-project#37442) allocated a persistent _decode_out buffer per attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of GPU memory that is not accounted for during profiling, causing OOM when KV cache + buffers exceed available memory. This fix uses a single module-level buffer shared across all layers. Memory drops from ~15 GiB to ~256 MiB. The buffer is only written by one layer at a time (sequential forward pass), so sharing is safe. Changes: - CUTLASS MLA: Add SM100DecodeOutBuffer class with shared buffer - FlashInfer MLA: Add FlashInferDecodeOutBuffer class with shared buffer - Both implementations now use g_*_decode_out global instance Fixes vllm-project#37777 Signed-off-by: xueliangyang-oeuler <yxl546827391@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a significant memory optimization by sharing the decode output buffer across all MLA layers, which drastically reduces GPU memory usage for models with many layers. The implementation correctly replaces per-layer buffers with a global shared one. My review focuses on improving the buffer reallocation strategy within the new buffer management classes. The current implementation can lead to inefficient reallocations when request characteristics fluctuate. I've suggested changes to ensure the shared buffers only grow, which will improve performance and robustness.
| class SM100DecodeOutBuffer: | ||
| """Shared decode output buffer for all CutlassMLAImpl instances. | ||
|
|
||
| The NaN fix (PR #37442) allocated a persistent _decode_out buffer per | ||
| attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of | ||
| GPU memory that is not accounted for during profiling, causing OOM. | ||
|
|
||
| This shared buffer reduces memory from ~15 GiB to ~256 MiB since all | ||
| layers write sequentially (no concurrent access in forward pass). | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self._buffer: torch.Tensor | None = None | ||
| self._dtype: torch.dtype | None = None | ||
|
|
||
| def get(self, batch_size: int, dtype: torch.dtype, device: torch.device): | ||
| D_latent = 512 | ||
| if ( | ||
| self._buffer is None | ||
| or self._buffer.shape[0] < batch_size | ||
| or self._buffer.dtype != dtype | ||
| or self._buffer.device != device | ||
| ): | ||
| self._buffer = torch.zeros( | ||
| (batch_size, MAX_HEADS, D_latent), dtype=dtype, device=device | ||
| ) | ||
| return self._buffer[:batch_size] |
There was a problem hiding this comment.
The current implementation of SM100DecodeOutBuffer.get can be inefficient. When the buffer needs to be reallocated due to a change in dtype or device, it uses the current batch_size, which might be small. A subsequent call with a larger batch_size would trigger another reallocation. To avoid these potentially frequent reallocations, the buffer should always grow to the maximum batch_size seen so far for a given dtype and device.
class SM100DecodeOutBuffer:
"""Shared decode output buffer for all CutlassMLAImpl instances.
The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
GPU memory that is not accounted for during profiling, causing OOM.
This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
layers write sequentially (no concurrent access in forward pass).
"""
def __init__(self):
self._buffer: torch.Tensor | None = None
self._max_batch_size = 0
def get(self, batch_size: int, dtype: torch.dtype, device: torch.device):
D_latent = 512
realloc = False
if (
self._buffer is None
or self._buffer.dtype != dtype
or self._buffer.device != device
):
self._max_batch_size = 0
realloc = True
if batch_size > self._max_batch_size:
self._max_batch_size = batch_size
realloc = True
if realloc:
self._buffer = torch.zeros(
(self._max_batch_size, MAX_HEADS, D_latent), dtype=dtype, device=device
)
return self._buffer[:batch_size]| class FlashInferDecodeOutBuffer: | ||
| """Shared decode output buffer for all FlashInferMLAImpl instances. | ||
|
|
||
| The NaN fix (PR #37442) allocated a persistent _decode_out buffer per | ||
| attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of | ||
| GPU memory that is not accounted for during profiling, causing OOM. | ||
|
|
||
| This shared buffer reduces memory from ~15 GiB to ~256 MiB since all | ||
| layers write sequentially (no concurrent access in forward pass). | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self._buffer: torch.Tensor | None = None | ||
|
|
||
| def get( | ||
| self, | ||
| batch_size: int, | ||
| num_heads: int, | ||
| kv_lora_rank: int, | ||
| dtype: torch.dtype, | ||
| device: torch.device, | ||
| ): | ||
| if ( | ||
| self._buffer is None | ||
| or self._buffer.shape[0] < batch_size | ||
| or self._buffer.shape[1] < num_heads | ||
| or self._buffer.shape[2] < kv_lora_rank | ||
| or self._buffer.dtype != dtype | ||
| or self._buffer.device != device | ||
| ): | ||
| self._buffer = torch.zeros( | ||
| (batch_size, num_heads, kv_lora_rank), dtype=dtype, device=device | ||
| ) | ||
| return self._buffer[:batch_size, :num_heads, :kv_lora_rank] |
There was a problem hiding this comment.
The get method in FlashInferDecodeOutBuffer re-allocates the buffer using the current dimensions, which can lead to inefficient reallocations if dimensions fluctuate. For instance, a request with large num_heads and small batch_size followed by one with small num_heads and large batch_size would cause two reallocations where the buffer size might shrink and then grow again. A more robust approach is to track the maximum size for each dimension and only grow the buffer, preventing performance degradation from frequent memory operations.
class FlashInferDecodeOutBuffer:
"""Shared decode output buffer for all FlashInferMLAImpl instances.
The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
GPU memory that is not accounted for during profiling, causing OOM.
This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
layers write sequentially (no concurrent access in forward pass).
"""
def __init__(self):
self._buffer: torch.Tensor | None = None
self._max_batch_size = 0
self._max_num_heads = 0
self._max_kv_lora_rank = 0
def get(
self,
batch_size: int,
num_heads: int,
kv_lora_rank: int,
dtype: torch.dtype,
device: torch.device,
):
realloc = False
if (
self._buffer is None
or self._buffer.dtype != dtype
or self._buffer.device != device
):
self._max_batch_size = 0
self._max_num_heads = 0
self._max_kv_lora_rank = 0
realloc = True
if batch_size > self._max_batch_size:
self._max_batch_size = batch_size
realloc = True
if num_heads > self._max_num_heads:
self._max_num_heads = num_heads
realloc = True
if kv_lora_rank > self._max_kv_lora_rank:
self._max_kv_lora_rank = kv_lora_rank
realloc = True
if realloc:
self._buffer = torch.zeros(
(self._max_batch_size, self._max_num_heads, self._max_kv_lora_rank),
dtype=dtype,
device=device,
)
return self._buffer[:batch_size, :num_heads, :kv_lora_rank]|
Hi @xueliangyang-oeuler . JFYI #37815 . |
The NaN fix (PR #37442) allocated a persistent _decode_out buffer per attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of GPU memory that is not accounted for during profiling, causing OOM when KV cache + buffers exceed available memory.
This fix uses a single module-level buffer shared across all layers. Memory drops from ~15 GiB to ~256 MiB. The buffer is only written by one layer at a time (sequential forward pass), so sharing is safe.
Changes:
Fixes #37777
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.