diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index fd4d9ab84274..b01ce2be2418 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -162,11 +162,6 @@ def __init__( # Share workspace buffer across all executions self._workspace = g_sm100_workspace - # Pre-allocated output buffer, lazily sized on first call. - # Zero-init once to prevent NaN in padding slots (seq_lens=0) - # from contaminating downstream per-tensor reductions. - self._decode_out: torch.Tensor | None = None - def _sm100_cutlass_mla_decode( self, q_nope: torch.Tensor, @@ -223,15 +218,7 @@ def _sm100_cutlass_mla_decode( if is_quantized_kv_cache(self.kv_cache_dtype) else q_nope.dtype ) - # Reuse pre-allocated zero-init output buffer to avoid a memset - # kernel on every CUDA graph replay. - if ( - self._decode_out is None - or self._decode_out.shape[0] < B_q - or self._decode_out.dtype != dtype - ): - self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype) - out = self._decode_out[:B_q] + out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) lse = ( torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) if self.need_to_return_lse_for_decode diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 16d01bd338ca..dc5b72bacdbf 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -21,7 +21,6 @@ AttentionLayer, AttentionType, MultipleOf, - is_quantized_kv_cache, ) from vllm.v1.attention.backends.utils import KVCacheLayoutType @@ -152,11 +151,6 @@ def __init__( self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None - # Pre-allocated output buffer, lazily sized on first call. - # Zero-init once to prevent NaN in padding slots (seq_lens=0) - # from contaminating downstream per-tensor reductions. - self._decode_out: torch.Tensor | None = None - def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], @@ -192,37 +186,6 @@ def forward_mqa( if self.kv_cache_dtype.startswith("fp8"): self.bmm2_scale *= layer._k_scale_float - # Reuse pre-allocated zero-init output buffer to avoid a memset - # kernel on every CUDA graph replay. - # q is 4D: (batch, q_len_per_req, num_heads, head_dim) - # FlashInfer has a bug where out= validation hardcodes 3D shape - # (batch, num_heads, kv_lora_rank), but the kernel writes 4D - # (batch, q_len, num_heads, kv_lora_rank) when q_len > 1. - # So we can only pass out= for single-token decode (q_len == 1). - # For q_len > 1, we zero padding slots after the kernel returns. - # TODO: upstream fix to FlashInfer - B, q_len_per_req = q.shape[0], q.shape[1] - out_kwargs: dict[str, torch.Tensor] = {} - if q_len_per_req == 1: - dtype = ( - torch.bfloat16 - if is_quantized_kv_cache(self.kv_cache_dtype) - else q.dtype - ) - if ( - self._decode_out is None - or self._decode_out.shape[0] < B - or self._decode_out.dtype != dtype - ): - self._decode_out = torch.zeros( - B, - q.shape[2], - self.kv_lora_rank, - dtype=dtype, - device=q.device, - ) - out_kwargs["out"] = self._decode_out[:B] - o = trtllm_batch_decode_with_kv_cache_mla( query=q, kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), @@ -235,15 +198,8 @@ def forward_mqa( max_seq_len=attn_metadata.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, - **out_kwargs, ) - # For q_len > 1, we can't pass out= so we work around by zeroing padding slots - if not out_kwargs: - num_real = attn_metadata.num_decodes - if num_real < o.shape[0]: - o[num_real:] = 0 - # Flatten the output for consistent shape o = o.view(-1, o.shape[-2], o.shape[-1])