Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,38 @@ def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int):

g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB


class SM100DecodeOutBuffer:
"""Shared decode output buffer for all CutlassMLAImpl instances.

The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
GPU memory that is not accounted for during profiling, causing OOM.

This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
layers write sequentially (no concurrent access in forward pass).
"""

def __init__(self):
self._buffer: torch.Tensor | None = None
self._dtype: torch.dtype | None = None

def get(self, batch_size: int, dtype: torch.dtype, device: torch.device):
D_latent = 512
if (
self._buffer is None
or self._buffer.shape[0] < batch_size
or self._buffer.dtype != dtype
or self._buffer.device != device
):
self._buffer = torch.zeros(
(batch_size, MAX_HEADS, D_latent), dtype=dtype, device=device
)
return self._buffer[:batch_size]
Comment on lines +102 to +128
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of SM100DecodeOutBuffer.get can be inefficient. When the buffer needs to be reallocated due to a change in dtype or device, it uses the current batch_size, which might be small. A subsequent call with a larger batch_size would trigger another reallocation. To avoid these potentially frequent reallocations, the buffer should always grow to the maximum batch_size seen so far for a given dtype and device.

class SM100DecodeOutBuffer:
    """Shared decode output buffer for all CutlassMLAImpl instances.

    The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
    attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
    GPU memory that is not accounted for during profiling, causing OOM.

    This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
    layers write sequentially (no concurrent access in forward pass).
    """

    def __init__(self):
        self._buffer: torch.Tensor | None = None
        self._max_batch_size = 0

    def get(self, batch_size: int, dtype: torch.dtype, device: torch.device):
        D_latent = 512
        realloc = False
        if (
            self._buffer is None
            or self._buffer.dtype != dtype
            or self._buffer.device != device
        ):
            self._max_batch_size = 0
            realloc = True

        if batch_size > self._max_batch_size:
            self._max_batch_size = batch_size
            realloc = True

        if realloc:
            self._buffer = torch.zeros(
                (self._max_batch_size, MAX_HEADS, D_latent), dtype=dtype, device=device
            )
        return self._buffer[:batch_size]



g_sm100_decode_out = SM100DecodeOutBuffer()

MAX_HEADS = 128


Expand Down Expand Up @@ -162,10 +194,9 @@ def __init__(
# Share workspace buffer across all executions
self._workspace = g_sm100_workspace

# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None
# Use shared decode output buffer across all layers to reduce memory.
# Each layer writes sequentially, so sharing is safe.
self._decode_out_buffer = g_sm100_decode_out

def _sm100_cutlass_mla_decode(
self,
Expand Down Expand Up @@ -223,15 +254,8 @@ def _sm100_cutlass_mla_decode(
if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype
)
# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
if (
self._decode_out is None
or self._decode_out.shape[0] < B_q
or self._decode_out.dtype != dtype
):
self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
out = self._decode_out[:B_q]
# Use shared buffer to reduce memory footprint.
out = self._decode_out_buffer.get(B_q, dtype, q_nope.device)
lse = (
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode
Expand Down
63 changes: 46 additions & 17 deletions vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,45 @@ def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
)


class FlashInferDecodeOutBuffer:
"""Shared decode output buffer for all FlashInferMLAImpl instances.

The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
GPU memory that is not accounted for during profiling, causing OOM.

This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
layers write sequentially (no concurrent access in forward pass).
"""

def __init__(self):
self._buffer: torch.Tensor | None = None

def get(
self,
batch_size: int,
num_heads: int,
kv_lora_rank: int,
dtype: torch.dtype,
device: torch.device,
):
if (
self._buffer is None
or self._buffer.shape[0] < batch_size
or self._buffer.shape[1] < num_heads
or self._buffer.shape[2] < kv_lora_rank
or self._buffer.dtype != dtype
or self._buffer.device != device
):
self._buffer = torch.zeros(
(batch_size, num_heads, kv_lora_rank), dtype=dtype, device=device
)
return self._buffer[:batch_size, :num_heads, :kv_lora_rank]
Comment on lines +106 to +139
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get method in FlashInferDecodeOutBuffer re-allocates the buffer using the current dimensions, which can lead to inefficient reallocations if dimensions fluctuate. For instance, a request with large num_heads and small batch_size followed by one with small num_heads and large batch_size would cause two reallocations where the buffer size might shrink and then grow again. A more robust approach is to track the maximum size for each dimension and only grow the buffer, preventing performance degradation from frequent memory operations.

class FlashInferDecodeOutBuffer:
    """Shared decode output buffer for all FlashInferMLAImpl instances.

    The NaN fix (PR #37442) allocated a persistent _decode_out buffer per
    attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of
    GPU memory that is not accounted for during profiling, causing OOM.

    This shared buffer reduces memory from ~15 GiB to ~256 MiB since all
    layers write sequentially (no concurrent access in forward pass).
    """

    def __init__(self):
        self._buffer: torch.Tensor | None = None
        self._max_batch_size = 0
        self._max_num_heads = 0
        self._max_kv_lora_rank = 0

    def get(
        self,
        batch_size: int,
        num_heads: int,
        kv_lora_rank: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        realloc = False
        if (
            self._buffer is None
            or self._buffer.dtype != dtype
            or self._buffer.device != device
        ):
            self._max_batch_size = 0
            self._max_num_heads = 0
            self._max_kv_lora_rank = 0
            realloc = True

        if batch_size > self._max_batch_size:
            self._max_batch_size = batch_size
            realloc = True

        if num_heads > self._max_num_heads:
            self._max_num_heads = num_heads
            realloc = True

        if kv_lora_rank > self._max_kv_lora_rank:
            self._max_kv_lora_rank = kv_lora_rank
            realloc = True

        if realloc:
            self._buffer = torch.zeros(
                (self._max_batch_size, self._max_num_heads, self._max_kv_lora_rank),
                dtype=dtype,
                device=device,
            )
        return self._buffer[:batch_size, :num_heads, :kv_lora_rank]



g_fi_decode_out = FlashInferDecodeOutBuffer()


class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
Expand Down Expand Up @@ -152,10 +191,9 @@ def __init__(
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None

# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None
# Use shared decode output buffer across all layers to reduce memory.
# Each layer writes sequentially, so sharing is safe.
self._decode_out_buffer = g_fi_decode_out

def forward_mqa(
self,
Expand Down Expand Up @@ -209,19 +247,10 @@ def forward_mqa(
if is_quantized_kv_cache(self.kv_cache_dtype)
else q.dtype
)
if (
self._decode_out is None
or self._decode_out.shape[0] < B
or self._decode_out.dtype != dtype
):
self._decode_out = torch.zeros(
B,
q.shape[2],
self.kv_lora_rank,
dtype=dtype,
device=q.device,
)
out_kwargs["out"] = self._decode_out[:B]
# Use shared buffer to reduce memory footprint.
out_kwargs["out"] = self._decode_out_buffer.get(
B, q.shape[2], self.kv_lora_rank, dtype, q.device
)

o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
Expand Down
Loading