Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions tests/v1/attention/test_turboquant_spec_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for TurboQuant K+1 spec-verify routing fix (#40880).

Verifies that uniform-query batches with max_query_len > 1 (typical for
MTP num_speculative_tokens=K, where verify produces K+1 query length) are
routed through `triton_turboquant_decode_attention` instead of the
default `_prefill_attention` continuation branch.

The default branch contains a `query_start_loc.tolist()` GPU→CPU sync
that is incompatible with active CUDA stream capture and was the root
cause of the degenerate-token cascade reported in #40880.

These tests use the `synth_seq_lens` trick to construct the routing
arguments and verify shape, dtype, and cudagraph-safety properties.
A full end-to-end correctness test against the unpatched continuation
path requires GPU + a TurboQuant model checkpoint and is gated under
`@pytest.mark.cuda` + skip-if-no-tq-model.
"""
from __future__ import annotations

import pytest
import torch


pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="TurboQuant K+1 spec-verify routing requires CUDA",
)


def _synth_args(batch_size: int, k_plus_1: int, base_seq_lens: torch.Tensor,
block_table: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Helper: build synth_seq_lens + synth_block_table for K+1 verify routing.

Mirrors the pattern used in `TurboQuantAttentionImpl.forward()`.
"""
device = base_seq_lens.device
offs = torch.arange(k_plus_1, device=device, dtype=base_seq_lens.dtype)
synth_seq_lens = (
base_seq_lens[:batch_size, None] - k_plus_1 + 1 + offs[None, :]
).reshape(-1)
synth_block_table = block_table[:batch_size].repeat_interleave(
k_plus_1, dim=0,
)
return synth_seq_lens, synth_block_table


def test_synth_seq_lens_shape():
"""synth_seq_lens must be (B*K_PLUS_1,) and equal expected pattern."""
device = torch.device("cuda")
B = 2
K_PLUS_1 = 4
base_seq_lens = torch.tensor([100, 200], dtype=torch.int32, device=device)
block_table = torch.zeros((B, 32), dtype=torch.int32, device=device)
block_table[0, :] = torch.arange(32, device=device)
block_table[1, :] = torch.arange(100, 132, device=device)

synth_seq_lens, synth_block_table = _synth_args(B, K_PLUS_1, base_seq_lens, block_table)

# Shapes
assert synth_seq_lens.shape == (B * K_PLUS_1,), \
f"expected ({B*K_PLUS_1},), got {synth_seq_lens.shape}"
assert synth_block_table.shape == (B * K_PLUS_1, 32), \
f"expected ({B*K_PLUS_1}, 32), got {synth_block_table.shape}"

# Per-request synth_seq_lens pattern: base - K1 + 1, base - K1 + 2, ..., base
# For req 0 (base=100): 97, 98, 99, 100
# For req 1 (base=200): 197, 198, 199, 200
expected = torch.tensor(
[97, 98, 99, 100, 197, 198, 199, 200],
dtype=torch.int32, device=device,
)
assert torch.equal(synth_seq_lens, expected), \
f"synth_seq_lens mismatch:\nexpected={expected.tolist()}\ngot={synth_seq_lens.tolist()}"

# Per-request block table is replicated K_PLUS_1 times
for req in range(B):
for offset in range(K_PLUS_1):
assert torch.equal(
synth_block_table[req * K_PLUS_1 + offset],
block_table[req],
), f"block table replication mismatch at req={req} offset={offset}"


def test_synth_dtypes_preserved():
"""Synth args must preserve the dtype of source seq_lens / block_table."""
device = torch.device("cuda")
for seq_dtype in (torch.int32, torch.int64):
base_seq_lens = torch.tensor([50], dtype=seq_dtype, device=device)
block_table = torch.zeros((1, 4), dtype=torch.int32, device=device)
synth_seq_lens, synth_block_table = _synth_args(1, 4, base_seq_lens, block_table)
assert synth_seq_lens.dtype == seq_dtype
assert synth_block_table.dtype == torch.int32


def test_synth_construction_no_cpu_sync():
"""Synth construction must be entirely on-GPU (no .item() / .tolist() sync).

This is the property that makes the routing safe under cudagraph capture.
We verify by checking that the operations are purely tensor ops with no
Python control flow that depends on tensor values.
"""
device = torch.device("cuda")
base_seq_lens = torch.tensor([100, 200, 300], dtype=torch.int32, device=device)
block_table = torch.zeros((3, 16), dtype=torch.int32, device=device)

# Run inside a stream-captured region — should NOT raise
g = torch.cuda.CUDAGraph()
static_input_seq_lens = base_seq_lens.clone()
static_input_block_table = block_table.clone()
# Warmup
_ = _synth_args(3, 4, static_input_seq_lens, static_input_block_table)
torch.cuda.synchronize()

# Capture
with torch.cuda.graph(g):
_ = _synth_args(3, 4, static_input_seq_lens, static_input_block_table)

# If we got here without exception, synth_args is cudagraph-safe.
# Replay should also work
g.replay()
torch.cuda.synchronize()


def test_eligibility_predicate():
"""Verify the dispatch predicate matches expected K+1 spec-verify shape."""
# Mock metadata fields the predicate checks
class FakeMeta:
is_prefill: bool
num_decodes: int
max_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor

# Eligible: K+1=4, has prior cache, batch divisible
m = FakeMeta()
m.is_prefill = True
m.num_decodes = 0
m.max_query_len = 4
m.max_seq_len = 1024
m.query_start_loc = torch.zeros(3, dtype=torch.int32) # B=2, B+1 = 3
N = 8 # = B*K1 = 2*4
eligible = (
m.is_prefill and m.num_decodes == 0
and 1 < m.max_query_len <= 16
and m.max_seq_len > m.max_query_len
and N > 0 and N % m.max_query_len == 0
and m.query_start_loc is not None
)
assert eligible

# NOT eligible: pure decode (max_query_len == 1)
m.max_query_len = 1
eligible = (
m.is_prefill and m.num_decodes == 0
and 1 < m.max_query_len <= 16
)
assert not eligible

# NOT eligible: no prior cache (max_seq_len == max_query_len, fresh prefill)
m.max_query_len = 4
m.max_seq_len = 4
eligible = (
m.is_prefill and m.num_decodes == 0
and 1 < m.max_query_len <= 16
and m.max_seq_len > m.max_query_len
)
assert not eligible

# NOT eligible: K+1 too large (>16, e.g., wrong spec-decode tree depth)
m.max_query_len = 32
m.max_seq_len = 1024
eligible = (
m.is_prefill and m.num_decodes == 0
and 1 < m.max_query_len <= 16
)
assert not eligible


# End-to-end correctness test (requires Qwen3.6-A3B-FP8 checkpoint + TQ model)
# would go here, gated under @pytest.mark.gpu + skip-if-no-model. Pre-flight
# check: this PR does not include such a model in CI; the empirical TPS data
# (75.6 vs 57.2 tok/s, +32%) is documented in the PR body and was measured
# on Sandermage/genesis-vllm-patches by the contributor.
86 changes: 86 additions & 0 deletions vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,92 @@ def forward(
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens

# ════════════════════════════════════════════════════════════════════
# Spec-decode K+1 verify path (fixes #40880).
#
# When speculative decoding is active (MTP num_speculative_tokens=K>0),
# the verify pass produces uniform-query batches with
# max_query_len = K+1 (e.g., K=3 → q_len=4 per request) where
# max_seq_len > max_query_len (each request has prior cached KV).
#
# The default `_prefill_attention` continuation branch reads
# `query_start_loc.tolist()` which (1) forces a GPU→CPU sync
# incompatible with active CUDA stream capture, and (2) was the
# root cause for #40880 (degenerate token cascades on Qwen3.6-MoE
# under MTP=3 + FULL_AND_PIECEWISE cudagraph).
#
# Fix: route uniform K+1 spec-verify batches through the
# `triton_turboquant_decode_attention` kernel via the same
# synth_seq_lens trick that `_continuation_prefill` uses internally.
# The decode kernel handles compressed K+V cache lookup natively
# and is cudagraph-safe (no CPU sync), so this restores
# FULL_AND_PIECEWISE capture for spec-decode workloads.
#
# Empirical: +32% wall-clock TPS on Qwen3.6-35B-A3B-FP8 + MTP=3
# + 2× RTX A5000 + TurboQuant k8v4 vs PIECEWISE-downgraded baseline.
# Cross-rig validation pending; tested ONLY on Ampere SM 8.6 so far.
# See `tests/v1/attention/test_turboquant_spec_verify.py`.
# ════════════════════════════════════════════════════════════════════
_spec_verify_eligible = (
attn_metadata.is_prefill
and num_decodes == 0
and 1 < attn_metadata.max_query_len <= 16
and attn_metadata.max_seq_len > attn_metadata.max_query_len
and N > 0
and (N % attn_metadata.max_query_len) == 0
and attn_metadata.query_start_loc is not None
)
if _spec_verify_eligible:
K_PLUS_1 = attn_metadata.max_query_len
B = N // K_PLUS_1
if attn_metadata.query_start_loc.shape[0] == B + 1:
from vllm.v1.attention.ops.triton_turboquant_decode import (
triton_turboquant_decode_attention,
)
# Build synth args mirroring _continuation_prefill's pattern:
# synth_seq_lens[req*K1+i] = base_seq_lens[req] - K1 + 1 + i
# synth_block_table[req*K1+i] = block_table[req]
# All GPU ops — cudagraph-safe.
_q_flat = q[:N].view(N, self.num_heads, self.head_size)
_offs = torch.arange(
K_PLUS_1, device=q.device,
dtype=attn_metadata.seq_lens.dtype,
)
_synth_seq_lens = (
attn_metadata.seq_lens[:B, None] - K_PLUS_1 + 1 + _offs[None, :]
).reshape(-1)
_synth_block_table = attn_metadata.block_table[:B].repeat_interleave(
K_PLUS_1, dim=0,
)
# Reuse cached decode buffers from the layer to avoid
# per-call torch.empty allocations — these would break
# CUDA graph replay (the very thing this PR restores).
# Per gemini-code-assist review on this PR.
_mid_o_buf = getattr(layer, "_tq_mid_o_buf", None)
_output_buf = getattr(layer, "_tq_output_buf", None)
_lse_buf = getattr(layer, "_tq_lse_buf", None)
attn_out = triton_turboquant_decode_attention(
query=_q_flat,
kv_cache=kv_cache,
block_table=_synth_block_table,
seq_lens=_synth_seq_lens,
Pi=Pi,
centroids=centroids,
scale=self.scale,
mse_bits=self.tq_config.key_mse_bits,
key_packed_size=self.tq_config.key_packed_size,
value_quant_bits=self.tq_config.effective_value_quant_bits,
key_fp8=self.tq_config.key_fp8,
norm_correction=self.tq_config.norm_correction,
PiT=PiT,
mid_o_buf=_mid_o_buf,
output_buf=_output_buf,
lse_buf=_lse_buf,
buf_holder=layer,
max_num_kv_splits=self.max_num_kv_splits,
)
return attn_out
Comment on lines +451 to +500
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new routing path for speculative verification does not reuse the cached decode buffers (mid_o_buf, output_buf, lse_buf) from the layer object, nor does it pass the buf_holder or max_num_kv_splits parameters to triton_turboquant_decode_attention.

This will cause the kernel to allocate new tensors on every call, which is a significant performance overhead in the hot path and, more importantly, breaks CUDA graph compatibility because dynamic allocations are not allowed during graph replay. Since this PR specifically aims to restore FULL_AND_PIECEWISE CUDA graph support, ensuring static buffer reuse is critical.

        if _spec_verify_eligible:
            K_PLUS_1 = attn_metadata.max_query_len
            B = N // K_PLUS_1
            if attn_metadata.query_start_loc.shape[0] == B + 1:
                from vllm.v1.attention.ops.triton_turboquant_decode import (
                    triton_turboquant_decode_attention,
                )
                # Build synth args mirroring _continuation_prefill's pattern:
                # synth_seq_lens[req*K1+i] = base_seq_lens[req] - K1 + 1 + i
                # synth_block_table[req*K1+i] = block_table[req]
                # All GPU ops — cudagraph-safe.
                _offs = torch.arange(
                    K_PLUS_1, device=q.device,
                    dtype=attn_metadata.seq_lens.dtype,
                )
                _synth_seq_lens = (
                    attn_metadata.seq_lens[:B, None] - K_PLUS_1 + 1 + _offs[None, :]
                ).reshape(-1)
                _synth_block_table = attn_metadata.block_table[:B].repeat_interleave(
                    K_PLUS_1, dim=0,
                )

                # Reuse cached decode buffers from the layer to avoid re-allocation
                # and ensure CUDA graph compatibility.
                mid_o_buf = getattr(layer, "_tq_mid_o_buf", None)
                output_buf = getattr(layer, "_tq_output_buf", None)
                lse_buf = getattr(layer, "_tq_lse_buf", None)

                attn_out = triton_turboquant_decode_attention(
                    query=q,
                    kv_cache=kv_cache,
                    block_table=_synth_block_table,
                    seq_lens=_synth_seq_lens,
                    Pi=Pi,
                    centroids=centroids,
                    scale=self.scale,
                    mse_bits=self.tq_config.key_mse_bits,
                    key_packed_size=self.tq_config.key_packed_size,
                    value_quant_bits=self.tq_config.effective_value_quant_bits,
                    key_fp8=self.tq_config.key_fp8,
                    norm_correction=self.tq_config.norm_correction,
                    PiT=PiT,
                    mid_o_buf=mid_o_buf,
                    output_buf=output_buf,
                    lse_buf=lse_buf,
                    buf_holder=layer,
                    max_num_kv_splits=self.max_num_kv_splits,
                )
                return attn_out


if not attn_metadata.is_prefill:
# Pure decode batch — fast path
attn_out = self._decode_attention(
Expand Down
Loading