Skip to content

[Fix] Share decode output buffer across MLA layers to reduce memory#37805

Closed
xueliangyang-oeuler wants to merge 4 commits intovllm-project:mainfrom
xueliangyang-oeuler:feat/check-issues-0322
Closed

[Fix] Share decode output buffer across MLA layers to reduce memory#37805
xueliangyang-oeuler wants to merge 4 commits intovllm-project:mainfrom
xueliangyang-oeuler:feat/check-issues-0322

Conversation

@xueliangyang-oeuler
Copy link
Copy Markdown

@xueliangyang-oeuler xueliangyang-oeuler commented Mar 22, 2026

The NaN fix (PR #37442) allocated a persistent _decode_out buffer per attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of GPU memory that is not accounted for during profiling, causing OOM when KV cache + buffers exceed available memory.

This fix uses a single module-level buffer shared across all layers. Memory drops from ~15 GiB to ~256 MiB. The buffer is only written by one layer at a time (sequential forward pass), so sharing is safe.

Changes:

  • CUTLASS MLA: Add SM100DecodeOutBuffer class with shared buffer
  • FlashInfer MLA: Add FlashInferDecodeOutBuffer class with shared buffer
  • Both implementations now use g_*_decode_out global instance

Fixes #37777

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

The NaN fix (PR vllm-project#37442) allocated a persistent _decode_out buffer per
attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
GPU memory that is not accounted for during profiling, causing OOM
when KV cache + buffers exceed available memory.

This fix uses a single module-level buffer shared across all layers.
Memory drops from ~15 GiB to ~256 MiB. The buffer is only written
by one layer at a time (sequential forward pass), so sharing is safe.

Changes:
- CUTLASS MLA: Add SM100DecodeOutBuffer class with shared buffer
- FlashInfer MLA: Add FlashInferDecodeOutBuffer class with shared buffer
- Both implementations now use g_*_decode_out global instance

Fixes vllm-project#37777

Signed-off-by: xueliangyang-oeuler <yxl546827391@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant memory optimization by sharing the decode output buffer across all MLA layers, which drastically reduces GPU memory usage for models with many layers. The implementation correctly replaces per-layer buffers with a global shared one. My review focuses on improving the buffer reallocation strategy within the new buffer management classes. The current implementation can lead to inefficient reallocations when request characteristics fluctuate. I've suggested changes to ensure the shared buffers only grow, which will improve performance and robustness.

Comment on lines +102 to +128
class SM100DecodeOutBuffer:
"""Shared decode output buffer for all CutlassMLAImpl instances.

The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
GPU memory that is not accounted for during profiling, causing OOM.

This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
layers write sequentially (no concurrent access in forward pass).
"""

def __init__(self):
self._buffer: torch.Tensor | None = None
self._dtype: torch.dtype | None = None

def get(self, batch_size: int, dtype: torch.dtype, device: torch.device):
D_latent = 512
if (
self._buffer is None
or self._buffer.shape[0] < batch_size
or self._buffer.dtype != dtype
or self._buffer.device != device
):
self._buffer = torch.zeros(
(batch_size, MAX_HEADS, D_latent), dtype=dtype, device=device
)
return self._buffer[:batch_size]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of SM100DecodeOutBuffer.get can be inefficient. When the buffer needs to be reallocated due to a change in dtype or device, it uses the current batch_size, which might be small. A subsequent call with a larger batch_size would trigger another reallocation. To avoid these potentially frequent reallocations, the buffer should always grow to the maximum batch_size seen so far for a given dtype and device.

class SM100DecodeOutBuffer:
    """Shared decode output buffer for all CutlassMLAImpl instances.

    The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
    attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
    GPU memory that is not accounted for during profiling, causing OOM.

    This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
    layers write sequentially (no concurrent access in forward pass).
    """

    def __init__(self):
        self._buffer: torch.Tensor | None = None
        self._max_batch_size = 0

    def get(self, batch_size: int, dtype: torch.dtype, device: torch.device):
        D_latent = 512
        realloc = False
        if (
            self._buffer is None
            or self._buffer.dtype != dtype
            or self._buffer.device != device
        ):
            self._max_batch_size = 0
            realloc = True

        if batch_size > self._max_batch_size:
            self._max_batch_size = batch_size
            realloc = True

        if realloc:
            self._buffer = torch.zeros(
                (self._max_batch_size, MAX_HEADS, D_latent), dtype=dtype, device=device
            )
        return self._buffer[:batch_size]

Comment on lines +106 to +139
class FlashInferDecodeOutBuffer:
"""Shared decode output buffer for all FlashInferMLAImpl instances.

The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
GPU memory that is not accounted for during profiling, causing OOM.

This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
layers write sequentially (no concurrent access in forward pass).
"""

def __init__(self):
self._buffer: torch.Tensor | None = None

def get(
self,
batch_size: int,
num_heads: int,
kv_lora_rank: int,
dtype: torch.dtype,
device: torch.device,
):
if (
self._buffer is None
or self._buffer.shape[0] < batch_size
or self._buffer.shape[1] < num_heads
or self._buffer.shape[2] < kv_lora_rank
or self._buffer.dtype != dtype
or self._buffer.device != device
):
self._buffer = torch.zeros(
(batch_size, num_heads, kv_lora_rank), dtype=dtype, device=device
)
return self._buffer[:batch_size, :num_heads, :kv_lora_rank]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get method in FlashInferDecodeOutBuffer re-allocates the buffer using the current dimensions, which can lead to inefficient reallocations if dimensions fluctuate. For instance, a request with large num_heads and small batch_size followed by one with small num_heads and large batch_size would cause two reallocations where the buffer size might shrink and then grow again. A more robust approach is to track the maximum size for each dimension and only grow the buffer, preventing performance degradation from frequent memory operations.

class FlashInferDecodeOutBuffer:
    """Shared decode output buffer for all FlashInferMLAImpl instances.

    The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
    attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
    GPU memory that is not accounted for during profiling, causing OOM.

    This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
    layers write sequentially (no concurrent access in forward pass).
    """

    def __init__(self):
        self._buffer: torch.Tensor | None = None
        self._max_batch_size = 0
        self._max_num_heads = 0
        self._max_kv_lora_rank = 0

    def get(
        self,
        batch_size: int,
        num_heads: int,
        kv_lora_rank: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        realloc = False
        if (
            self._buffer is None
            or self._buffer.dtype != dtype
            or self._buffer.device != device
        ):
            self._max_batch_size = 0
            self._max_num_heads = 0
            self._max_kv_lora_rank = 0
            realloc = True

        if batch_size > self._max_batch_size:
            self._max_batch_size = batch_size
            realloc = True

        if num_heads > self._max_num_heads:
            self._max_num_heads = num_heads
            realloc = True

        if kv_lora_rank > self._max_kv_lora_rank:
            self._max_kv_lora_rank = kv_lora_rank
            realloc = True

        if realloc:
            self._buffer = torch.zeros(
                (self._max_batch_size, self._max_num_heads, self._max_kv_lora_rank),
                dtype=dtype,
                device=device,
            )
        return self._buffer[:batch_size, :num_heads, :kv_lora_rank]

@varun-sundar-rabindranath
Copy link
Copy Markdown
Contributor

Hi @xueliangyang-oeuler . JFYI #37815 .

@hmellor hmellor added the closed-as-slop Pull request determined to be low effort and agent generated label Mar 25, 2026
@hmellor hmellor closed this Mar 25, 2026
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

closed-as-slop Pull request determined to be low effort and agent generated nvidia v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: [OOM] DeepSeek-R1 Out of Memory

4 participants