Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/kernels/attention/test_mla_zero_out_decode_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import pytest
import torch

from vllm.v1.attention.backends.mla.utils import zero_out_decode_padding


def _assert_zero_out_matches_ref(
*,
num_tokens: int,
num_cols: int,
pad_positions: tuple[int, ...],
dtype: torch.dtype = torch.bfloat16,
num_heads: int = 3,
) -> None:
out = torch.randn(num_tokens, 3, num_cols, dtype=dtype, device="cuda")
seq_lens = torch.ones(num_tokens, dtype=torch.int32, device="cuda")
if pad_positions:
pad_indices = torch.tensor(pad_positions, dtype=torch.long, device="cuda")
seq_lens[pad_indices] = 0
# Match production behavior: padded rows may contain NaNs.
out[pad_indices] = torch.nan

ref = out.clone()
ref[seq_lens == 0] = 0

zero_out_decode_padding(out, seq_lens)
torch.testing.assert_close(out, ref, atol=0, rtol=0)


@pytest.mark.parametrize("num_tokens", [1, 2, 3, 8])
@pytest.mark.parametrize("num_cols", [257, 1024, 1500])
def test_zero_out_padding_exhaustive(num_tokens: int, num_cols: int):
if num_tokens == 1:
_assert_zero_out_matches_ref(
num_tokens=1,
num_cols=num_cols,
pad_positions=(),
)
return

for pad_start in range(1, num_tokens):
_assert_zero_out_matches_ref(
num_tokens=num_tokens,
num_cols=num_cols,
pad_positions=tuple(list(range(pad_start, num_tokens))),
)


@pytest.mark.parametrize("num_tokens", [4, 5, 10, 13, 25] + list(range(55, 64)))
@pytest.mark.parametrize("num_cols", [257])
def test_zero_out_padding(num_tokens: int, num_cols: int) -> None:
_assert_zero_out_matches_ref(
num_tokens=num_tokens,
num_cols=num_cols,
pad_positions=(num_tokens - 2, num_tokens - 1),
)
15 changes: 1 addition & 14 deletions vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,6 @@ def __init__(
# Share workspace buffer across all executions
self._workspace = g_sm100_workspace

# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None

def _sm100_cutlass_mla_decode(
self,
q_nope: torch.Tensor,
Expand Down Expand Up @@ -223,15 +218,7 @@ def _sm100_cutlass_mla_decode(
if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype
)
# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
if (
self._decode_out is None
or self._decode_out.shape[0] < B_q
or self._decode_out.dtype != dtype
):
self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
out = self._decode_out[:B_q]
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
lse = (
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode
Expand Down
53 changes: 9 additions & 44 deletions vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.mla.utils import zero_out_decode_padding
from vllm.v1.attention.backends.utils import KVCacheLayoutType

logger = init_logger(__name__)
Expand Down Expand Up @@ -152,11 +152,6 @@ def __init__(
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None

# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None

def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
Expand Down Expand Up @@ -192,37 +187,6 @@ def forward_mqa(
if self.kv_cache_dtype.startswith("fp8"):
self.bmm2_scale *= layer._k_scale_float

# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
# q is 4D: (batch, q_len_per_req, num_heads, head_dim)
# FlashInfer has a bug where out= validation hardcodes 3D shape
# (batch, num_heads, kv_lora_rank), but the kernel writes 4D
# (batch, q_len, num_heads, kv_lora_rank) when q_len > 1.
# So we can only pass out= for single-token decode (q_len == 1).
# For q_len > 1, we zero padding slots after the kernel returns.
# TODO: upstream fix to FlashInfer
B, q_len_per_req = q.shape[0], q.shape[1]
out_kwargs: dict[str, torch.Tensor] = {}
if q_len_per_req == 1:
dtype = (
torch.bfloat16
if is_quantized_kv_cache(self.kv_cache_dtype)
else q.dtype
)
if (
self._decode_out is None
or self._decode_out.shape[0] < B
or self._decode_out.dtype != dtype
):
self._decode_out = torch.zeros(
B,
q.shape[2],
self.kv_lora_rank,
dtype=dtype,
device=q.device,
)
out_kwargs["out"] = self._decode_out[:B]

o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
Expand All @@ -235,14 +199,15 @@ def forward_mqa(
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
**out_kwargs,
)

# For q_len > 1, we can't pass out= so we work around by zeroing padding slots
if not out_kwargs:
num_real = attn_metadata.num_decodes
if num_real < o.shape[0]:
o[num_real:] = 0
# Flashinfer MLA kernels introduces NaNs in padded regions in
# some cases. We need to zero out the padded regions to avoid
# NaNs in the output.
assert o.size(0) == attn_metadata.decode.seq_lens.size(0), (
f"output shape {o.size()} != "
f"seq_lens shape {attn_metadata.decode.seq_lens.size()}"
)
o = zero_out_decode_padding(o, attn_metadata.decode.seq_lens)

# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])
Expand Down
70 changes: 70 additions & 0 deletions vllm/v1/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.triton_utils import HAS_TRITON, tl, triton

_DEFAULT_BLOCK_SIZE = 1024


if HAS_TRITON:

@triton.jit
def _zero_out_decode_padding_kernel(
out_ptr,
seq_lens_ptr,
row_stride,
num_cols,
BLOCK_SIZE: tl.constexpr,
) -> None:
row = tl.program_id(0)

if tl.load(seq_lens_ptr + row) != 0:
return

col_offsets = tl.arange(0, BLOCK_SIZE)
out_ptrs = out_ptr + row * row_stride + col_offsets
for c in tl.range(0, tl.cdiv(num_cols, BLOCK_SIZE)):
mask = col_offsets + c * BLOCK_SIZE < num_cols
tl.store(
out_ptrs,
tl.zeros([BLOCK_SIZE], dtype=out_ptr.dtype.element_ty),
mask=mask,
)
out_ptrs += BLOCK_SIZE
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson @elvircrn the kernel can use a fresh pair of eyes. Thanks 🙌

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm!



def _zero_out_decode_padding_triton(
out: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
"""Zero rows in `out` where `seq_lens == 0` using a Triton kernel."""
if not out.is_cuda or not seq_lens.is_cuda:
raise ValueError("out and seq_lens must be CUDA tensors.")
if out.size(0) != seq_lens.numel():
raise ValueError(
f"out.size(0) {out.size()} must matchseq_lens.numel() ({seq_lens.numel()})."
)
if not out.is_contiguous():
raise ValueError("out must be contiguous.")

BLOCK_SIZE = 1024

out_2d = out.view(out.size(0), -1)
grid = (out_2d.size(0),)
_zero_out_decode_padding_kernel[grid](
out_2d,
seq_lens,
out_2d.stride(0),
out_2d.size(1),
BLOCK_SIZE=BLOCK_SIZE,
)


def zero_out_decode_padding(out: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor:
if HAS_TRITON:
_zero_out_decode_padding_triton(out, seq_lens)
else:
out[seq_lens == 0] = 0
return out
Loading