diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index fd4d9ab84274..adc2fe112b7d 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -98,6 +98,38 @@ def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int): g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB + +class SM100DecodeOutBuffer: + """Shared decode output buffer for all CutlassMLAImpl instances. + + The NaN fix (PR #37442) allocated a persistent _decode_out buffer per + attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of + GPU memory that is not accounted for during profiling, causing OOM. + + This shared buffer reduces memory from ~15 GiB to ~256 MiB since all + layers write sequentially (no concurrent access in forward pass). + """ + + def __init__(self): + self._buffer: torch.Tensor | None = None + self._dtype: torch.dtype | None = None + + def get(self, batch_size: int, dtype: torch.dtype, device: torch.device): + D_latent = 512 + if ( + self._buffer is None + or self._buffer.shape[0] < batch_size + or self._buffer.dtype != dtype + or self._buffer.device != device + ): + self._buffer = torch.zeros( + (batch_size, MAX_HEADS, D_latent), dtype=dtype, device=device + ) + return self._buffer[:batch_size] + + +g_sm100_decode_out = SM100DecodeOutBuffer() + MAX_HEADS = 128 @@ -162,10 +194,9 @@ def __init__( # Share workspace buffer across all executions self._workspace = g_sm100_workspace - # Pre-allocated output buffer, lazily sized on first call. - # Zero-init once to prevent NaN in padding slots (seq_lens=0) - # from contaminating downstream per-tensor reductions. - self._decode_out: torch.Tensor | None = None + # Use shared decode output buffer across all layers to reduce memory. + # Each layer writes sequentially, so sharing is safe. + self._decode_out_buffer = g_sm100_decode_out def _sm100_cutlass_mla_decode( self, @@ -223,15 +254,8 @@ def _sm100_cutlass_mla_decode( if is_quantized_kv_cache(self.kv_cache_dtype) else q_nope.dtype ) - # Reuse pre-allocated zero-init output buffer to avoid a memset - # kernel on every CUDA graph replay. - if ( - self._decode_out is None - or self._decode_out.shape[0] < B_q - or self._decode_out.dtype != dtype - ): - self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype) - out = self._decode_out[:B_q] + # Use shared buffer to reduce memory footprint. + out = self._decode_out_buffer.get(B_q, dtype, q_nope.device) lse = ( torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) if self.need_to_return_lse_for_decode diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 16d01bd338ca..1c2e0b415b17 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -103,6 +103,45 @@ def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": ) +class FlashInferDecodeOutBuffer: + """Shared decode output buffer for all FlashInferMLAImpl instances. + + The NaN fix (PR #37442) allocated a persistent _decode_out buffer per + attention layer. For DeepSeek-R1 (61 layers), this totals ~15 GiB of + GPU memory that is not accounted for during profiling, causing OOM. + + This shared buffer reduces memory from ~15 GiB to ~256 MiB since all + layers write sequentially (no concurrent access in forward pass). + """ + + def __init__(self): + self._buffer: torch.Tensor | None = None + + def get( + self, + batch_size: int, + num_heads: int, + kv_lora_rank: int, + dtype: torch.dtype, + device: torch.device, + ): + if ( + self._buffer is None + or self._buffer.shape[0] < batch_size + or self._buffer.shape[1] < num_heads + or self._buffer.shape[2] < kv_lora_rank + or self._buffer.dtype != dtype + or self._buffer.device != device + ): + self._buffer = torch.zeros( + (batch_size, num_heads, kv_lora_rank), dtype=dtype, device=device + ) + return self._buffer[:batch_size, :num_heads, :kv_lora_rank] + + +g_fi_decode_out = FlashInferDecodeOutBuffer() + + class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): def __init__( self, @@ -152,10 +191,9 @@ def __init__( self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None - # Pre-allocated output buffer, lazily sized on first call. - # Zero-init once to prevent NaN in padding slots (seq_lens=0) - # from contaminating downstream per-tensor reductions. - self._decode_out: torch.Tensor | None = None + # Use shared decode output buffer across all layers to reduce memory. + # Each layer writes sequentially, so sharing is safe. + self._decode_out_buffer = g_fi_decode_out def forward_mqa( self, @@ -209,19 +247,10 @@ def forward_mqa( if is_quantized_kv_cache(self.kv_cache_dtype) else q.dtype ) - if ( - self._decode_out is None - or self._decode_out.shape[0] < B - or self._decode_out.dtype != dtype - ): - self._decode_out = torch.zeros( - B, - q.shape[2], - self.kv_lora_rank, - dtype=dtype, - device=q.device, - ) - out_kwargs["out"] = self._decode_out[:B] + # Use shared buffer to reduce memory footprint. + out_kwargs["out"] = self._decode_out_buffer.get( + B, q.shape[2], self.kv_lora_rank, dtype, q.device + ) o = trtllm_batch_decode_with_kv_cache_mla( query=q,