diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index fd4d9ab84274..bedfc02d603b 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -27,6 +27,10 @@ logger = init_logger(__name__) +# Shared across all CutlassMLAImpl instances (all layers) to avoid +# 61× per-layer allocation (~15 GiB → ~256 MiB for DeepSeek-R1). +g_cutlass_decode_out: torch.Tensor | None = None + class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): # enable full CUDA Graph support for decode-only capture @@ -162,10 +166,6 @@ def __init__( # Share workspace buffer across all executions self._workspace = g_sm100_workspace - # Pre-allocated output buffer, lazily sized on first call. - # Zero-init once to prevent NaN in padding slots (seq_lens=0) - # from contaminating downstream per-tensor reductions. - self._decode_out: torch.Tensor | None = None def _sm100_cutlass_mla_decode( self, @@ -223,15 +223,17 @@ def _sm100_cutlass_mla_decode( if is_quantized_kv_cache(self.kv_cache_dtype) else q_nope.dtype ) - # Reuse pre-allocated zero-init output buffer to avoid a memset - # kernel on every CUDA graph replay. + # Reuse a single zero-init output buffer shared across all layers + # to prevent NaN in padding slots. Shared buffer reduces memory + # from ~15 GiB (per-layer) to ~256 MiB for DeepSeek-R1. + global g_cutlass_decode_out if ( - self._decode_out is None - or self._decode_out.shape[0] < B_q - or self._decode_out.dtype != dtype + g_cutlass_decode_out is None + or g_cutlass_decode_out.shape[0] < B_q + or g_cutlass_decode_out.dtype != dtype ): - self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype) - out = self._decode_out[:B_q] + g_cutlass_decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype) + out = g_cutlass_decode_out[:B_q] lse = ( torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) if self.need_to_return_lse_for_decode diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 3de0dcdd8c01..1260e11ab51b 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -29,6 +29,10 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +# Shared across all FlashInferMLAImpl instances (all layers) to avoid +# 61× per-layer allocation (~15 GiB → ~256 MiB for DeepSeek-R1). +g_fi_decode_out: torch.Tensor | None = None + class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH @@ -152,10 +156,6 @@ def __init__( self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None - # Pre-allocated output buffer, lazily sized on first call. - # Zero-init once to prevent NaN in padding slots (seq_lens=0) - # from contaminating downstream per-tensor reductions. - self._decode_out: torch.Tensor | None = None def forward_mqa( self, @@ -192,15 +192,18 @@ def forward_mqa( if self.kv_cache_dtype.startswith("fp8"): self.bmm2_scale *= layer._k_scale_float - # Reuse pre-allocated zero-init output buffer to avoid a memset - # kernel on every CUDA graph replay. - # q is 4D: (batch, q_len_per_req, num_heads, head_dim) + # Reuse a single zero-init output buffer shared across all layers + # to prevent NaN in padding slots (seq_lens=0) from contaminating + # downstream per-tensor reductions. Shared buffer reduces memory + # from ~15 GiB (per-layer) to ~256 MiB for DeepSeek-R1. + # # FlashInfer has a bug where out= validation hardcodes 3D shape # (batch, num_heads, kv_lora_rank), but the kernel writes 4D # (batch, q_len, num_heads, kv_lora_rank) when q_len > 1. # So we can only pass out= for single-token decode (q_len == 1). # For q_len > 1, we zero padding slots after the kernel returns. # TODO: upstream fix to FlashInfer + global g_fi_decode_out B, q_len_per_req = q.shape[0], q.shape[1] out_kwargs: dict[str, torch.Tensor] = {} if q_len_per_req == 1: @@ -210,18 +213,18 @@ def forward_mqa( else q.dtype ) if ( - self._decode_out is None - or self._decode_out.shape[0] < B - or self._decode_out.dtype != dtype + g_fi_decode_out is None + or g_fi_decode_out.shape[0] < B + or g_fi_decode_out.dtype != dtype ): - self._decode_out = torch.zeros( + g_fi_decode_out = torch.zeros( B, q.shape[2], self.kv_lora_rank, dtype=dtype, device=q.device, ) - out_kwargs["out"] = self._decode_out[:B] + out_kwargs["out"] = g_fi_decode_out[:B] o = trtllm_batch_decode_with_kv_cache_mla( query=q,