diff --git a/tests/kernels/attention/test_mla_zero_out_decode_padding.py b/tests/kernels/attention/test_mla_zero_out_decode_padding.py new file mode 100644 index 000000000000..86fbe810dde8 --- /dev/null +++ b/tests/kernels/attention/test_mla_zero_out_decode_padding.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest +import torch + +from vllm.v1.attention.backends.mla.utils import zero_out_decode_padding + + +def _assert_zero_out_matches_ref( + *, + num_tokens: int, + num_cols: int, + pad_positions: tuple[int, ...], + dtype: torch.dtype = torch.bfloat16, + num_heads: int = 3, +) -> None: + out = torch.randn(num_tokens, 3, num_cols, dtype=dtype, device="cuda") + seq_lens = torch.ones(num_tokens, dtype=torch.int32, device="cuda") + if pad_positions: + pad_indices = torch.tensor(pad_positions, dtype=torch.long, device="cuda") + seq_lens[pad_indices] = 0 + # Match production behavior: padded rows may contain NaNs. + out[pad_indices] = torch.nan + + ref = out.clone() + ref[seq_lens == 0] = 0 + + zero_out_decode_padding(out, seq_lens) + torch.testing.assert_close(out, ref, atol=0, rtol=0) + + +@pytest.mark.parametrize("num_tokens", [1, 2, 3, 8]) +@pytest.mark.parametrize("num_cols", [257, 1024, 1500]) +def test_zero_out_padding_exhaustive(num_tokens: int, num_cols: int): + if num_tokens == 1: + _assert_zero_out_matches_ref( + num_tokens=1, + num_cols=num_cols, + pad_positions=(), + ) + return + + for pad_start in range(1, num_tokens): + _assert_zero_out_matches_ref( + num_tokens=num_tokens, + num_cols=num_cols, + pad_positions=tuple(list(range(pad_start, num_tokens))), + ) + + +@pytest.mark.parametrize("num_tokens", [4, 5, 10, 13, 25] + list(range(55, 64))) +@pytest.mark.parametrize("num_cols", [257]) +def test_zero_out_padding(num_tokens: int, num_cols: int) -> None: + _assert_zero_out_matches_ref( + num_tokens=num_tokens, + num_cols=num_cols, + pad_positions=(num_tokens - 2, num_tokens - 1), + ) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index fd4d9ab84274..b01ce2be2418 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -162,11 +162,6 @@ def __init__( # Share workspace buffer across all executions self._workspace = g_sm100_workspace - # Pre-allocated output buffer, lazily sized on first call. - # Zero-init once to prevent NaN in padding slots (seq_lens=0) - # from contaminating downstream per-tensor reductions. - self._decode_out: torch.Tensor | None = None - def _sm100_cutlass_mla_decode( self, q_nope: torch.Tensor, @@ -223,15 +218,7 @@ def _sm100_cutlass_mla_decode( if is_quantized_kv_cache(self.kv_cache_dtype) else q_nope.dtype ) - # Reuse pre-allocated zero-init output buffer to avoid a memset - # kernel on every CUDA graph replay. - if ( - self._decode_out is None - or self._decode_out.shape[0] < B_q - or self._decode_out.dtype != dtype - ): - self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype) - out = self._decode_out[:B_q] + out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) lse = ( torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) if self.need_to_return_lse_for_decode diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 16d01bd338ca..c3020b12d278 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -21,8 +21,8 @@ AttentionLayer, AttentionType, MultipleOf, - is_quantized_kv_cache, ) +from vllm.v1.attention.backends.mla.utils import zero_out_decode_padding from vllm.v1.attention.backends.utils import KVCacheLayoutType logger = init_logger(__name__) @@ -152,11 +152,6 @@ def __init__( self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None - # Pre-allocated output buffer, lazily sized on first call. - # Zero-init once to prevent NaN in padding slots (seq_lens=0) - # from contaminating downstream per-tensor reductions. - self._decode_out: torch.Tensor | None = None - def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], @@ -192,37 +187,6 @@ def forward_mqa( if self.kv_cache_dtype.startswith("fp8"): self.bmm2_scale *= layer._k_scale_float - # Reuse pre-allocated zero-init output buffer to avoid a memset - # kernel on every CUDA graph replay. - # q is 4D: (batch, q_len_per_req, num_heads, head_dim) - # FlashInfer has a bug where out= validation hardcodes 3D shape - # (batch, num_heads, kv_lora_rank), but the kernel writes 4D - # (batch, q_len, num_heads, kv_lora_rank) when q_len > 1. - # So we can only pass out= for single-token decode (q_len == 1). - # For q_len > 1, we zero padding slots after the kernel returns. - # TODO: upstream fix to FlashInfer - B, q_len_per_req = q.shape[0], q.shape[1] - out_kwargs: dict[str, torch.Tensor] = {} - if q_len_per_req == 1: - dtype = ( - torch.bfloat16 - if is_quantized_kv_cache(self.kv_cache_dtype) - else q.dtype - ) - if ( - self._decode_out is None - or self._decode_out.shape[0] < B - or self._decode_out.dtype != dtype - ): - self._decode_out = torch.zeros( - B, - q.shape[2], - self.kv_lora_rank, - dtype=dtype, - device=q.device, - ) - out_kwargs["out"] = self._decode_out[:B] - o = trtllm_batch_decode_with_kv_cache_mla( query=q, kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), @@ -235,14 +199,15 @@ def forward_mqa( max_seq_len=attn_metadata.max_seq_len, bmm1_scale=self.bmm1_scale, bmm2_scale=self.bmm2_scale, - **out_kwargs, ) - - # For q_len > 1, we can't pass out= so we work around by zeroing padding slots - if not out_kwargs: - num_real = attn_metadata.num_decodes - if num_real < o.shape[0]: - o[num_real:] = 0 + # Flashinfer MLA kernels introduces NaNs in padded regions in + # some cases. We need to zero out the padded regions to avoid + # NaNs in the output. + assert o.size(0) == attn_metadata.decode.seq_lens.size(0), ( + f"output shape {o.size()} != " + f"seq_lens shape {attn_metadata.decode.seq_lens.size()}" + ) + o = zero_out_decode_padding(o, attn_metadata.decode.seq_lens) # Flatten the output for consistent shape o = o.view(-1, o.shape[-2], o.shape[-1]) diff --git a/vllm/v1/attention/backends/mla/utils.py b/vllm/v1/attention/backends/mla/utils.py new file mode 100644 index 000000000000..3306ef651ade --- /dev/null +++ b/vllm/v1/attention/backends/mla/utils.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.triton_utils import HAS_TRITON, tl, triton + +_DEFAULT_BLOCK_SIZE = 1024 + + +if HAS_TRITON: + + @triton.jit + def _zero_out_decode_padding_kernel( + out_ptr, + seq_lens_ptr, + row_stride, + num_cols, + BLOCK_SIZE: tl.constexpr, + ) -> None: + row = tl.program_id(0) + + if tl.load(seq_lens_ptr + row) != 0: + return + + col_offsets = tl.arange(0, BLOCK_SIZE) + out_ptrs = out_ptr + row * row_stride + col_offsets + for c in tl.range(0, tl.cdiv(num_cols, BLOCK_SIZE)): + mask = col_offsets + c * BLOCK_SIZE < num_cols + tl.store( + out_ptrs, + tl.zeros([BLOCK_SIZE], dtype=out_ptr.dtype.element_ty), + mask=mask, + ) + out_ptrs += BLOCK_SIZE + + +def _zero_out_decode_padding_triton( + out: torch.Tensor, + seq_lens: torch.Tensor, +) -> None: + """Zero rows in `out` where `seq_lens == 0` using a Triton kernel.""" + if not out.is_cuda or not seq_lens.is_cuda: + raise ValueError("out and seq_lens must be CUDA tensors.") + if out.size(0) != seq_lens.numel(): + raise ValueError( + f"out.size(0) {out.size()} must matchseq_lens.numel() ({seq_lens.numel()})." + ) + if not out.is_contiguous(): + raise ValueError("out must be contiguous.") + + BLOCK_SIZE = 1024 + + out_2d = out.view(out.size(0), -1) + grid = (out_2d.size(0),) + _zero_out_decode_padding_kernel[grid]( + out_2d, + seq_lens, + out_2d.stride(0), + out_2d.size(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def zero_out_decode_padding(out: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor: + if HAS_TRITON: + _zero_out_decode_padding_triton(out, seq_lens) + else: + out[seq_lens == 0] = 0 + return out