-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[Bugfix][Spec-Decode] TurboQuant K+1 spec-verify routing (fixes #40880) #40914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Sandermage
wants to merge
2
commits into
vllm-project:main
Choose a base branch
from
Sandermage:genesis-p67-multi-query-spec-decode-kernel
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+271
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Tests for TurboQuant K+1 spec-verify routing fix (#40880). | ||
|
|
||
| Verifies that uniform-query batches with max_query_len > 1 (typical for | ||
| MTP num_speculative_tokens=K, where verify produces K+1 query length) are | ||
| routed through `triton_turboquant_decode_attention` instead of the | ||
| default `_prefill_attention` continuation branch. | ||
|
|
||
| The default branch contains a `query_start_loc.tolist()` GPU→CPU sync | ||
| that is incompatible with active CUDA stream capture and was the root | ||
| cause of the degenerate-token cascade reported in #40880. | ||
|
|
||
| These tests use the `synth_seq_lens` trick to construct the routing | ||
| arguments and verify shape, dtype, and cudagraph-safety properties. | ||
| A full end-to-end correctness test against the unpatched continuation | ||
| path requires GPU + a TurboQuant model checkpoint and is gated under | ||
| `@pytest.mark.cuda` + skip-if-no-tq-model. | ||
| """ | ||
| from __future__ import annotations | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
|
|
||
| pytestmark = pytest.mark.skipif( | ||
| not torch.cuda.is_available(), | ||
| reason="TurboQuant K+1 spec-verify routing requires CUDA", | ||
| ) | ||
|
|
||
|
|
||
| def _synth_args(batch_size: int, k_plus_1: int, base_seq_lens: torch.Tensor, | ||
| block_table: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Helper: build synth_seq_lens + synth_block_table for K+1 verify routing. | ||
|
|
||
| Mirrors the pattern used in `TurboQuantAttentionImpl.forward()`. | ||
| """ | ||
| device = base_seq_lens.device | ||
| offs = torch.arange(k_plus_1, device=device, dtype=base_seq_lens.dtype) | ||
| synth_seq_lens = ( | ||
| base_seq_lens[:batch_size, None] - k_plus_1 + 1 + offs[None, :] | ||
| ).reshape(-1) | ||
| synth_block_table = block_table[:batch_size].repeat_interleave( | ||
| k_plus_1, dim=0, | ||
| ) | ||
| return synth_seq_lens, synth_block_table | ||
|
|
||
|
|
||
| def test_synth_seq_lens_shape(): | ||
| """synth_seq_lens must be (B*K_PLUS_1,) and equal expected pattern.""" | ||
| device = torch.device("cuda") | ||
| B = 2 | ||
| K_PLUS_1 = 4 | ||
| base_seq_lens = torch.tensor([100, 200], dtype=torch.int32, device=device) | ||
| block_table = torch.zeros((B, 32), dtype=torch.int32, device=device) | ||
| block_table[0, :] = torch.arange(32, device=device) | ||
| block_table[1, :] = torch.arange(100, 132, device=device) | ||
|
|
||
| synth_seq_lens, synth_block_table = _synth_args(B, K_PLUS_1, base_seq_lens, block_table) | ||
|
|
||
| # Shapes | ||
| assert synth_seq_lens.shape == (B * K_PLUS_1,), \ | ||
| f"expected ({B*K_PLUS_1},), got {synth_seq_lens.shape}" | ||
| assert synth_block_table.shape == (B * K_PLUS_1, 32), \ | ||
| f"expected ({B*K_PLUS_1}, 32), got {synth_block_table.shape}" | ||
|
|
||
| # Per-request synth_seq_lens pattern: base - K1 + 1, base - K1 + 2, ..., base | ||
| # For req 0 (base=100): 97, 98, 99, 100 | ||
| # For req 1 (base=200): 197, 198, 199, 200 | ||
| expected = torch.tensor( | ||
| [97, 98, 99, 100, 197, 198, 199, 200], | ||
| dtype=torch.int32, device=device, | ||
| ) | ||
| assert torch.equal(synth_seq_lens, expected), \ | ||
| f"synth_seq_lens mismatch:\nexpected={expected.tolist()}\ngot={synth_seq_lens.tolist()}" | ||
|
|
||
| # Per-request block table is replicated K_PLUS_1 times | ||
| for req in range(B): | ||
| for offset in range(K_PLUS_1): | ||
| assert torch.equal( | ||
| synth_block_table[req * K_PLUS_1 + offset], | ||
| block_table[req], | ||
| ), f"block table replication mismatch at req={req} offset={offset}" | ||
|
|
||
|
|
||
| def test_synth_dtypes_preserved(): | ||
| """Synth args must preserve the dtype of source seq_lens / block_table.""" | ||
| device = torch.device("cuda") | ||
| for seq_dtype in (torch.int32, torch.int64): | ||
| base_seq_lens = torch.tensor([50], dtype=seq_dtype, device=device) | ||
| block_table = torch.zeros((1, 4), dtype=torch.int32, device=device) | ||
| synth_seq_lens, synth_block_table = _synth_args(1, 4, base_seq_lens, block_table) | ||
| assert synth_seq_lens.dtype == seq_dtype | ||
| assert synth_block_table.dtype == torch.int32 | ||
|
|
||
|
|
||
| def test_synth_construction_no_cpu_sync(): | ||
| """Synth construction must be entirely on-GPU (no .item() / .tolist() sync). | ||
|
|
||
| This is the property that makes the routing safe under cudagraph capture. | ||
| We verify by checking that the operations are purely tensor ops with no | ||
| Python control flow that depends on tensor values. | ||
| """ | ||
| device = torch.device("cuda") | ||
| base_seq_lens = torch.tensor([100, 200, 300], dtype=torch.int32, device=device) | ||
| block_table = torch.zeros((3, 16), dtype=torch.int32, device=device) | ||
|
|
||
| # Run inside a stream-captured region — should NOT raise | ||
| g = torch.cuda.CUDAGraph() | ||
| static_input_seq_lens = base_seq_lens.clone() | ||
| static_input_block_table = block_table.clone() | ||
| # Warmup | ||
| _ = _synth_args(3, 4, static_input_seq_lens, static_input_block_table) | ||
| torch.cuda.synchronize() | ||
|
|
||
| # Capture | ||
| with torch.cuda.graph(g): | ||
| _ = _synth_args(3, 4, static_input_seq_lens, static_input_block_table) | ||
|
|
||
| # If we got here without exception, synth_args is cudagraph-safe. | ||
| # Replay should also work | ||
| g.replay() | ||
| torch.cuda.synchronize() | ||
|
|
||
|
|
||
| def test_eligibility_predicate(): | ||
| """Verify the dispatch predicate matches expected K+1 spec-verify shape.""" | ||
| # Mock metadata fields the predicate checks | ||
| class FakeMeta: | ||
| is_prefill: bool | ||
| num_decodes: int | ||
| max_query_len: int | ||
| max_seq_len: int | ||
| query_start_loc: torch.Tensor | ||
|
|
||
| # Eligible: K+1=4, has prior cache, batch divisible | ||
| m = FakeMeta() | ||
| m.is_prefill = True | ||
| m.num_decodes = 0 | ||
| m.max_query_len = 4 | ||
| m.max_seq_len = 1024 | ||
| m.query_start_loc = torch.zeros(3, dtype=torch.int32) # B=2, B+1 = 3 | ||
| N = 8 # = B*K1 = 2*4 | ||
| eligible = ( | ||
| m.is_prefill and m.num_decodes == 0 | ||
| and 1 < m.max_query_len <= 16 | ||
| and m.max_seq_len > m.max_query_len | ||
| and N > 0 and N % m.max_query_len == 0 | ||
| and m.query_start_loc is not None | ||
| ) | ||
| assert eligible | ||
|
|
||
| # NOT eligible: pure decode (max_query_len == 1) | ||
| m.max_query_len = 1 | ||
| eligible = ( | ||
| m.is_prefill and m.num_decodes == 0 | ||
| and 1 < m.max_query_len <= 16 | ||
| ) | ||
| assert not eligible | ||
|
|
||
| # NOT eligible: no prior cache (max_seq_len == max_query_len, fresh prefill) | ||
| m.max_query_len = 4 | ||
| m.max_seq_len = 4 | ||
| eligible = ( | ||
| m.is_prefill and m.num_decodes == 0 | ||
| and 1 < m.max_query_len <= 16 | ||
| and m.max_seq_len > m.max_query_len | ||
| ) | ||
| assert not eligible | ||
|
|
||
| # NOT eligible: K+1 too large (>16, e.g., wrong spec-decode tree depth) | ||
| m.max_query_len = 32 | ||
| m.max_seq_len = 1024 | ||
| eligible = ( | ||
| m.is_prefill and m.num_decodes == 0 | ||
| and 1 < m.max_query_len <= 16 | ||
| ) | ||
| assert not eligible | ||
|
|
||
|
|
||
| # End-to-end correctness test (requires Qwen3.6-A3B-FP8 checkpoint + TQ model) | ||
| # would go here, gated under @pytest.mark.gpu + skip-if-no-model. Pre-flight | ||
| # check: this PR does not include such a model in CI; the empirical TPS data | ||
| # (75.6 vs 57.2 tok/s, +32%) is documented in the PR body and was measured | ||
| # on Sandermage/genesis-vllm-patches by the contributor. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new routing path for speculative verification does not reuse the cached decode buffers (
mid_o_buf,output_buf,lse_buf) from thelayerobject, nor does it pass thebuf_holderormax_num_kv_splitsparameters totriton_turboquant_decode_attention.This will cause the kernel to allocate new tensors on every call, which is a significant performance overhead in the hot path and, more importantly, breaks CUDA graph compatibility because dynamic allocations are not allowed during graph replay. Since this PR specifically aims to restore
FULL_AND_PIECEWISECUDA graph support, ensuring static buffer reuse is critical.