Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def __init__(
# Share workspace buffer across all executions
self._workspace = g_sm100_workspace

# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None

def _sm100_cutlass_mla_decode(
self,
q_nope: torch.Tensor,
Expand Down Expand Up @@ -218,7 +223,15 @@ def _sm100_cutlass_mla_decode(
if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype
)
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
if (
self._decode_out is None
or self._decode_out.shape[0] < B_q
or self._decode_out.dtype != dtype
):
self._decode_out = q_nope.new_zeros((B_q, MAX_HEADS, D_latent), dtype=dtype)
out = self._decode_out[:B_q]
lse = (
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
if self.need_to_return_lse_for_decode
Expand Down
44 changes: 44 additions & 0 deletions vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import KVCacheLayoutType

Expand Down Expand Up @@ -151,6 +152,11 @@ def __init__(
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None

# Pre-allocated output buffer, lazily sized on first call.
# Zero-init once to prevent NaN in padding slots (seq_lens=0)
# from contaminating downstream per-tensor reductions.
self._decode_out: torch.Tensor | None = None

def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
Expand Down Expand Up @@ -181,6 +187,37 @@ def forward_mqa(
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float

# Reuse pre-allocated zero-init output buffer to avoid a memset
# kernel on every CUDA graph replay.
# q is 4D: (batch, q_len_per_req, num_heads, head_dim)
# FlashInfer has a bug where out= validation hardcodes 3D shape
# (batch, num_heads, kv_lora_rank), but the kernel writes 4D
# (batch, q_len, num_heads, kv_lora_rank) when q_len > 1.
# So we can only pass out= for single-token decode (q_len == 1).
# For q_len > 1, we zero padding slots after the kernel returns.
# TODO: upstream fix to FlashInfer
B, q_len_per_req = q.shape[0], q.shape[1]
out_kwargs: dict[str, torch.Tensor] = {}
if q_len_per_req == 1:
dtype = (
torch.bfloat16
if is_quantized_kv_cache(self.kv_cache_dtype)
else q.dtype
)
if (
self._decode_out is None
or self._decode_out.shape[0] < B
or self._decode_out.dtype != dtype
):
self._decode_out = torch.zeros(
B,
q.shape[2],
self.kv_lora_rank,
dtype=dtype,
device=q.device,
)
out_kwargs["out"] = self._decode_out[:B]

o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
Expand All @@ -193,8 +230,15 @@ def forward_mqa(
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
**out_kwargs,
)

# For q_len > 1, we can't pass out= so we work around by zeroing padding slots
if not out_kwargs:
num_real = attn_metadata.num_decodes
if num_real < o.shape[0]:
o[num_real:] = 0

# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])

Expand Down
Loading