diff --git a/.buildkite/test_areas/lm_eval.yaml b/.buildkite/test_areas/lm_eval.yaml index 39029efe9cd9..a07d702cf3ce 100644 --- a/.buildkite/test_areas/lm_eval.yaml +++ b/.buildkite/test_areas/lm_eval.yaml @@ -91,6 +91,16 @@ steps: - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt +- label: LM Eval TurboQuant KV Cache + timeout_in_minutes: 75 + source_file_dependencies: + - vllm/model_executor/layers/quantization/turboquant/ + - vllm/v1/attention/backends/turboquant_attn.py + - vllm/v1/attention/ops/triton_turboquant_decode.py + - vllm/v1/attention/ops/triton_turboquant_store.py + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/models-turboquant.txt + - label: GPQA Eval (GPT-OSS) (H100) timeout_in_minutes: 120 device: h100 diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 242cc6b3b1ed..65e93657e780 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -178,6 +178,7 @@ Priority is **1 = highest** (tried first). | `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A | | `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | | `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `int8_per_token_head`, `fp8_per_token_head` | %16 | Any | ✅ | ✅ | ❌ | All | Any | +| `TURBOQUANT` | | fp16, bf16 | `turboquant_k8v4`, `turboquant_4bit_nc`, `turboquant_k3v4_nc`, `turboquant_3bit_nc` | 16, 32, 64, 128 | Any | ❌ | ❌ | ❌ | Decoder | Any | > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > diff --git a/pyproject.toml b/pyproject.toml index c4951b1a4ed3..f55dd9308bd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,9 @@ eles = "eles" datas = "datas" ser = "ser" ure = "ure" +# Walsh-Hadamard Transform +wht = "wht" +WHT = "WHT" [tool.uv] no-build-isolation-package = ["torch"] diff --git a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml new file mode 100644 index 000000000000..fedb74169606 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml @@ -0,0 +1,5 @@ +model_name: "Qwen/Qwen3-4B" +accuracy_threshold: 0.78 +num_questions: 1319 +num_fewshot: 5 +server_args: "--kv-cache-dtype turboquant_k3v4_nc --enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml new file mode 100644 index 000000000000..9717333582b3 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml @@ -0,0 +1,5 @@ +model_name: "Qwen/Qwen3-4B" +accuracy_threshold: 0.80 +num_questions: 1319 +num_fewshot: 5 +server_args: "--kv-cache-dtype turboquant_k8v4 --enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml new file mode 100644 index 000000000000..8ece18526257 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml @@ -0,0 +1,5 @@ +model_name: "Qwen/Qwen3-4B" +accuracy_threshold: 0.75 +num_questions: 1319 +num_fewshot: 5 +server_args: "--kv-cache-dtype turboquant_3bit_nc --enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml new file mode 100644 index 000000000000..9b3a14f9b954 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml @@ -0,0 +1,5 @@ +model_name: "Qwen/Qwen3-4B" +accuracy_threshold: 0.80 +num_questions: 1319 +num_fewshot: 5 +server_args: "--kv-cache-dtype turboquant_4bit_nc --enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/models-turboquant.txt b/tests/evals/gsm8k/configs/models-turboquant.txt new file mode 100644 index 000000000000..518aac780b90 --- /dev/null +++ b/tests/evals/gsm8k/configs/models-turboquant.txt @@ -0,0 +1,4 @@ +Qwen3-4B-TQ-k8v4.yaml +Qwen3-4B-TQ-t4nc.yaml +Qwen3-4B-TQ-k3v4nc.yaml +Qwen3-4B-TQ-t3nc.yaml diff --git a/tests/quantization/test_turboquant.py b/tests/quantization/test_turboquant.py new file mode 100644 index 000000000000..78c137e67628 --- /dev/null +++ b/tests/quantization/test_turboquant.py @@ -0,0 +1,570 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for TurboQuant KV-cache quantization. + +Run: .venv/bin/python -m pytest tests/quantization/test_turboquant.py -v +""" + +import math + +import pytest +import torch + +from vllm.model_executor.layers.quantization.turboquant.centroids import ( + get_centroids, + solve_lloyd_max, +) +from vllm.model_executor.layers.quantization.turboquant.config import ( + TQ_PRESETS, + TurboQuantConfig, +) +from vllm.model_executor.layers.quantization.turboquant.quantizer import ( + generate_wht_signs, +) +from vllm.utils.math_utils import next_power_of_2 + +# ============================================================================ +# Helpers +# ============================================================================ + +ALL_PRESETS = list(TQ_PRESETS.keys()) + + +def _assert_strictly_sorted(seq, name="sequence"): + for i in range(len(seq) - 1): + assert seq[i] < seq[i + 1], f"{name} not sorted at index {i}" + + +def _is_power_of_2(n: int) -> bool: + return n > 0 and next_power_of_2(n) == n + + +# Expected concrete values for each preset at head_dim=128. +# fmt: off +PRESET_EXPECTED = { + "turboquant_k8v4": dict( + key_fp8=True, key_quant_bits=8, + key_mse_bits=0, value_quant_bits=4, + mse_bits=4, n_centroids=16, centroid_bits=4, + norm_correction=False, + key_packed_size=128, value_packed_size=68, + slot_size=196, slot_size_aligned=196, + ), + "turboquant_4bit_nc": dict( + key_fp8=False, key_quant_bits=4, + key_mse_bits=4, value_quant_bits=4, + mse_bits=4, n_centroids=16, centroid_bits=4, + norm_correction=True, + key_packed_size=66, value_packed_size=68, + slot_size=134, slot_size_aligned=134, + ), + "turboquant_k3v4_nc": dict( + key_fp8=False, key_quant_bits=3, + key_mse_bits=3, value_quant_bits=4, + mse_bits=3, n_centroids=8, centroid_bits=3, + norm_correction=True, + key_packed_size=50, value_packed_size=68, + slot_size=118, slot_size_aligned=118, + ), + "turboquant_3bit_nc": dict( + key_fp8=False, key_quant_bits=3, + key_mse_bits=3, value_quant_bits=3, + mse_bits=3, n_centroids=8, centroid_bits=3, + norm_correction=True, + key_packed_size=50, value_packed_size=52, + slot_size=102, slot_size_aligned=102, + ), +} +# fmt: on + + +# ============================================================================ +# Config tests (CPU-only, no dependencies beyond config.py) +# ============================================================================ + + +class TestTurboQuantConfig: + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_preset_parses(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + assert isinstance(cfg, TurboQuantConfig) + + def test_invalid_preset_raises(self): + with pytest.raises(ValueError, match="Unknown TurboQuant"): + TurboQuantConfig.from_cache_dtype("turboquant_invalid", head_dim=128) + + # ---- Per-preset concrete value checks (table-driven) ---- + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_key_mode(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + exp = PRESET_EXPECTED[preset] + assert cfg.key_fp8 is exp["key_fp8"] + assert cfg.key_quant_bits == exp["key_quant_bits"] + assert cfg.key_mse_bits == exp["key_mse_bits"] + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_value_mode(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + exp = PRESET_EXPECTED[preset] + assert cfg.value_quant_bits == exp["value_quant_bits"] + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_bits_and_centroids(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + exp = PRESET_EXPECTED[preset] + assert cfg.mse_bits == exp["mse_bits"] + assert cfg.n_centroids == exp["n_centroids"] + assert cfg.centroid_bits == exp["centroid_bits"] + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_norm_correction(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + assert cfg.norm_correction is PRESET_EXPECTED[preset]["norm_correction"] + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_packed_sizes(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + exp = PRESET_EXPECTED[preset] + assert cfg.key_packed_size == exp["key_packed_size"] + assert cfg.value_packed_size == exp["value_packed_size"] + assert cfg.slot_size == exp["slot_size"] + assert cfg.slot_size_aligned == exp["slot_size_aligned"] + + # ---- Cross-preset structural invariants ---- + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_slot_equals_key_plus_value(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + assert cfg.slot_size == cfg.key_packed_size + cfg.value_packed_size + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_padded_slot_is_even(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + assert cfg.slot_size_aligned >= cfg.slot_size + assert cfg.slot_size_aligned % 2 == 0, ( + f"slot_size_aligned={cfg.slot_size_aligned} is not even" + ) + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_key_value_packed_sizes_positive(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + assert cfg.key_packed_size > 0 + assert cfg.value_packed_size > 0 + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_n_centroids_is_2_to_mse_bits(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + assert cfg.n_centroids == 2**cfg.mse_bits + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_centroid_bits_always_positive(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + assert cfg.centroid_bits > 0 + + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_mse_key_or_fp8_exclusive(self, preset): + """Each preset is either FP8 keys or MSE keys, never both.""" + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + if cfg.key_fp8: + assert cfg.key_mse_bits == 0 + assert cfg.key_quant_bits == 8 + else: + assert cfg.key_mse_bits > 0 + assert cfg.key_quant_bits in (3, 4) + + @pytest.mark.parametrize("preset", ALL_PRESETS) + @pytest.mark.parametrize("head_dim", [64, 96, 128, 256]) + def test_all_presets_all_head_dims(self, preset, head_dim): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=head_dim) + assert cfg.head_dim == head_dim + assert cfg.slot_size == cfg.key_packed_size + cfg.value_packed_size + assert cfg.slot_size_aligned >= cfg.slot_size + assert cfg.slot_size_aligned % 2 == 0 + + # ---- Boundary skip layers ---- + + def test_boundary_skip_layers_basic(self): + layers = TurboQuantConfig.get_boundary_skip_layers(32) + assert layers == ["0", "1", "30", "31"] + + def test_boundary_skip_layers_zero(self): + assert TurboQuantConfig.get_boundary_skip_layers(32, 0) == [] + + def test_boundary_skip_layers_small_model(self): + layers = TurboQuantConfig.get_boundary_skip_layers(4) + assert layers == ["0", "1", "2", "3"] + + def test_boundary_skip_layers_cap_at_half(self): + layers = TurboQuantConfig.get_boundary_skip_layers(8, 10) + assert len(layers) == 8 + + +# ============================================================================ +# Centroids tests (CPU-only) +# ============================================================================ + + +class TestCentroids: + @pytest.mark.parametrize("bits,expected_n", [(2, 4), (3, 8), (4, 16)]) + def test_centroids_shape(self, bits, expected_n): + c = get_centroids(128, bits) + assert c.shape == (expected_n,) + + @pytest.mark.parametrize("bits", [2, 3, 4]) + def test_centroids_sorted(self, bits): + _assert_strictly_sorted(get_centroids(128, bits), "centroids") + + def test_centroids_cached(self): + c1 = get_centroids(128, 3) + c2 = get_centroids(128, 3) + assert c1 is c2, "get_centroids should return cached object" + + def test_centroids_different_dims_not_identical(self): + c64 = get_centroids(64, 3) + c128 = get_centroids(128, 3) + assert not torch.equal(c64, c128) + + @pytest.mark.parametrize("bits", [2, 3, 4]) + def test_centroids_symmetric_around_zero(self, bits): + """N(0, 1/d) is symmetric, so centroids should be ~symmetric.""" + c = get_centroids(128, bits) + assert abs(c.mean().item()) < 0.01, "Centroids not centered near 0" + assert abs(c[0].item() + c[-1].item()) < 0.01 + + @pytest.mark.parametrize("bits", [2, 3, 4]) + def test_centroids_within_4sigma(self, bits): + """All centroids should be within ~4 sigma of N(0, 1/d).""" + sigma = math.sqrt(1.0 / 128) + c = get_centroids(128, bits) + for i, val in enumerate(c): + assert abs(val.item()) < 4 * sigma, ( + f"Centroid {i}={val:.6f} outside 4*sigma={4 * sigma:.6f}" + ) + + +class TestLloydMax: + @pytest.mark.parametrize("bits,expected_n", [(2, 4), (3, 8), (4, 16)]) + def test_solve_shapes(self, bits, expected_n): + centroids, boundaries = solve_lloyd_max(128, bits) + assert centroids.shape == (expected_n,) + assert boundaries.shape == (expected_n - 1,) + + @pytest.mark.parametrize("bits", [2, 3, 4]) + def test_centroids_sorted(self, bits): + centroids, _ = solve_lloyd_max(128, bits) + _assert_strictly_sorted(centroids, "centroids") + + @pytest.mark.parametrize("bits", [2, 3, 4]) + def test_boundaries_sorted(self, bits): + _, boundaries = solve_lloyd_max(128, bits) + _assert_strictly_sorted(boundaries, "boundaries") + + @pytest.mark.parametrize("bits", [2, 3, 4]) + def test_boundaries_between_centroids(self, bits): + """Each boundary must lie between its adjacent centroids.""" + centroids, boundaries = solve_lloyd_max(128, bits) + for i in range(len(boundaries)): + assert centroids[i] < boundaries[i] < centroids[i + 1], ( + f"Boundary {i}={boundaries[i]:.6f} not between " + f"c[{i}]={centroids[i]:.6f} and c[{i + 1}]={centroids[i + 1]:.6f}" + ) + + @pytest.mark.parametrize("bits", [2, 3, 4]) + def test_boundaries_are_midpoints(self, bits): + """Lloyd-Max boundaries are midpoints of adjacent centroids.""" + centroids, boundaries = solve_lloyd_max(128, bits) + for i in range(len(boundaries)): + expected = (centroids[i] + centroids[i + 1]) / 2.0 + assert abs(boundaries[i].item() - expected.item()) < 1e-6 + + def test_solve_deterministic(self): + c1, b1 = solve_lloyd_max(128, 3) + c2, b2 = solve_lloyd_max(128, 3) + assert torch.equal(c1, c2) + assert torch.equal(b1, b2) + + def test_solve_dtype_float32(self): + centroids, boundaries = solve_lloyd_max(128, 3) + assert centroids.dtype == torch.float32 + assert boundaries.dtype == torch.float32 + + @pytest.mark.parametrize("bits", [3, 4]) + def test_centroids_match_scipy_reference(self, bits): + """Verify _trapz(n=200) centroids match scipy.integrate.quad reference. + + This ensures our scipy-free trapezoid integration doesn't silently + drift from the published Lloyd-Max quality. + """ + pytest.importorskip("scipy") + from scipy.integrate import quad + + d = 128 + sigma2 = 1.0 / d + sigma = math.sqrt(sigma2) + + def pdf(x): + return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp( + -x * x / (2 * sigma2) + ) + + n_levels = 2**bits + lo, hi = -3.5 * sigma, 3.5 * sigma + ref_centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)] + for _ in range(200): + boundaries = [ + (ref_centroids[i] + ref_centroids[i + 1]) / 2.0 + for i in range(n_levels - 1) + ] + edges = [lo * 3] + boundaries + [hi * 3] + new_centroids = [] + for i in range(n_levels): + a, b = edges[i], edges[i + 1] + num, _ = quad(lambda x: x * pdf(x), a, b) + den, _ = quad(pdf, a, b) + new_centroids.append(num / den if den > 1e-15 else ref_centroids[i]) + if ( + max(abs(new_centroids[i] - ref_centroids[i]) for i in range(n_levels)) + < 1e-10 + ): + break + ref_centroids = new_centroids + + # Compare our _trapz centroids against scipy reference + our_centroids, _ = solve_lloyd_max(d, bits) + ref_t = torch.tensor(ref_centroids, dtype=torch.float32) + max_err = (our_centroids - ref_t).abs().max().item() + # _trapz(n=200) has ~O(h^2) error vs adaptive quad; 1e-3 is tight + # enough to catch regression while allowing trapezoid approximation. + assert max_err < 1e-3, ( + f"d={d}, bits={bits}: max centroid error vs scipy = {max_err:.2e}" + ) + + +# ============================================================================ +# Rotation matrix tests (GPU required) +# ============================================================================ + +CUDA_AVAILABLE = torch.cuda.is_available() + + +def generate_rotation_matrix(d: int, seed: int, device: str = "cpu") -> torch.Tensor: + """Haar-distributed random orthogonal matrix via QR (test/benchmark only).""" + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + G = torch.randn(d, d, generator=gen, device="cpu", dtype=torch.float32) + Q, R = torch.linalg.qr(G) + diag_sign = torch.sign(torch.diag(R)) + diag_sign[diag_sign == 0] = 1.0 + Q = Q * diag_sign.unsqueeze(0) + return Q.to(device) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available") +class TestRotationMatrix: + """Tests for the QR-based rotation (standalone benchmarks only).""" + + @pytest.mark.parametrize("dim", [64, 96, 128, 256]) + def test_rotation_matrix_shape_and_orthogonal(self, dim): + Pi = generate_rotation_matrix(dim, seed=42, device="cuda") + assert Pi.shape == (dim, dim) + eye = Pi @ Pi.T + assert torch.allclose(eye, torch.eye(dim, device="cuda"), atol=1e-5), ( + f"Pi not orthogonal for dim={dim}" + ) + + def test_rotation_matrix_deterministic(self): + Pi1 = generate_rotation_matrix(128, seed=42) + Pi2 = generate_rotation_matrix(128, seed=42) + assert torch.equal(Pi1, Pi2) + + def test_rotation_matrix_different_seeds(self): + Pi1 = generate_rotation_matrix(128, seed=42) + Pi2 = generate_rotation_matrix(128, seed=99) + assert not torch.equal(Pi1, Pi2) + + def test_rotation_matrix_det_is_pm1(self): + """Orthogonal matrix determinant must be +1 or -1.""" + Pi = generate_rotation_matrix(128, seed=42, device="cuda") + det = torch.linalg.det(Pi) + assert abs(abs(det.item()) - 1.0) < 1e-4 + + +# ============================================================================ +# WHT rotation tests (serving path: generate_wht_signs + _build_hadamard) +# ============================================================================ + + +def _build_hadamard(d: int, device: str = "cpu") -> torch.Tensor: + """Reproduce the serving-path Hadamard construction.""" + H = torch.tensor([[1.0]]) + while H.shape[0] < d: + H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0) + return (H / math.sqrt(d)).to(torch.device(device)) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available") +class TestWHTRotation: + """Tests for the WHT rotation actually used in serving.""" + + @pytest.mark.parametrize("dim", [64, 128, 256]) + def test_wht_orthonormal(self, dim): + """signs * H must be orthonormal: (signs*H) @ (signs*H)^T = I.""" + signs = generate_wht_signs(dim, seed=42, device="cuda") + H = _build_hadamard(dim, "cuda") + PiT = (signs.unsqueeze(1) * H).contiguous() + eye = PiT @ PiT.T + assert torch.allclose(eye, torch.eye(dim, device="cuda"), atol=1e-5), ( + f"WHT rotation not orthonormal for dim={dim}" + ) + + @pytest.mark.parametrize("dim", [64, 128, 256]) + def test_wht_self_inverse(self, dim): + """PiT should be self-inverse: PiT @ PiT = I (up to sign flip).""" + signs = generate_wht_signs(dim, seed=42, device="cuda") + H = _build_hadamard(dim, "cuda") + PiT = (signs.unsqueeze(1) * H).contiguous() + Pi = PiT.T.contiguous() + # Pi @ PiT should be identity (rotation then inverse) + result = Pi @ PiT + assert torch.allclose(result, torch.eye(dim, device="cuda"), atol=1e-5), ( + f"WHT rotation not self-inverse for dim={dim}" + ) + + def test_wht_signs_deterministic(self): + """Same seed must produce identical signs.""" + s1 = generate_wht_signs(128, seed=42) + s2 = generate_wht_signs(128, seed=42) + assert torch.equal(s1, s2) + + def test_wht_signs_different_seeds(self): + """Different seeds must produce different signs.""" + s1 = generate_wht_signs(128, seed=42) + s2 = generate_wht_signs(128, seed=99) + assert not torch.equal(s1, s2) + + def test_wht_signs_are_pm1(self): + """All sign values must be exactly +1 or -1.""" + signs = generate_wht_signs(128, seed=42) + assert torch.all(signs.abs() == 1.0) + + +# ============================================================================ +# Store → Decode round-trip test (GPU + Triton required) +# ============================================================================ + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA not available") +class TestStoreDecodeRoundTrip: + """End-to-end: store KV into TQ cache, decode, compare vs fp16 ref.""" + + @pytest.mark.parametrize( + "preset", + ["turboquant_k8v4", "turboquant_4bit_nc"], + ) + def test_single_token_roundtrip(self, preset): + """Store 1 token, decode with query=key, check attention output. + + For a single token with query=key, attention output should equal + the value (softmax over single key = 1.0). Quantization error + means we check cosine similarity rather than exact equality. + """ + from vllm.model_executor.layers.quantization.turboquant.centroids import ( + solve_lloyd_max, + ) + from vllm.v1.attention.ops.triton_turboquant_decode import ( + triton_turboquant_decode_attention, + ) + from vllm.v1.attention.ops.triton_turboquant_store import ( + triton_turboquant_store, + ) + + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + D = 128 + Hk = 4 # num_kv_heads + Hq = 4 # num_q_heads (no GQA for simplicity) + B = 1 # single token + block_size = 16 + num_blocks = 1 + + device = torch.device("cuda") + + # Generate rotation + signs = generate_wht_signs(D, seed=42, device=device) + H = _build_hadamard(D, "cuda") + PiT = (signs.unsqueeze(1) * H).contiguous().float() + Pi = PiT.T.contiguous() + + # Generate centroids + centroids, _ = solve_lloyd_max(D, cfg.centroid_bits) + centroids = centroids.float().to(device) + c_sorted, _ = centroids.sort() + midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device) + + # Random K, V + torch.manual_seed(123) + key = torch.randn(B, Hk, D, device=device, dtype=torch.float16) + value = torch.randn(B, Hk, D, device=device, dtype=torch.float16) + + # Allocate KV cache + padded_slot = cfg.slot_size_aligned + kv_cache = torch.zeros( + num_blocks, + block_size, + Hk, + padded_slot, + device=device, + dtype=torch.uint8, + ) + slot_mapping = torch.tensor([0], device=device, dtype=torch.int32) + + # Store + triton_turboquant_store( + key, + value, + kv_cache, + slot_mapping, + PiT, + midpoints, + mse_bits=cfg.key_mse_bits, + key_packed_size=cfg.key_packed_size, + value_quant_bits=cfg.effective_value_quant_bits, + key_fp8=cfg.key_fp8, + ) + + # Decode: use key as query so attention = softmax([1]) * V = V + query = key.expand(B, Hq, D).contiguous().to(torch.float16) + block_table = torch.tensor([[0]], device=device, dtype=torch.int32) + seq_lens = torch.tensor([1], device=device, dtype=torch.int32) + + output = triton_turboquant_decode_attention( + query=query, + kv_cache=kv_cache, + block_table=block_table, + seq_lens=seq_lens, + Pi=Pi, + centroids=centroids, + scale=1.0 / math.sqrt(D), + mse_bits=cfg.key_mse_bits, + key_packed_size=cfg.key_packed_size, + value_quant_bits=cfg.effective_value_quant_bits, + key_fp8=cfg.key_fp8, + norm_correction=cfg.norm_correction, + PiT=PiT, + max_num_kv_splits=4, + ) + + # With single KV, output should approximate the stored value. + # Check per-head cosine similarity > threshold. + out_fp32 = output.float() + val_fp32 = value.expand(B, Hq, D).float() + for h in range(Hq): + cos_sim = torch.nn.functional.cosine_similarity( + out_fp32[0, h].unsqueeze(0), + val_fp32[0, h].unsqueeze(0), + ).item() + # FP8 keys should be very accurate; MSE keys have more error + threshold = 0.95 if cfg.key_fp8 else 0.85 + assert cos_sim > threshold, ( + f"Preset {preset} head {h}: cosine_sim={cos_sim:.4f} < {threshold}" + ) diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 1da647a6d6ff..561367173d5f 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -27,6 +27,11 @@ class AttentionConfig: flash_attn_max_num_splits_for_cuda_graph: int = 32 """Flash Attention max number splits for cuda graph decode.""" + tq_max_kv_splits_for_cuda_graph: int = 32 + """TurboQuant max NUM_KV_SPLITS for cuda graph decode. + Fixes the split count so grid dimensions are constant across captures, + and buffers can be pre-allocated to avoid inflating the memory estimate.""" + use_cudnn_prefill: bool = False """Whether to use cudnn prefill.""" diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 20721cc80923..47a655f22d53 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -24,6 +24,10 @@ "fp8_e5m2", "fp8_inc", "fp8_ds_mla", + "turboquant_k8v4", + "turboquant_4bit_nc", + "turboquant_k3v4_nc", + "turboquant_3bit_nc", "int8_per_token_head", "fp8_per_token_head", ] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 03a460fbe95a..7028b12dab32 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1642,6 +1642,31 @@ def create_engine_config( kv_offloading_backend=self.kv_offloading_backend, ) + # TurboQuant: auto-skip first/last 2 layers (boundary protection). + # These layers are most sensitive to quantization error. + # Users can add extra layers via --kv-cache-dtype-skip-layers. + if resolved_cache_dtype.startswith("turboquant_"): + if model_config.is_hybrid: + raise NotImplementedError( + "TurboQuant KV cache is not supported for hybrid " + "(attention + Mamba) models. Boundary layer protection " + "requires uniform attention layers." + ) + from vllm.model_executor.layers.quantization.turboquant.config import ( + TurboQuantConfig, + ) + + num_layers = model_config.hf_text_config.num_hidden_layers + boundary = TurboQuantConfig.get_boundary_skip_layers(num_layers) + existing = set(cache_config.kv_cache_dtype_skip_layers) + merged = sorted(existing | set(boundary), key=lambda x: int(x)) + cache_config.kv_cache_dtype_skip_layers = merged + logger.info( + "TQ: skipping layers %s for boundary protection (num_layers=%d)", + merged, + num_layers, + ) + ray_runtime_env = None if is_ray_initialized(): # Ray Serve LLM calls `create_engine_config` in the context @@ -1948,6 +1973,19 @@ def create_engine_config( self.attention_backend ) + # TurboQuant requires FlashAttention 2 — FA3 boundary layers assert + # FlashAttentionImpl which fails with TurboQuantAttentionImpl. + if resolved_cache_dtype.startswith("turboquant_") and ( + attention_config.flash_attn_version is None + or attention_config.flash_attn_version >= 3 + ): + logger.warning( + "TurboQuant is not yet compatible with FlashAttention >= 3. " + "Overriding flash_attn_version to 2. To silence this " + "warning, pass --attention-config.flash_attn_version=2" + ) + attention_config.flash_attn_version = 2 + # Mamba config overrides mamba_config = copy.deepcopy(self.mamba_config) # Convert string to enum if needed (CLI parsing returns a string) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 5e8825e2baf6..a92e2f4ad188 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -379,6 +379,10 @@ def __init__( # Initialize KV cache quantization attributes _init_kv_cache_quant(self, quant_config, prefix) + # Initialize TurboQuant buffers (Pi, S, centroids) if tq cache dtype + if kv_cache_dtype.startswith("turboquant_"): + self._init_turboquant_buffers(kv_cache_dtype, head_size, prefix) + # for attn backends supporting query quantization self.query_quant = None if ( @@ -397,6 +401,67 @@ def __init__( else GroupShape.PER_TENSOR, ) + def _init_turboquant_buffers( + self, cache_dtype: str, head_size: int, prefix: str + ) -> None: + """Initialize TurboQuant rotation/projection matrices and centroids.""" + from vllm.model_executor.layers.quantization.turboquant.centroids import ( + get_centroids, + ) + from vllm.model_executor.layers.quantization.turboquant.config import ( + TurboQuantConfig, + ) + from vllm.model_executor.layers.quantization.turboquant.quantizer import ( + generate_wht_signs, + ) + + tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size) + + # Each layer needs a unique rotation matrix so quantization errors + # don't correlate across layers. Stride must exceed max head_dim to + # ensure non-overlapping RNG streams between adjacent layers. + _TQ_LAYER_SEED_STRIDE = 1337 + + from vllm.model_executor.models.utils import extract_layer_index + + layer_idx = extract_layer_index(prefix) + seed = tq_config.seed + layer_idx * _TQ_LAYER_SEED_STRIDE + + self.register_buffer( + "_tq_signs", + generate_wht_signs(head_size, seed=seed), + ) + self.register_buffer( + "_tq_centroids", + get_centroids(head_size, tq_config.centroid_bits), + ) + self._tq_config = tq_config + + # Pre-allocate decode intermediate buffers so model.to(device) moves + # them to GPU *before* the memory profiler runs. Without this the + # profiler gives all free memory to KV cache blocks and the first + # decode OOMs when these buffers are lazily allocated. + _vllm_cfg = get_current_vllm_config() + B = _vllm_cfg.scheduler_config.max_num_seqs + Hq = self.num_heads + S = _vllm_cfg.attention_config.tq_max_kv_splits_for_cuda_graph + D = head_size + self.register_buffer( + "_tq_mid_o_buf", + torch.empty(B, Hq, S, D + 1, dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "_tq_output_buf", + torch.empty(B, Hq, D, dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "_tq_lse_buf", + torch.empty(B, Hq, dtype=torch.float32), + persistent=False, + ) + def forward( self, query: torch.Tensor, @@ -544,6 +609,23 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: kv_quant_mode=quant_mode, sliding_window=self.sliding_window, ) + elif self.kv_cache_dtype.startswith("turboquant_"): + from vllm.model_executor.layers.quantization.turboquant.config import ( + TurboQuantConfig, + ) + from vllm.v1.kv_cache_interface import TQFullAttentionSpec + + tq_config = TurboQuantConfig.from_cache_dtype( + self.kv_cache_dtype, self.head_size + ) + return TQFullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size, + dtype=self.kv_cache_torch_dtype, + tq_slot_size=tq_config.slot_size_aligned, + ) else: return FullAttentionSpec( block_size=block_size, diff --git a/vllm/model_executor/layers/quantization/turboquant/__init__.py b/vllm/model_executor/layers/quantization/turboquant/__init__.py new file mode 100644 index 000000000000..10ee032c9ecf --- /dev/null +++ b/vllm/model_executor/layers/quantization/turboquant/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""TurboQuant: Near-optimal KV-cache quantization for vLLM. + +PolarQuant compression: random rotation + per-coordinate Lloyd-Max +scalar quantization for keys, uniform quantization for values. + +Reference: "TurboQuant: Online Vector Quantization with Near-optimal +Distortion Rate" (ICLR 2026), Zandieh et al. +""" + +from vllm.model_executor.layers.quantization.turboquant.config import TurboQuantConfig + +__all__ = ["TurboQuantConfig"] diff --git a/vllm/model_executor/layers/quantization/turboquant/centroids.py b/vllm/model_executor/layers/quantization/turboquant/centroids.py new file mode 100644 index 000000000000..490265747c5b --- /dev/null +++ b/vllm/model_executor/layers/quantization/turboquant/centroids.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Lloyd-Max optimal scalar quantizer for TurboQuant. + +After rotating a d-dimensional unit vector by a random orthogonal matrix, +each coordinate approximately follows N(0, 1/d) for d >= 64. +We solve the Lloyd-Max conditions to find optimal centroids. + +Based on: turboquant-pytorch/lloyd_max.py (Zandieh et al.) +""" + +import math +from functools import lru_cache + +import torch + + +def _gaussian_pdf(x: float, sigma2: float) -> float: + return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2)) + + +def _trapz(f, a: float, b: float, n: int = 200) -> float: + """Trapezoidal numerical integration (replaces scipy.integrate.quad).""" + h = (b - a) / n + result = 0.5 * (f(a) + f(b)) + for i in range(1, n): + result += f(a + i * h) + return result * h + + +def solve_lloyd_max( + d: int, + bits: int, + max_iter: int = 200, + tol: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + """Solve Lloyd-Max optimal quantizer for N(0, 1/d) distribution. + + Args: + d: Vector dimension (determines variance = 1/d). + bits: Number of quantization bits. + max_iter: Maximum Lloyd-Max iterations. + tol: Convergence tolerance. + + Returns: + centroids: Sorted tensor of 2^bits optimal centroids. + boundaries: Sorted tensor of 2^bits - 1 decision boundaries. + """ + n_levels = 2**bits + sigma2 = 1.0 / d + sigma = math.sqrt(sigma2) + + def pdf(x): + return _gaussian_pdf(x, sigma2) + + lo, hi = -3.5 * sigma, 3.5 * sigma + centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)] + + for _ in range(max_iter): + boundaries = [ + (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1) + ] + edges = [lo * 3] + boundaries + [hi * 3] + new_centroids = [] + for i in range(n_levels): + a, b = edges[i], edges[i + 1] + num = _trapz(lambda x: x * pdf(x), a, b) + den = _trapz(pdf, a, b) + new_centroids.append(num / den if den > 1e-15 else centroids[i]) + + if max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) < tol: + break + centroids = new_centroids + + boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)] + return ( + torch.tensor(centroids, dtype=torch.float32), + torch.tensor(boundaries, dtype=torch.float32), + ) + + +@lru_cache(maxsize=32) +def get_centroids(d: int, bits: int) -> torch.Tensor: + """Get precomputed Lloyd-Max centroids (cached).""" + centroids, _ = solve_lloyd_max(d, bits) + return centroids diff --git a/vllm/model_executor/layers/quantization/turboquant/config.py b/vllm/model_executor/layers/quantization/turboquant/config.py new file mode 100644 index 000000000000..289bed120773 --- /dev/null +++ b/vllm/model_executor/layers/quantization/turboquant/config.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""TurboQuant configuration.""" + +import math +from dataclasses import dataclass + +# Named TQ presets: each maps to frozen config parameters. +# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys. +# value_quant_bits: 3-4 = uniform quantized values. +TQ_PRESETS: dict[str, dict] = { + "turboquant_k8v4": { + "key_quant_bits": 8, + "value_quant_bits": 4, + "norm_correction": False, + }, + "turboquant_4bit_nc": { + "key_quant_bits": 4, + "value_quant_bits": 4, + "norm_correction": True, + }, + "turboquant_k3v4_nc": { + "key_quant_bits": 3, + "value_quant_bits": 4, + "norm_correction": True, + }, + "turboquant_3bit_nc": { + "key_quant_bits": 3, + "value_quant_bits": 3, + "norm_correction": True, + }, +} + + +@dataclass +class TurboQuantConfig: + """Configuration for TurboQuant KV-cache quantization. + + Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys + and uniform quantization for values. QJL is intentionally omitted — + community consensus (5+ independent groups) found it hurts attention + quality by amplifying variance through softmax. + + Named presets (use via --kv-cache-dtype): + turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL + turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71% + turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63% + turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59% + + Args: + head_dim: Attention head dimension (e.g. 64, 96, 128). + key_quant_bits: Bits for key quantization. 8 = FP8 keys (no + rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys. + value_quant_bits: Bits per value dimension for uniform quantization. + 3 = 8 levels, 4 = 16 levels (default). + seed: Base seed for deterministic random matrix generation. + Actual seed per layer = seed + layer_idx * 1337. + norm_correction: Re-normalize centroid vectors to unit norm before + inverse rotation during dequant. Fixes quantization-induced norm + distortion, improving PPL by ~0.8% at 4-bit. + """ + + head_dim: int = 128 + key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys + value_quant_bits: int = 4 # 3-4 = uniform quantized values + seed: int = 42 + norm_correction: bool = False + + @property + def key_fp8(self) -> bool: + """Whether keys are stored as FP8 — no rotation/quantization needed.""" + return self.key_quant_bits == 8 + + @property + def mse_bits(self) -> int: + """MSE quantizer bit-width (determines centroid count: 2^mse_bits). + + For MSE key modes, equals key_quant_bits. + For FP8 key mode, falls back to value_quant_bits (centroids are still + needed for continuation-prefill dequant and decode kernel params). + """ + if self.key_fp8: + return self.value_quant_bits + return self.key_quant_bits + + @property + def key_mse_bits(self) -> int: + """MSE bits actually used for key quantization (0 if FP8 keys).""" + if self.key_fp8: + return 0 + return self.key_quant_bits + + @property + def centroid_bits(self) -> int: + """Bits for centroid generation — always non-zero.""" + return self.mse_bits + + @property + def n_centroids(self) -> int: + return 2**self.mse_bits + + @property + def key_packed_size(self) -> int: + """Packed bytes for a single KEY vector. + + FP8 mode (key_quant_bits=8): + head_dim bytes (1 byte per element, no overhead). + + TQ mode: + - MSE indices: ceil(head_dim * key_mse_bits / 8) bytes + - vec_norm: 2 bytes (float16) + """ + if self.key_fp8: + return self.head_dim # 1 byte per element + mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8) + norm_bytes = 2 # vec_norm fp16 + return mse_bytes + norm_bytes + + @property + def effective_value_quant_bits(self) -> int: + """Actual bits used for value storage.""" + return self.value_quant_bits + + @property + def value_packed_size(self) -> int: + """Packed bytes for a single VALUE vector. + + Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16). + """ + data_bytes = math.ceil(self.head_dim * self.value_quant_bits / 8) + return data_bytes + 4 # +2 scale(fp16) +2 zero(fp16) + + @property + def slot_size(self) -> int: + """Total packed bytes per head per position (key + value combined). + + Layout: [key_packed | value_packed] + """ + return self.key_packed_size + self.value_packed_size + + @property + def slot_size_aligned(self) -> int: + """Slot size rounded up to next even number. + + Even-number is required so effective_head_size = slot_size_aligned // 2 + is integral. + """ + s = self.slot_size + return s + (s % 2) # round up to even + + @staticmethod + def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]: + """Get layer indices to skip TQ compression (boundary protection). + + Returns first N and last N layer indices as strings, suitable for + kv_cache_dtype_skip_layers. + """ + if n <= 0 or num_layers <= 0: + return [] + n = min(n, num_layers // 2) # don't skip more than half + first = list(range(n)) + last = list(range(num_layers - n, num_layers)) + # Deduplicate (if num_layers <= 2*n) + indices = sorted(set(first + last)) + return [str(i) for i in indices] + + @staticmethod + def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig": + """Create config from a named preset. + + Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc. + """ + if cache_dtype not in TQ_PRESETS: + valid = ", ".join(TQ_PRESETS.keys()) + raise ValueError( + f"Unknown TurboQuant cache dtype: {cache_dtype!r}. " + f"Valid presets: {valid}" + ) + preset = TQ_PRESETS[cache_dtype] + return TurboQuantConfig( + head_dim=head_dim, + key_quant_bits=preset["key_quant_bits"], + value_quant_bits=preset["value_quant_bits"], + norm_correction=preset["norm_correction"], + ) diff --git a/vllm/model_executor/layers/quantization/turboquant/quantizer.py b/vllm/model_executor/layers/quantization/turboquant/quantizer.py new file mode 100644 index 000000000000..aea63c52bacf --- /dev/null +++ b/vllm/model_executor/layers/quantization/turboquant/quantizer.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""TurboQuant quantizer utilities. + +Serving path uses generate_wht_signs() for WHT rotation sign buffers. +Triton kernels handle all quantization, packing, and dequantization on GPU. +""" + +import torch + +_CPU = torch.device("cpu") + + +def generate_wht_signs(d: int, seed: int, device: torch.device = _CPU) -> torch.Tensor: + """Generate deterministic random ±1 signs for WHT rotation. + + Used with Walsh-Hadamard Transform for per-layer rotation randomization. + Same seed derivation as QR (per-layer via seed + layer_idx * stride). + """ + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + bits = torch.randint(0, 2, (d,), generator=gen, device="cpu") + signs = bits.float() * 2 - 1 + return signs.to(device) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 045298dbb36a..bbbd9af0cf71 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -255,6 +255,11 @@ def get_valid_backends( valid_backends_priorities = [] invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {} + # TurboQuant KV cache: route directly to TQ backend + kv_cache_dtype = attn_selector_config.kv_cache_dtype + if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"): + return [(AttentionBackendEnum.TURBOQUANT, 0)], {} + backend_priorities = _get_backend_priorities( attn_selector_config.use_mla, device_capability, diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 892fc0c950be..f9a4f8bd7320 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -61,6 +61,12 @@ def get_attn_backend_cls( "only NHD layout is supported by XPU attention kernels." ) + # TurboQuant KV cache: route directly to TQ backend + kv_cache_dtype = attn_selector_config.kv_cache_dtype + if kv_cache_dtype is not None and kv_cache_dtype.startswith("turboquant_"): + logger.info_once("Using TurboQuant attention backend.") + return AttentionBackendEnum.TURBOQUANT.get_path() + dtype = attn_selector_config.dtype if attn_selector_config.use_sparse: logger.info_once("Using XPU MLA Sparse backend.") diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 60b40855b5d3..26e377de69cb 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -42,6 +42,10 @@ "fp8_per_token_head": torch.uint8, "fp8_inc": torch.float8_e4m3fn, "fp8_ds_mla": torch.uint8, + "turboquant_k8v4": torch.uint8, + "turboquant_4bit_nc": torch.uint8, + "turboquant_k3v4_nc": torch.uint8, + "turboquant_3bit_nc": torch.uint8, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/v1/attention/backends/registry.py b/vllm/v1/attention/backends/registry.py index 4744ead4f54b..f31edfafc38a 100644 --- a/vllm/v1/attention/backends/registry.py +++ b/vllm/v1/attention/backends/registry.py @@ -82,6 +82,7 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "RocmAiterUnifiedAttentionBackend" ) CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend" + TURBOQUANT = "vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend" # Placeholder for third-party/custom backends - must be registered before use # set to None to avoid alias with other backend, whose value is an empty string CUSTOM = None diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py new file mode 100644 index 000000000000..279fcb04ace4 --- /dev/null +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -0,0 +1,812 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""TurboQuant attention backend for vLLM. + +Prefill: Standard scaled dot-product attention on uncompressed K/V, + then quantize K and store K+V into combined cache slot. +Decode: Compute TQ attention scores from compressed cache, + unpack FP16 values, softmax + weighted sum. + +Cache layout (no leading 2 dimension): + (num_blocks, block_size, num_kv_heads, slot_size) + where slot_size = key_packed_size + value_fp16_size + +Per-head per-position slot layout: + [key_packed (kps bytes) | value_fp16 (D*2 bytes)] + For turboquant_k3v4_nc head_dim=256: [100 bytes key | 512 bytes value] = 612 +""" + +import functools +import math +from dataclasses import dataclass +from typing import Any, ClassVar + +import torch +import torch.nn.functional as F + +from vllm.config import get_current_vllm_config +from vllm.config.cache import CacheDType +from vllm.triton_utils import triton +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionCGSupport, + AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + CommonAttentionMetadata, + MultipleOf, +) +from vllm.v1.attention.backends.fa_utils import ( + is_flash_attn_varlen_func_available, +) +from vllm.v1.attention.backends.utils import split_decodes_and_prefills +from vllm.v1.attention.ops.triton_turboquant_decode import ( + _tq_full_dequant_kv, + _use_fp8_e4b15, + triton_turboquant_decode_attention, +) +from vllm.v1.attention.ops.triton_turboquant_store import triton_turboquant_store + +_HAS_FLASH_ATTN = is_flash_attn_varlen_func_available() +if _HAS_FLASH_ATTN: + from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func + +# Continuation prefill: for small continuation chunks (q_len ≤ threshold), +# use the TQ decode kernel directly instead of full-dequant + flash_attn. +# do_kv_cache_update already stored all tokens to TQ cache, so the decode +# kernel can read them efficiently. This avoids O(cached_len) dequant work +# per continuation, eliminating the O(N²/chunk_size) collapse at long context. +_CONTINUATION_DECODE_THRESHOLD = 128 + + +def _build_hadamard(d: int, device_str: str) -> torch.Tensor: + """Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device). + + Precomputed D×D matrix enables matmul-based WHT — single cuBLAS GEMM + instead of log2(D) butterfly kernel launches. 64KB for D=128. + """ + # Normalize device string so "cuda" and "cuda:0" hit the same cache entry. + return _build_hadamard_cached(d, str(torch.device(device_str))) + + +@functools.cache +def _build_hadamard_cached(d: int, device_str: str) -> torch.Tensor: + H = torch.tensor([[1.0]]) + while H.shape[0] < d: + H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0) + return (H / math.sqrt(d)).to(torch.device(device_str)) + + +class TurboQuantAttentionBackend(AttentionBackend): + """Attention backend using TurboQuant KV-cache compression.""" + + accept_output_buffer: bool = True + forward_includes_kv_cache_update: bool = False + + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "turboquant_k8v4", + "turboquant_4bit_nc", + "turboquant_k3v4_nc", + "turboquant_3bit_nc", + ] + + @staticmethod + def get_name() -> str: + return "TURBOQUANT" + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [16, 32, 64, 128] + + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + return attn_type == AttentionType.DECODER + + @classmethod + def supports_per_head_quant_scales(cls) -> bool: + return False + + @staticmethod + def get_impl_cls() -> type["TurboQuantAttentionImpl"]: + return TurboQuantAttentionImpl + + @staticmethod + def get_builder_cls() -> type["TurboQuantMetadataBuilder"]: + return TurboQuantMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "turboquant_4bit_nc", + ) -> tuple[int, ...]: + """Combined K+V cache shape — no leading 2 dimension. + + Standard attention backends use (2, num_blocks, block_size, num_kv_heads, + head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V + into a single interleaved slot per head per position, so the cache is: + + (num_blocks, block_size, num_kv_heads, slot_size_aligned) + + Each slot = [key_packed | value_packed | padding]. + This is safe because TQ has its own get_kv_cache_shape override and + never shares cache tensors with other backends. Layers that fall back + to native dtype via kv_cache_dtype_skip_layers get their own + standard-shaped cache allocation. + + head_size is the model's real head_dim. slot_size_aligned is computed + from the TQ config to ensure correct cache allocation for all head dims. + """ + from vllm.model_executor.layers.quantization.turboquant.config import ( + TurboQuantConfig, + ) + + tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype_str, head_size) + return (num_blocks, block_size, num_kv_heads, tq_config.slot_size_aligned) + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return False + return kv_cache_dtype.startswith("turboquant_") + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + # head_size from spec is effective_head_size (padded_slot//2), + # not the model's actual head_dim. Accept any positive value. + return head_size > 0 + + +@dataclass +class TurboQuantMetadata(AttentionMetadata): + """Metadata for TurboQuant attention.""" + + seq_lens: torch.Tensor # (num_reqs,) — total context length per request + slot_mapping: torch.Tensor # (num_tokens,) — cache slot for each token + block_table: torch.Tensor # (num_reqs, max_num_blocks) + query_start_loc: torch.Tensor # (num_reqs + 1,) — cu_seqlens for queries + num_actual_tokens: int = 0 # actual tokens (excluding padding) + max_query_len: int = 0 # longest query in batch + max_seq_len: int = 0 # longest context in batch + is_prefill: bool = False + num_decodes: int = 0 # number of decode requests (first in batch) + num_decode_tokens: int = 0 # tokens from decode requests + + +class TurboQuantMetadataBuilder(AttentionMetadataBuilder[TurboQuantMetadata]): + """Builds TurboQuantMetadata from scheduler output.""" + + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + def __init__(self, kv_cache_spec, layer_names, vllm_config, device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=False) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> TurboQuantMetadata: + attn_metadata = self.build(0, common_attn_metadata) + # Set seq_lens to 1 so CUDA graph capture is fast + # (real seq_lens are filled at replay time). + attn_metadata.seq_lens.fill_(1) + return attn_metadata + + def build(self, common_prefix_len, common_attn_metadata, fast_build=False): + """Build TurboQuantMetadata from common attention metadata.""" + cam = common_attn_metadata + + # With reorder_batch_threshold=1, the model runner guarantees + # decodes come first in the batch. split_decodes_and_prefills + # finds the boundary (operates on CPU tensors — no GPU sync). + assert self.reorder_batch_threshold is not None + num_decodes, num_prefills, num_decode_tokens, _ = split_decodes_and_prefills( + cam, decode_threshold=self.reorder_batch_threshold + ) + + return TurboQuantMetadata( + seq_lens=cam.seq_lens, + slot_mapping=cam.slot_mapping, + block_table=cam.block_table_tensor, + query_start_loc=cam.query_start_loc, + num_actual_tokens=cam.num_actual_tokens, + max_query_len=cam.max_query_len, + max_seq_len=cam.max_seq_len, + is_prefill=(cam.max_query_len > 1), + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + ) + + +class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]): + """TurboQuant attention implementation. + + Vectorized PyTorch: batch quantize/store, vectorized bit-unpack + decode with einsum scores and value gather. + """ + + supports_quant_query_input: bool = False + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + sliding_window: int | None = None, + kv_cache_dtype: str = "auto", + logits_soft_cap: float | None = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + **kwargs, + ): + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + from vllm.model_executor.layers.quantization.turboquant.config import ( + TurboQuantConfig, + ) + + self.tq_config = TurboQuantConfig.from_cache_dtype(kv_cache_dtype, head_size) + + # Pre-compute kernel constants from config (avoid repeated arithmetic) + cfg = self.tq_config + self._mse_bytes = ( + math.ceil(head_size * cfg.key_mse_bits / 8) + if not cfg.key_fp8 + else head_size + ) + self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8) + self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1 + + # Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph, + # and benchmarks show no regression vs dynamic in eager mode). + vllm_config = get_current_vllm_config() + self.max_num_kv_splits = ( + vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph + ) + + def _ensure_on_device(self, layer, device): + """One-time derivation of TQ buffers (rotation matrices, midpoints). + + Registered buffers (_tq_signs, _tq_centroids) are already on the + correct device via register_buffer + model.to(device). + """ + if not hasattr(layer, "_tq_cached"): + D = layer._tq_signs.shape[0] + signs = layer._tq_signs.to(device=device, dtype=torch.float32) + + # WHT rotation: orthonormal + self-inverse, enabling future + # in-kernel butterfly fusion and trivial inverse for continuation. + H = _build_hadamard(D, str(device)) + layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous() + layer._tq_Pi = layer._tq_PiT.T.contiguous() + + c = layer._tq_centroids.to(device=device, dtype=torch.float32) + # Precompute midpoints for threshold-based quantization + c_sorted, _ = c.sort() + layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2 + # Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf) + # are pre-allocated via register_buffer in Attention.__init__ + # and moved to GPU by model.to(device) — no allocation needed + # here. The memory profiler sees them before KV cache sizing. + layer._tq_cached = True + + def do_kv_cache_update( + self, + layer: torch.nn.Module, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Store compressed K/V into the combined TQ cache. + + Called as a separate custom op (unified_kv_cache_update) BEFORE + the attention forward, matching FlashAttention's split pattern. + slot_mapping is already sliced to num_actual_tokens by the caller. + """ + N = slot_mapping.shape[0] + if N <= 0: + return + + device = key.device + self._ensure_on_device(layer, device) + + k = key[:N].view(N, self.num_kv_heads, self.head_size) + v = value[:N].view(N, self.num_kv_heads, self.head_size) + self._store_kv(k, v, kv_cache, slot_mapping, layer) + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: "TurboQuantMetadata", + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + num_tokens = query.shape[0] + + if output is None: + output = torch.zeros( + num_tokens, + self.num_heads * self.head_size, + dtype=query.dtype, + device=query.device, + ) + + if attn_metadata is None: + return output.fill_(0) + + # Slice to actual tokens + N = attn_metadata.num_actual_tokens + if N <= 0: + return output.fill_(0) + + q = query[:N].view(N, self.num_heads, self.head_size) + + # Get TQ buffers, ensure on device (one-time migration). + # Use Any-typed alias for dynamic _tq_* attrs set by _ensure_on_device. + tq_layer: Any = layer + device = q.device + self._ensure_on_device(tq_layer, device) + Pi = tq_layer._tq_Pi + PiT = tq_layer._tq_PiT + centroids = tq_layer._tq_centroids + + # Compute attention (KV cache was already updated by do_kv_cache_update) + # With reorder_batch_threshold=1, decodes come first in the batch. + # num_decodes/num_decode_tokens from metadata give the split point. + num_decodes = attn_metadata.num_decodes + num_decode_tokens = attn_metadata.num_decode_tokens + + if not attn_metadata.is_prefill: + # Pure decode batch — fast path + attn_out = self._decode_attention( + q, kv_cache, attn_metadata, Pi, centroids, PiT, layer + ) + elif num_decodes == 0: + # Pure prefill batch + k = key[:N].view(N, self.num_kv_heads, self.head_size) + v = value[:N].view(N, self.num_kv_heads, self.head_size) + attn_out = self._prefill_attention( + q, + k, + v, + kv_cache, + attn_metadata, + Pi, + centroids, + PiT, + layer=layer, + ) + else: + # Mixed batch: decodes first (guaranteed by reorder_batch). + attn_out = torch.zeros( + N, self.num_heads, self.head_size, device=device, dtype=q.dtype + ) + + # --- Decode portion (first num_decodes requests) --- + # Use full-batch max_seq_len as safe upper bound (no GPU sync). + decode_meta = TurboQuantMetadata( + seq_lens=attn_metadata.seq_lens[:num_decodes], + slot_mapping=attn_metadata.slot_mapping[:num_decode_tokens], + block_table=attn_metadata.block_table[:num_decodes], + query_start_loc=attn_metadata.query_start_loc[: num_decodes + 1], + num_actual_tokens=num_decode_tokens, + max_query_len=1, + max_seq_len=attn_metadata.max_seq_len, + is_prefill=False, + ) + attn_out[:num_decode_tokens] = self._decode_attention( + q[:num_decode_tokens], kv_cache, decode_meta, Pi, centroids, PiT, layer + ) + + # --- Prefill portion (remaining requests) --- + # CRITICAL: use prefill-specific max_seq_len so flash_attn's + # fast path (max_query_len == max_seq_len) triggers for + # first-chunk prefills. Using full-batch max_seq_len breaks + # this because decode requests inflate max_seq_len. + prefill_seq_lens = attn_metadata.seq_lens[num_decodes:] + # Use CPU-side max to avoid GPU→CPU sync from .item() + prefill_max_seq = max(attn_metadata.seq_lens[num_decodes:].tolist()) + prefill_qsl = ( + attn_metadata.query_start_loc[num_decodes:] - num_decode_tokens + ) + prefill_meta = TurboQuantMetadata( + seq_lens=prefill_seq_lens, + slot_mapping=attn_metadata.slot_mapping[num_decode_tokens:N], + block_table=attn_metadata.block_table[num_decodes:], + query_start_loc=prefill_qsl, + num_actual_tokens=N - num_decode_tokens, + max_query_len=attn_metadata.max_query_len, + max_seq_len=prefill_max_seq, + is_prefill=True, + ) + k = key[:N].view(N, self.num_kv_heads, self.head_size) + v = value[:N].view(N, self.num_kv_heads, self.head_size) + attn_out[num_decode_tokens:] = self._prefill_attention( + q[num_decode_tokens:], + k[num_decode_tokens:], + v[num_decode_tokens:], + kv_cache, + prefill_meta, + Pi, + centroids, + PiT, + layer=layer, + ) + + # Write into output buffer: attn_out is (N, Hq, D) + # output may be 2D (N, Hq*D) or 3D (N, Hq, D) + if output.ndim == 3: + output[:N] = attn_out.to(output.dtype) + else: + output[:N] = attn_out.reshape(N, -1).to(output.dtype) + return output + + # ------------------------------------------------------------------ # + # Store K/V into combined cache (vectorized) # + # ------------------------------------------------------------------ # + def _store_kv( + self, + key: torch.Tensor, # (N, Hk, D) + value: torch.Tensor, # (N, Hk, D) + kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) + slot_mapping: torch.Tensor, + layer: Any, + ): + """Quantize + store via fused Triton kernel.""" + triton_turboquant_store( + key, + value, + kv_cache, + slot_mapping, + layer._tq_PiT, + layer._tq_midpoints, + mse_bits=self.tq_config.key_mse_bits, + key_packed_size=self.tq_config.key_packed_size, + value_quant_bits=self.tq_config.effective_value_quant_bits, + key_fp8=self.tq_config.key_fp8, + ) + + # ------------------------------------------------------------------ # + # Prefill: SDPA on raw Q/K/V with causal mask # + # ------------------------------------------------------------------ # + def _prefill_attention( + self, + query: torch.Tensor, # (N, Hq, D) + key: torch.Tensor, # (N, Hk, D) + value: torch.Tensor, # (N, Hk, D) + kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) + attn_metadata: TurboQuantMetadata, + Pi: torch.Tensor, + centroids: torch.Tensor, + PiT: torch.Tensor | None = None, + layer: Any = None, + ) -> torch.Tensor: + N, Hq, D = query.shape + + # Fast path: use flash_attn for first-chunk prefills (all K/V in batch). + # max_query_len == max_seq_len means no request has prior cached KV. + # Both are Python ints — no GPU sync. + if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len: + output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype) + flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=attn_metadata.max_query_len, + softmax_scale=self.scale, + causal=True, + out=output, + ) + return output + + # Continuation or no flash_attn: per-request attention. + # For continuation chunks (seq_len > q_len), we must attend to + # previously cached K/V from the TQ cache, not just the current + # chunk's raw K/V. + Hk = key.shape[1] + use_gqa = Hk < Hq + query_start_loc = attn_metadata.query_start_loc + num_reqs = query_start_loc.shape[0] - 1 + + output = torch.zeros(N, Hq, D, device=query.device, dtype=query.dtype) + + # Convert to Python lists once (single CPU-GPU sync) instead of + # per-request .item() calls that each force a sync. + qsl = query_start_loc.tolist() + seq_lens_list = attn_metadata.seq_lens.tolist() + + # Pre-allocate cu_seqlens for single-request flash_attn calls + # to avoid per-request host→device tensor creation. + _cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32) + + for i in range(num_reqs): + q_start = qsl[i] + q_end = qsl[i + 1] + q_len = q_end - q_start + if q_len <= 0: + continue + + seq_len = seq_lens_list[i] + q_seq = query[q_start:q_end] # (q_len, Hq, D) + k_seq = key[q_start:q_end] # (q_len, Hk, D) + v_seq = value[q_start:q_end] # (q_len, Hk, D) + + if q_len == seq_len: + # First-chunk prefill: all K/V are in the current batch. + if _HAS_FLASH_ATTN: + out = torch.empty_like(q_seq) + _cu_2[1] = q_len + cu = _cu_2 + flash_attn_varlen_func( + q=q_seq, + k=k_seq, + v=v_seq, + cu_seqlens_q=cu, + cu_seqlens_k=cu, + max_seqlen_q=q_len, + max_seqlen_k=q_len, + softmax_scale=self.scale, + causal=True, + out=out, + ) + else: + q_t = q_seq.transpose(0, 1).contiguous() + k_t = k_seq.transpose(0, 1).contiguous() + v_t = v_seq.transpose(0, 1).contiguous() + out = F.scaled_dot_product_attention( + q_t, + k_t, + v_t, + is_causal=True, + scale=self.scale, + enable_gqa=use_gqa, + ).transpose(0, 1) + output[q_start:q_end] = out.to(query.dtype) + else: + # Continuation chunk: tokens already stored to TQ cache + # by do_kv_cache_update. Use decode kernel directly to + # avoid O(cached_len) full-dequant per continuation. + # For large continuations, fall back to _continuation_prefill. + cached_len = seq_len - q_len + if q_len <= _CONTINUATION_DECODE_THRESHOLD: + # Fast path: treat each query as a decode request + # with incremental seq_lens for causal masking. + synth_seq_lens = torch.arange( + cached_len + 1, + seq_len + 1, + device=query.device, + dtype=attn_metadata.seq_lens.dtype, + ) + synth_bt = attn_metadata.block_table[i : i + 1].expand(q_len, -1) + out = triton_turboquant_decode_attention( + query=q_seq, + kv_cache=kv_cache, + block_table=synth_bt, + seq_lens=synth_seq_lens, + Pi=Pi, + centroids=centroids, + scale=self.scale, + mse_bits=self.tq_config.key_mse_bits, + key_packed_size=self.tq_config.key_packed_size, + value_quant_bits=(self.tq_config.effective_value_quant_bits), + key_fp8=self.tq_config.key_fp8, + norm_correction=self.tq_config.norm_correction, + PiT=PiT, + ) + else: + # Large continuation: dequant cached K/V and use + # flash_attn for better throughput. + out = self._continuation_prefill( + layer, + q_seq, + k_seq, + v_seq, + kv_cache, + attn_metadata.block_table[i : i + 1], + cached_len, + seq_len, + Pi, + centroids, + ) + output[q_start:q_end] = out.to(query.dtype) + + return output + + def _continuation_prefill( + self, + layer: Any, + query: torch.Tensor, # (q_len, Hq, D) + key_chunk: torch.Tensor, # (q_len, Hk, D) + val_chunk: torch.Tensor, # (q_len, Hk, D) + kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) + block_table: torch.Tensor, # (1, max_num_blocks) + cached_len: int, + seq_len: int, + Pi: torch.Tensor, + centroids: torch.Tensor, + ) -> torch.Tensor: + """Handle continuation chunk by dequanting cached K/V from TQ cache. + + Dequants previously cached K/V, concatenates with the current + chunk's raw K/V, then runs flash_attn with causal masking. + """ + q_len, Hq, D = query.shape + Hk = key_chunk.shape[1] + device = query.device + block_size = kv_cache.shape[1] + BLOCK_D = triton.next_power_of_2(D) + + mse_bytes = self._mse_bytes + val_data_bytes = self._val_data_bytes + + # Dequant cached K/V from TQ cache + # Allocate slightly over to align to block_size for the grid. + # Reuse cached buffers to avoid per-call allocation (~16MB at 8K). + alloc_len = math.ceil(cached_len / block_size) * block_size + buf_shape = (1, Hk, alloc_len, D) + k_buf = getattr(layer, "_tq_k_dequant_buf", None) + if k_buf is None or k_buf.shape[2] < alloc_len: + k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device) + v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device) + layer._tq_k_dequant_buf = k_buf + layer._tq_v_dequant_buf = v_buf + else: + v_buf = layer._tq_v_dequant_buf + k_cached = k_buf[:, :, :alloc_len, :].zero_() + v_cached = v_buf[:, :, :alloc_len, :].zero_() + + grid = (alloc_len, 1 * Hk) + _tq_full_dequant_kv[grid]( + kv_cache, + block_table, + centroids, + k_cached, + v_cached, + k_cached.stride(0), + k_cached.stride(1), + k_cached.stride(2), + v_cached.stride(0), + v_cached.stride(1), + v_cached.stride(2), + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + block_table.stride(0), + HEAD_DIM=D, + BLOCK_SIZE=block_size, + NUM_KV_HEADS=Hk, + MSE_BYTES=mse_bytes, + KPS=self.tq_config.key_packed_size, + VQB=self.tq_config.effective_value_quant_bits, + VAL_DATA_BYTES=val_data_bytes, + MSE_BITS=self.tq_config.key_mse_bits, + KEY_FP8=1 if self.tq_config.key_fp8 else 0, + BLOCK_D=BLOCK_D, + NORM_CORRECTION=1 if self.tq_config.norm_correction else 0, + FP8_E4B15=_use_fp8_e4b15(device.index or 0), + num_warps=4, + ) + + # Inverse-rotate MSE keys back to original space + if not self.tq_config.key_fp8: + k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float() + k_flat = k_flat @ Pi + k_cached_trim = ( + k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1) + ) # (cached_len, Hk, D) + else: + k_cached_trim = ( + k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous() + ) # (cached_len, Hk, D) + + v_cached_trim = ( + v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous() + ) # (cached_len, Hk, D) + + # Concatenate cached + current chunk K/V (match query dtype) + qdtype = query.dtype + k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0) + v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0) + + # Attention: q_len queries attending to seq_len K/V with causal mask + if _HAS_FLASH_ATTN: + output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype) + cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32) + cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32) + flash_attn_varlen_func( + q=query, + k=k_full, + v=v_full, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=seq_len, + softmax_scale=self.scale, + causal=True, + out=output, + ) + return output + else: + # SDPA fallback: expand KV for GQA, build causal mask + q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D) + k_t = k_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D) + v_t = v_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D) + # Build causal mask: query position p can attend to K position j + # where j <= cached_len + p (p is 0-indexed within chunk) + q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len + k_pos = torch.arange(seq_len, device=device).unsqueeze(0) + mask = k_pos <= q_pos # (q_len, seq_len) + out = F.scaled_dot_product_attention( + q_t, + k_t, + v_t, + attn_mask=mask, + scale=self.scale, + enable_gqa=(Hk < Hq), + ) # (1, Hq, q_len, D) + return out[0].transpose(0, 1) # (q_len, Hq, D) + + # ------------------------------------------------------------------ # + # Decode: Triton TQ decode attention # + # ------------------------------------------------------------------ # + def _decode_attention( + self, + query: torch.Tensor, # (B, Hq, D) + kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size) + attn_metadata: TurboQuantMetadata, + Pi: torch.Tensor, + centroids: torch.Tensor, + PiT: torch.Tensor | None = None, + layer: torch.nn.Module | None = None, + ) -> torch.Tensor: + # Grab cached decode buffers from the layer (lazily allocated). + mid_o_buf = output_buf = lse_buf = None + if layer is not None: + mid_o_buf = getattr(layer, "_tq_mid_o_buf", None) + output_buf = getattr(layer, "_tq_output_buf", None) + lse_buf = getattr(layer, "_tq_lse_buf", None) + + result = triton_turboquant_decode_attention( + query=query, + kv_cache=kv_cache, + block_table=attn_metadata.block_table, + seq_lens=attn_metadata.seq_lens, + Pi=Pi, + centroids=centroids, + scale=self.scale, + mse_bits=self.tq_config.key_mse_bits, + key_packed_size=self.tq_config.key_packed_size, + value_quant_bits=self.tq_config.effective_value_quant_bits, + key_fp8=self.tq_config.key_fp8, + norm_correction=self.tq_config.norm_correction, + PiT=PiT, + mid_o_buf=mid_o_buf, + output_buf=output_buf, + lse_buf=lse_buf, + buf_holder=layer, + max_num_kv_splits=self.max_num_kv_splits, + ) + return result diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py new file mode 100644 index 000000000000..8b276e31eafb --- /dev/null +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -0,0 +1,617 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton fused TurboQuant decode attention. + +Decode path: Triton stage1 (split-KV tiled attention scoring + value +accumulation) + stage2 (log-sum-exp reduction across splits). + +Supports FP8 (E4M3) keys, 3-bit and 4-bit uniform quantized values. +""" + +import math +from typing import Any + +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.attention.ops.triton_decode_attention import ( + _fwd_kernel_stage2, +) + +_FP8_E4B15: dict[int, int] = {} + + +def _use_fp8_e4b15(device: int = 0) -> int: + """Return 1 if device needs fp8e4b15 (Ampere/Ada, SM < 8.9), else 0.""" + if device not in _FP8_E4B15: + cap = torch.cuda.get_device_capability(device) + _FP8_E4B15[device] = 1 if cap < (8, 9) else 0 + return _FP8_E4B15[device] + + +# --------------------------------------------------------------------------- +# Stage 1: Fused TQ score + value accumulation (BLOCK_KV tiled) +# --------------------------------------------------------------------------- + + +@triton.jit +def _tq_decode_stage1( + # Precomputed query projection + Q_rot_ptr, # [B, Hq, D] float32 + # Compressed KV cache (combined K+V) + KV_cache_ptr, # [num_blocks, block_size, Hk, padded_slot] uint8 + # Block table and sequence info + Block_table_ptr, # [B, max_num_blocks] int32 + Seq_lens_ptr, # [B] int32 + # TQ parameters + Centroids_ptr, # [n_centroids] float32 + # Output (intermediate for stage2) + Mid_o_ptr, # [B, Hq, NUM_KV_SPLITS, D+1] float32 + # Strides + stride_qb, + stride_qh, # Q strides: [B, Hq, D] + stride_cache_block, + stride_cache_pos, + stride_cache_head, # KV cache + stride_bt_b, # block_table stride per batch + stride_mid_b, + stride_mid_h, + stride_mid_s, # mid_o strides + # Constexpr dims + NUM_KV_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # KV cache block_size (pages) + NUM_KV_SPLITS: tl.constexpr, + KV_GROUP_SIZE: tl.constexpr, # Hq // Hk + # TQ layout constants + MSE_BITS: tl.constexpr, # 3 or 4 + MSE_BYTES: tl.constexpr, # ceil(D * mse_bits / 8) + KPS: tl.constexpr, # key_packed_size + VQB: tl.constexpr, # value_quant_bits (4 or 8=FP8) + VAL_DATA_BYTES: tl.constexpr, # ceil(D * vqb / 8) or D for FP8 + # Score constants + ATTN_SCALE: tl.constexpr, # 1/sqrt(D) + # Block tile sizes + BLOCK_D: tl.constexpr, # next_power_of_2(HEAD_DIM) + BLOCK_KV: tl.constexpr, # tokens per tile (16) + KEY_FP8: tl.constexpr, # 1 if K is stored as FP8 + NORM_CORRECTION: tl.constexpr = 0, # 1 = re-normalize centroids + FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+) +): + bid = tl.program_id(0) # batch index + hid = tl.program_id(1) # q_head index + sid = tl.program_id(2) # kv_split index + + kv_head = hid // KV_GROUP_SIZE + + # Sequence length for this batch + seq_len = tl.load(Seq_lens_ptr + bid) + + # KV split range + split_len = tl.cdiv(seq_len, NUM_KV_SPLITS) + split_start = split_len * sid + split_end = tl.minimum(split_start + split_len, seq_len) + + if split_start >= split_end: + return + + # Dimension offsets + d_offs = tl.arange(0, BLOCK_D) + d_mask = d_offs < HEAD_DIM + kv_range = tl.arange(0, BLOCK_KV) + + # Load query vector: q_rot — [BLOCK_D] float32 + q_base = bid * stride_qb + hid * stride_qh + q_rot = tl.load(Q_rot_ptr + q_base + d_offs, mask=d_mask, other=0.0).to(tl.float32) + + # Precompute byte/bit index vectors for MSE gather loads + if not KEY_FP8: + mse_bit_off = d_offs * MSE_BITS + mse_byte_idx = mse_bit_off // 8 + mse_bit_shift = mse_bit_off % 8 + mse_mask = (1 << MSE_BITS) - 1 + + # Precompute value bit/byte index vectors (loop-invariant) + if VQB == 3: + val_bit_off = d_offs * 3 + val_byte_idx = val_bit_off // 8 + val_bit_shift = val_bit_off % 8 + + # Online softmax accumulators + m_prev = -float("inf") + l_prev = 0.0 + acc = tl.zeros([BLOCK_D], dtype=tl.float32) + + bt_base = bid * stride_bt_b + + # ================================================================ + # TILED LOOP: process BLOCK_KV tokens per iteration + # ================================================================ + for start_n in range(split_start, split_end, BLOCK_KV): + kv_offs = start_n + kv_range + kv_mask = kv_offs < split_end + + page_idx = kv_offs // BLOCK_SIZE + page_off = kv_offs % BLOCK_SIZE + block_nums = tl.load( + Block_table_ptr + bt_base + page_idx, + mask=kv_mask, + other=0, + ) + + slot_bases = ( + block_nums * stride_cache_block + + page_off * stride_cache_pos + + kv_head * stride_cache_head + ) + + # ============================================================ + # COMPUTE ATTENTION SCORES: [BLOCK_KV] + # ============================================================ + if KEY_FP8: + k_addrs = slot_bases[:, None] + d_offs[None, :] + k_raw = tl.load( + KV_cache_ptr + k_addrs, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ) + if FP8_E4B15: + k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) + else: + k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) + scores = ( + tl.sum( + tl.where(d_mask[None, :], q_rot[None, :] * k_float, 0.0), + axis=1, + ) + * ATTN_SCALE + ) + scores = tl.where(kv_mask, scores, -float("inf")) + else: + # MSE unpack + norms + mse_addrs0 = slot_bases[:, None] + mse_byte_idx[None, :] + mse_raw0 = tl.load( + KV_cache_ptr + mse_addrs0, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + mse_raw1 = tl.load( + KV_cache_ptr + mse_addrs0 + 1, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + raw16 = mse_raw0 | (mse_raw1 << 8) + mse_idx = (raw16 >> mse_bit_shift[None, :]) & mse_mask + + # Centroid gather + dot product + c_vals = tl.load( + Centroids_ptr + mse_idx, + mask=kv_mask[:, None] & d_mask[None, :], + other=0.0, + ) + + # Norm correction: re-normalize centroid vector to unit norm + if NORM_CORRECTION: + c_norm_sq = tl.sum( + tl.where(d_mask[None, :], c_vals * c_vals, 0.0), + axis=1, + ) + c_inv_norm = 1.0 / tl.sqrt(c_norm_sq + 1e-16) + c_vals = c_vals * c_inv_norm[:, None] + + term1 = tl.sum( + tl.where(d_mask[None, :], q_rot[None, :] * c_vals, 0.0), + axis=1, + ) + + # Load norms (fp16 -> fp32): norms are at MSE_BYTES offset + norm_bases = slot_bases + MSE_BYTES + n_lo = tl.load(KV_cache_ptr + norm_bases, mask=kv_mask, other=0).to( + tl.uint16 + ) + n_hi = tl.load(KV_cache_ptr + norm_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + vec_norms = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + + scores = vec_norms * term1 * ATTN_SCALE + scores = tl.where(kv_mask, scores, -float("inf")) + + # ============================================================ + # ONLINE SOFTMAX UPDATE (block-level) + # ============================================================ + n_e_max = tl.maximum(tl.max(scores, 0), m_prev) + re_scale = tl.exp(m_prev - n_e_max) + p = tl.exp(scores - n_e_max) + + # ============================================================ + # VALUE LOAD + DEQUANTIZE: [BLOCK_KV, BLOCK_D] + # ============================================================ + val_bases = slot_bases + KPS + + if VQB == 3: + val_addrs0 = val_bases[:, None] + val_byte_idx[None, :] + val_raw0 = tl.load( + KV_cache_ptr + val_addrs0, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + val_raw1 = tl.load( + KV_cache_ptr + val_addrs0 + 1, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + raw16 = val_raw0 | (val_raw1 << 8) + v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32) + + sc_bases = val_bases + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( + tl.uint16 + ) + sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_scales = ( + (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + ) + zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( + tl.uint16 + ) + zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + values = v_idx * v_scales[:, None] + v_zeros[:, None] + else: # VQB == 4 + vb_idx = d_offs // 2 + vb_shift = (d_offs % 2) * 4 + val_addrs = val_bases[:, None] + vb_idx[None, :] + val_raw = tl.load( + KV_cache_ptr + val_addrs, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32) + + sc_bases = val_bases + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( + tl.uint16 + ) + sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_scales = ( + (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + ) + zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( + tl.uint16 + ) + zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + values = v_idx * v_scales[:, None] + v_zeros[:, None] + + # ============================================================ + # WEIGHTED VALUE ACCUMULATION + # ============================================================ + acc = acc * re_scale + tl.sum(p[:, None] * values, 0) + l_prev = l_prev * re_scale + tl.sum(p, 0) + m_prev = n_e_max + + # Store partial result + out_base = bid * stride_mid_b + hid * stride_mid_h + sid * stride_mid_s + safe_l = tl.where(l_prev > 0.0, l_prev, 1.0) + tl.store(Mid_o_ptr + out_base + d_offs, acc / safe_l, mask=d_mask) + lse = m_prev + tl.log(safe_l) + tl.store(Mid_o_ptr + out_base + HEAD_DIM, lse) + + +# --------------------------------------------------------------------------- +# Pre-dequant kernel: Bulk dequant K (MSE+norms) and V to fp16 +# --------------------------------------------------------------------------- + + +@triton.jit +def _tq_full_dequant_kv( + KV_cache_ptr, + Block_table_ptr, + Centroids_ptr, + K_out_ptr, # [B, Hk, max_seq, D] float16 + V_out_ptr, # [B, Hk, max_seq, D] float16 + stride_ko_b, + stride_ko_h, + stride_ko_s, + stride_vo_b, + stride_vo_h, + stride_vo_s, + stride_cache_block, + stride_cache_pos, + stride_cache_head, + stride_bt_b, + HEAD_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_KV_HEADS: tl.constexpr, + MSE_BYTES: tl.constexpr, + KPS: tl.constexpr, + VQB: tl.constexpr, + VAL_DATA_BYTES: tl.constexpr, + MSE_BITS: tl.constexpr, + KEY_FP8: tl.constexpr, + BLOCK_D: tl.constexpr, + NORM_CORRECTION: tl.constexpr = 0, + FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+) +): + """Full dequant: reconstruct K (MSE centroids * norm or FP8) and V to fp16.""" + pos = tl.program_id(0) + bh = tl.program_id(1) + bid = bh // NUM_KV_HEADS + hid = bh % NUM_KV_HEADS + + page_idx = pos // BLOCK_SIZE + page_off = pos % BLOCK_SIZE + block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx) + slot_base = ( + block_num * stride_cache_block + + page_off * stride_cache_pos + + hid * stride_cache_head + ) + + d_offs = tl.arange(0, BLOCK_D) + d_mask = d_offs < HEAD_DIM + + # === K dequant === + ko_base = bid * stride_ko_b + hid * stride_ko_h + pos * stride_ko_s + if KEY_FP8: + k_raw = tl.load(KV_cache_ptr + slot_base + d_offs, mask=d_mask, other=0) + if FP8_E4B15: + k_recon = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) + else: + k_recon = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) + tl.store(K_out_ptr + ko_base + d_offs, k_recon.to(tl.float16), mask=d_mask) + else: + # MSE unpack (3-bit or 4-bit) + norms + mse_bit_off = d_offs * MSE_BITS + mse_byte_idx = mse_bit_off // 8 + mse_bit_shift = mse_bit_off % 8 + mse_umask = (1 << MSE_BITS) - 1 + + mse_raw0 = tl.load( + KV_cache_ptr + slot_base + mse_byte_idx, mask=d_mask, other=0 + ).to(tl.int32) + mse_raw1 = tl.load( + KV_cache_ptr + slot_base + mse_byte_idx + 1, mask=d_mask, other=0 + ).to(tl.int32) + raw16_key = mse_raw0 | (mse_raw1 << 8) + mse_idx = (raw16_key >> mse_bit_shift) & mse_umask + + k_mse = tl.load(Centroids_ptr + mse_idx, mask=d_mask, other=0.0) + + # Norm correction: re-normalize centroid vector to unit norm + if NORM_CORRECTION: + c_norm_sq = tl.sum(tl.where(d_mask, k_mse * k_mse, 0.0), axis=0) + c_inv_norm = 1.0 / tl.sqrt(c_norm_sq + 1e-16) + k_mse = k_mse * c_inv_norm + + # Norms at MSE_BYTES offset (no QJL bytes) + norm_base = slot_base + MSE_BYTES + n_lo = tl.load(KV_cache_ptr + norm_base).to(tl.uint16) + n_hi = tl.load(KV_cache_ptr + norm_base + 1).to(tl.uint16) + vec_norm = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + + k_recon = vec_norm * k_mse + tl.store(K_out_ptr + ko_base + d_offs, k_recon.to(tl.float16), mask=d_mask) + + # === V dequant === + val_base = slot_base + KPS + if VQB == 4: + vb_idx = d_offs // 2 + vb_shift = (d_offs % 2) * 4 + val_raw = tl.load(KV_cache_ptr + val_base + vb_idx, mask=d_mask, other=0).to( + tl.int32 + ) + v_idx = ((val_raw >> vb_shift) & 0xF).to(tl.float32) + + sc_base = val_base + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_base).to(tl.uint16) + sc_hi = tl.load(KV_cache_ptr + sc_base + 1).to(tl.uint16) + v_scale = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + zr_lo = tl.load(KV_cache_ptr + sc_base + 2).to(tl.uint16) + zr_hi = tl.load(KV_cache_ptr + sc_base + 3).to(tl.uint16) + v_zero = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + v_vals = v_idx * v_scale + v_zero + elif VQB == 3: + # 3-bit value unpack: 8 values per 3 bytes + val_bit_off = d_offs * 3 + val_byte_idx = val_bit_off // 8 + val_bit_shift = val_bit_off % 8 + val_raw0 = tl.load( + KV_cache_ptr + val_base + val_byte_idx, mask=d_mask, other=0 + ).to(tl.int32) + val_raw1 = tl.load( + KV_cache_ptr + val_base + val_byte_idx + 1, mask=d_mask, other=0 + ).to(tl.int32) + raw16_val = val_raw0 | (val_raw1 << 8) + v_idx = ((raw16_val >> val_bit_shift) & 0x7).to(tl.float32) + + sc_base = val_base + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_base).to(tl.uint16) + sc_hi = tl.load(KV_cache_ptr + sc_base + 1).to(tl.uint16) + v_scale = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + zr_lo = tl.load(KV_cache_ptr + sc_base + 2).to(tl.uint16) + zr_hi = tl.load(KV_cache_ptr + sc_base + 3).to(tl.uint16) + v_zero = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + v_vals = v_idx * v_scale + v_zero + else: + v_vals = tl.zeros([BLOCK_D], dtype=tl.float32) + + vo_base = bid * stride_vo_b + hid * stride_vo_h + pos * stride_vo_s + tl.store(V_out_ptr + vo_base + d_offs, v_vals.to(tl.float16), mask=d_mask) + + +# --------------------------------------------------------------------------- +# Stage 2: Reuse from triton_decode_attention.py +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Launcher — cached constants + fused GEMM +# --------------------------------------------------------------------------- + +_layout_cache: dict = {} + + +def _get_layout(D, mse_bits, value_quant_bits, key_packed_size): + """Get cached layout constants.""" + key = (D, mse_bits, value_quant_bits, key_packed_size) + cfg = _layout_cache.get(key) + if cfg is None: + val_data_bytes = math.ceil(D * value_quant_bits / 8) + cfg = { + "mse_bytes": math.ceil(D * mse_bits / 8), + "val_data_bytes": val_data_bytes, + "mse_bits": mse_bits, + "n_centroids": 2**mse_bits, + "BLOCK_D": triton.next_power_of_2(D), + } + _layout_cache[key] = cfg + return cfg + + +def triton_turboquant_decode_attention( + query: torch.Tensor, # [B, Hq, D] — original query + kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8 + block_table: torch.Tensor, # [B, max_num_blocks] int32 + seq_lens: torch.Tensor, # [B] int32 + Pi: torch.Tensor, # [D, D] float32 + centroids: torch.Tensor, # [n_centroids] float32 + scale: float, + mse_bits: int, + key_packed_size: int, + value_quant_bits: int, + key_fp8: bool = False, + norm_correction: bool = False, + PiT: torch.Tensor | None = None, # [D, D] pre-computed Pi.T contiguous + # Pre-allocated buffers (optional, avoids per-call allocation) + mid_o_buf: torch.Tensor | None = None, + output_buf: torch.Tensor | None = None, + lse_buf: torch.Tensor | None = None, + buf_holder: Any = None, + max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph) +) -> torch.Tensor: + """Launch fused TQ decode attention (Triton stage1 + stage2). + + Returns: output tensor [B, Hq, D] in query's dtype. + """ + B, Hq, D = query.shape + Hk = kv_cache.shape[2] + block_size = kv_cache.shape[1] + kv_group_size = Hq // Hk + device = query.device + + cfg = _get_layout(D, mse_bits, value_quant_bits, key_packed_size) + + # Compute q_rot = q @ Pi.T (rotated query for MSE key scoring) + # FP8 path: pass query directly (float16); kernel casts inline. + # MSE path: still needs external GEMM (cuBLAS), so q_rot is float32. + if key_fp8: + q_rot = query.contiguous() + else: + q_float = query.float() + if PiT is None: + PiT = Pi.T.contiguous() + q_rot = (q_float @ PiT).contiguous() + + NUM_KV_SPLITS = max_num_kv_splits + + if ( + mid_o_buf is not None + and mid_o_buf.shape[0] >= B + and mid_o_buf.shape[2] >= NUM_KV_SPLITS + ): + mid_o = mid_o_buf[:B, :Hq, :NUM_KV_SPLITS, :] + else: + mid_o = torch.empty( + B, + Hq, + NUM_KV_SPLITS, + D + 1, + dtype=torch.float32, + device=device, + ) + if buf_holder is not None: + buf_holder._tq_mid_o_buf = mid_o + + # Stage 1: split-KV tiled attention scoring + value accumulation + fp8_e4b15 = _use_fp8_e4b15(device.index or 0) + BLOCK_KV = 4 + grid = (B, Hq, NUM_KV_SPLITS) + _tq_decode_stage1[grid]( + q_rot, + kv_cache, + block_table, + seq_lens, + centroids, + mid_o, + q_rot.stride(0), + q_rot.stride(1), + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + block_table.stride(0), + mid_o.stride(0), + mid_o.stride(1), + mid_o.stride(2), + NUM_KV_HEADS=Hk, + HEAD_DIM=D, + BLOCK_SIZE=block_size, + NUM_KV_SPLITS=NUM_KV_SPLITS, + KV_GROUP_SIZE=kv_group_size, + MSE_BITS=mse_bits, + MSE_BYTES=cfg["mse_bytes"], + KPS=key_packed_size, + VQB=value_quant_bits, + VAL_DATA_BYTES=cfg["val_data_bytes"], + ATTN_SCALE=scale, + BLOCK_D=cfg["BLOCK_D"], + BLOCK_KV=BLOCK_KV, + KEY_FP8=1 if key_fp8 else 0, + NORM_CORRECTION=1 if norm_correction else 0, + FP8_E4B15=fp8_e4b15, + num_warps=1, + num_stages=1, + ) + + # Stage 2: Reduce across KV splits + if output_buf is not None and output_buf.shape[0] >= B: + output = output_buf[:B, :Hq, :D] + else: + output = torch.empty(B, Hq, D, dtype=torch.float32, device=device) + if buf_holder is not None: + buf_holder._tq_output_buf = output + if lse_buf is not None and lse_buf.shape[0] >= B: + lse = lse_buf[:B, :Hq] + else: + lse = torch.empty(B, Hq, dtype=torch.float32, device=device) + if buf_holder is not None: + buf_holder._tq_lse_buf = lse + + grid2 = (B, Hq) + _fwd_kernel_stage2[grid2]( + mid_o, + output, + lse, + seq_lens, + mid_o.stride(0), + mid_o.stride(1), + mid_o.stride(2), + output.stride(0), + output.stride(1), + lse.stride(0), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=cfg["BLOCK_D"], + Lv=D, + num_warps=4, + num_stages=2, + ) + + return output.to(query.dtype) diff --git a/vllm/v1/attention/ops/triton_turboquant_store.py b/vllm/v1/attention/ops/triton_turboquant_store.py new file mode 100644 index 000000000000..3da3347d5df5 --- /dev/null +++ b/vllm/v1/attention/ops/triton_turboquant_store.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Fused Triton kernels for TurboQuant KV store. + +Two kernels: +1. _tq_fused_store_fp8: FP8 key scatter + value uniform quantization. +2. _tq_fused_store_mse: Fused binary-search bucketize + MSE index + packing + value quantization. + +The launcher `triton_turboquant_store` selects the appropriate kernel. +""" + +import math + +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.attention.ops.triton_turboquant_decode import _use_fp8_e4b15 + +# ═══════════════════════════════════════════════════════════════════════ +# Shared: value uniform quantization + pack + scale/zero store +# ═══════════════════════════════════════════════════════════════════════ + + +@triton.jit +def _store_quantized_value( + Value_ptr, + KV_cache_ptr, + base, # pid * D offset into Value_ptr + slot_base, # byte offset into KV_cache_ptr for this slot+head + d_offs, # tl.arange(0, BLOCK_D) + d_mask, # d_offs < D + D: tl.constexpr, + KPS: tl.constexpr, + VQB: tl.constexpr, + VAL_DATA_BYTES: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_VAL: tl.constexpr, + BLOCK_GRP: tl.constexpr, +): + """Uniform quantization of values to VQB bits, pack, and store with scale/zero.""" + val_cache_offset = KPS + + if VQB == 3: + val_vec = tl.load(Value_ptr + base + d_offs, mask=d_mask, other=0.0).to( + tl.float32 + ) + val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0) + val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0) + v_scale = (val_max - val_min) / 7.0 + v_scale = tl.where(v_scale > 1e-8, v_scale, 1e-8) + + q_vals = tl.minimum( + tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 7 + ) + + grp_offs = tl.arange(0, BLOCK_GRP) + grp_mask = grp_offs < (D // 8) + q_grp = tl.reshape(q_vals, [BLOCK_GRP, 8]) + shifts_3bit = tl.arange(0, 8) * 3 + packed_24 = tl.sum(q_grp << shifts_3bit[None, :], axis=1) + b0 = (packed_24 & 0xFF).to(tl.uint8) + b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8) + b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8) + tl.store( + KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3, + b0, + mask=grp_mask, + ) + tl.store( + KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3 + 1, + b1, + mask=grp_mask, + ) + tl.store( + KV_cache_ptr + slot_base + val_cache_offset + grp_offs * 3 + 2, + b2, + mask=grp_mask, + ) + + sc_offset = val_cache_offset + VAL_DATA_BYTES + sc_f16 = v_scale.to(tl.float16) + sc_u16 = sc_f16.to(tl.uint16, bitcast=True) + tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8)) + tl.store( + KV_cache_ptr + slot_base + sc_offset + 1, + ((sc_u16 >> 8) & 0xFF).to(tl.uint8), + ) + zr_f16 = val_min.to(tl.float16) + zr_u16 = zr_f16.to(tl.uint16, bitcast=True) + tl.store(KV_cache_ptr + slot_base + sc_offset + 2, (zr_u16 & 0xFF).to(tl.uint8)) + tl.store( + KV_cache_ptr + slot_base + sc_offset + 3, + ((zr_u16 >> 8) & 0xFF).to(tl.uint8), + ) + + else: # VQB == 4 + val_vec = tl.load(Value_ptr + base + d_offs, mask=d_mask, other=0.0).to( + tl.float32 + ) + val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0) + val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0) + v_scale = (val_max - val_min) / 15.0 + v_scale = tl.where(v_scale > 1e-8, v_scale, 1e-8) + + # Quantize all D elements from register (no re-load) + q_all = tl.minimum( + tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 15 + ) + # Reshape to pairs and pack two 4-bit values per byte + q_pairs = tl.reshape(q_all, [BLOCK_D // 2, 2]) + shifts_4 = tl.arange(0, 2) * 4 + packed_val = tl.sum((q_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8) + val_offs = tl.arange(0, BLOCK_D // 2) + val_mask = val_offs < VAL_DATA_BYTES + tl.store( + KV_cache_ptr + slot_base + val_cache_offset + val_offs, + packed_val, + mask=val_mask, + ) + + sc_offset = val_cache_offset + VAL_DATA_BYTES + sc_f16 = v_scale.to(tl.float16) + sc_u16 = sc_f16.to(tl.uint16, bitcast=True) + tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8)) + tl.store( + KV_cache_ptr + slot_base + sc_offset + 1, + ((sc_u16 >> 8) & 0xFF).to(tl.uint8), + ) + zr_f16 = val_min.to(tl.float16) + zr_u16 = zr_f16.to(tl.uint16, bitcast=True) + tl.store(KV_cache_ptr + slot_base + sc_offset + 2, (zr_u16 & 0xFF).to(tl.uint8)) + tl.store( + KV_cache_ptr + slot_base + sc_offset + 3, + ((zr_u16 >> 8) & 0xFF).to(tl.uint8), + ) + + +# ═══════════════════════════════════════════════════════════════════════ +# FP8 key store + value uniform quantization +# ═══════════════════════════════════════════════════════════════════════ + + +@triton.jit +def _tq_fused_store_fp8( + Key_ptr, # [NH, D] float16/bfloat16 — raw keys + Value_ptr, # [NH, D] float16/bfloat16 — raw values + KV_cache_ptr, # [total_bytes] uint8 (flattened view) + Slot_mapping_ptr, # [N] int32 — per-token slot indices + # Cache strides (for computing byte offsets) + stride_cache_block: tl.constexpr, + stride_cache_pos: tl.constexpr, + stride_cache_head: tl.constexpr, + # Dimensions + D: tl.constexpr, + H: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_D: tl.constexpr, + # TQ layout + KPS: tl.constexpr, + # Value quantization + VQB: tl.constexpr, + VAL_DATA_BYTES: tl.constexpr, + # Packing block sizes + BLOCK_VAL: tl.constexpr, + BLOCK_GRP: tl.constexpr = 16, + FP8_E4B15: tl.constexpr = 0, # 1 = e4b15 (Ampere/Ada), 0 = e4nv (Hopper+) +): + """FP8 key cast+scatter + value uniform quantization.""" + pid = tl.program_id(0) + token_idx = pid // H + head_idx = pid % H + + slot = tl.load(Slot_mapping_ptr + token_idx) + if slot < 0: + return + blk = slot // BLOCK_SIZE + off = slot % BLOCK_SIZE + slot_base = ( + blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head + ) + + base = pid * D + + # ── FP8 KEY: cast to FP8 in-kernel and store ───────────────── + d_offs = tl.arange(0, BLOCK_D) + d_mask = d_offs < D + k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0) + k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv) + k_bytes = k_fp8.to(tl.uint8, bitcast=True) + tl.store(KV_cache_ptr + slot_base + d_offs, k_bytes, mask=d_mask) + + # ── VALUE QUANTIZE + PACK ─────────────────────────────────────── + _store_quantized_value( + Value_ptr, + KV_cache_ptr, + base, + slot_base, + d_offs, + d_mask, + D=D, + KPS=KPS, + VQB=VQB, + VAL_DATA_BYTES=VAL_DATA_BYTES, + BLOCK_D=BLOCK_D, + BLOCK_VAL=BLOCK_VAL, + BLOCK_GRP=BLOCK_GRP, + ) + + +# ═══════════════════════════════════════════════════════════════════════ +# Fused MSE store: bucketize + MSE index pack + norm store + value pack +# (eliminates 4 PyTorch kernel launches per layer vs pack-only kernel) +# ═══════════════════════════════════════════════════════════════════════ + + +@triton.jit +def _tq_fused_store_mse( + # Post-rotation inputs + Y_ptr, # [NH, D] float32 — rotated normalized keys (x_hat @ PiT) + Norms_ptr, # [NH] float32 — key vector norms (||k||) + Value_ptr, # [NH, D] float32 — raw values + # Quantization tables + Midpoints_ptr, # [n_centroids-1] float32 + # Cache and indexing + KV_cache_ptr, # [total_bytes] uint8 (flattened view) + Slot_mapping_ptr, # [N] int32 — per-token slot indices + # Cache strides + stride_cache_block: tl.constexpr, + stride_cache_pos: tl.constexpr, + stride_cache_head: tl.constexpr, + # Dimensions + D: tl.constexpr, + H: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_D: tl.constexpr, + # TQ layout + MSE_BYTES: tl.constexpr, + KPS: tl.constexpr, + # Value quantization + VQB: tl.constexpr, + VAL_DATA_BYTES: tl.constexpr, + # Packing block sizes + BLOCK_VAL: tl.constexpr, + # MSE params + MSE_BITS: tl.constexpr, + N_CENTROIDS: tl.constexpr, + BLOCK_GRP: tl.constexpr = 16, +): + """Fused MSE quantize + pack + store. + + Performs binary-search bucketize, MSE index packing, norm storage, + and value quantization in one kernel. + """ + pid = tl.program_id(0) + token_idx = pid // H + head_idx = pid % H + + slot = tl.load(Slot_mapping_ptr + token_idx) + if slot < 0: + return + blk = slot // BLOCK_SIZE + off = slot % BLOCK_SIZE + slot_base = ( + blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head + ) + + base = pid * D + d_offs = tl.arange(0, BLOCK_D) + d_mask = d_offs < D + + # ── 1. BINARY SEARCH BUCKETIZE ─────────────────────────────────── + # Midpoints are sorted (N_CENTROIDS-1 values); binary search finds + # insertion point in MSE_BITS iterations vs N_CENTROIDS-1 for linear. + y_vec = tl.load(Y_ptr + base + d_offs, mask=d_mask, other=0.0) + lo = tl.zeros([BLOCK_D], dtype=tl.int32) + hi = tl.full([BLOCK_D], N_CENTROIDS - 1, dtype=tl.int32) + for _ in range(MSE_BITS): + mid = (lo + hi) >> 1 + # Clamp to valid midpoint index [0, N_CENTROIDS-2] for load safety; + # the search result (lo) is still correct since converged lanes + # don't change. + safe_mid = tl.minimum(mid, N_CENTROIDS - 2) + mid_val = tl.load(Midpoints_ptr + safe_mid, mask=d_mask, other=0.0) + lo = tl.where(y_vec >= mid_val, mid + 1, lo) + hi = tl.where(y_vec >= mid_val, hi, mid) + idx = tl.minimum(lo, N_CENTROIDS - 1) + + # ── 2. PACK MSE INDICES from register idx ───────────────────────── + if MSE_BITS == 4: + idx_pairs = tl.reshape(idx, [BLOCK_D // 2, 2]) + shifts_4 = tl.arange(0, 2) * 4 + packed = tl.sum((idx_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8) + mse_offs = tl.arange(0, BLOCK_D // 2) + mse_mask = mse_offs < MSE_BYTES + tl.store(KV_cache_ptr + slot_base + mse_offs, packed, mask=mse_mask) + + elif MSE_BITS == 3: + grp_offs = tl.arange(0, BLOCK_GRP) + grp_mask = grp_offs < (D // 8) + idx_grp = tl.reshape(idx, [BLOCK_GRP, 8]) + shifts_3 = tl.arange(0, 8) * 3 + packed_24 = tl.sum((idx_grp & 0x7) << shifts_3[None, :], axis=1) + b0 = (packed_24 & 0xFF).to(tl.uint8) + b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8) + b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8) + tl.store(KV_cache_ptr + slot_base + grp_offs * 3, b0, mask=grp_mask) + tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 1, b1, mask=grp_mask) + tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 2, b2, mask=grp_mask) + + # ── 3. STORE vec_norm (fp16, 2 bytes) ───────────────────────────── + norm_offset = MSE_BYTES + + vn_f16 = tl.load(Norms_ptr + pid).to(tl.float16) + vn_u16 = vn_f16.to(tl.uint16, bitcast=True) + tl.store(KV_cache_ptr + slot_base + norm_offset, (vn_u16 & 0xFF).to(tl.uint8)) + tl.store( + KV_cache_ptr + slot_base + norm_offset + 1, ((vn_u16 >> 8) & 0xFF).to(tl.uint8) + ) + + # ── 4. VALUE QUANTIZE + PACK ────────────────────────────────────── + _store_quantized_value( + Value_ptr, + KV_cache_ptr, + base, + slot_base, + d_offs, + d_mask, + D=D, + KPS=KPS, + VQB=VQB, + VAL_DATA_BYTES=VAL_DATA_BYTES, + BLOCK_D=BLOCK_D, + BLOCK_VAL=BLOCK_VAL, + BLOCK_GRP=BLOCK_GRP, + ) + + +# ═══════════════════════════════════════════════════════════════════════ +# Launcher +# ═══════════════════════════════════════════════════════════════════════ + + +def triton_turboquant_store( + key: torch.Tensor, # [N, H, D] — raw keys (post-RoPE) + value: torch.Tensor, # [N, H, D] — raw values + kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8 + slot_mapping: torch.Tensor, # [N] int32 + PiT: torch.Tensor, # [D, D] float32 + midpoints: torch.Tensor, # [n_centroids-1] float32 + mse_bits: int, + key_packed_size: int, + value_quant_bits: int, + key_fp8: bool = False, +): + """Launch TQ store kernel (FP8 or MSE path).""" + N, H, D = key.shape + NH = N * H + block_size = kv_cache.shape[1] + BLOCK_D = triton.next_power_of_2(D) + mse_bytes = math.ceil(D * mse_bits / 8) + n_centroids = 2**mse_bits + + val_data_bytes = math.ceil(D * value_quant_bits / 8) + + BLOCK_VAL = triton.next_power_of_2(val_data_bytes) + + # Cache strides (element_size=1 for uint8, so stride in bytes = stride()) + stride_block = kv_cache.stride(0) + stride_pos = kv_cache.stride(1) + stride_head = kv_cache.stride(2) + + block_grp = triton.next_power_of_2(D // 8) if D >= 8 else 1 + + # ── FP8 PATH: in-kernel FP8 cast + scatter via fp8 kernel ── + if key_fp8: + k_flat = key.reshape(NH, D).contiguous() + v_flat = value.reshape(NH, D).contiguous() + + fp8_e4b15 = _use_fp8_e4b15(key.device.index or 0) + + grid = (NH,) + _tq_fused_store_fp8[grid]( + k_flat, + v_flat, + kv_cache.view(-1), + slot_mapping, + stride_cache_block=stride_block, + stride_cache_pos=stride_pos, + stride_cache_head=stride_head, + D=D, + H=H, + BLOCK_SIZE=block_size, + BLOCK_D=BLOCK_D, + KPS=key_packed_size, + VQB=value_quant_bits, + VAL_DATA_BYTES=val_data_bytes, + BLOCK_VAL=BLOCK_VAL, + BLOCK_GRP=block_grp, + FP8_E4B15=fp8_e4b15, + num_warps=4, + num_stages=1, + ) + return + + # ── MSE PATH: external GEMM + fused bucketize/pack kernel ── + # Normalize + rotation GEMM externally (cuBLAS is faster than in-kernel) + k_flat = key.float().reshape(NH, D) + norms = k_flat.norm(dim=1, keepdim=True) + x_hat = k_flat / (norms + 1e-8) + y = x_hat @ PiT + + v_flat = value.float().reshape(NH, D) + + # Fused kernel: bucketize + MSE index pack + norm store + value pack + grid = (NH,) + _tq_fused_store_mse[grid]( + y, + norms.squeeze(1), + v_flat, + midpoints, + kv_cache.view(-1), + slot_mapping, + stride_cache_block=stride_block, + stride_cache_pos=stride_pos, + stride_cache_head=stride_head, + D=D, + H=H, + BLOCK_SIZE=block_size, + BLOCK_D=BLOCK_D, + MSE_BYTES=mse_bytes, + KPS=key_packed_size, + VQB=value_quant_bits, + VAL_DATA_BYTES=val_data_bytes, + BLOCK_VAL=BLOCK_VAL, + MSE_BITS=mse_bits, + N_CENTROIDS=n_centroids, + BLOCK_GRP=block_grp, + num_warps=4, + num_stages=1, + ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index fa5395685e6a..30061462008f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -21,6 +21,7 @@ MLAAttentionSpec, SinkFullAttentionSpec, SlidingWindowSpec, + TQFullAttentionSpec, ) from vllm.v1.request import Request @@ -209,7 +210,7 @@ def allocate_new_computed_blocks( cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks) ) req_blocks.extend(allocated_blocks) - if type(self.kv_cache_spec) is FullAttentionSpec: + if type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec): self.new_block_ids.extend(b.block_id for b in allocated_blocks) def allocate_new_blocks( @@ -237,7 +238,7 @@ def allocate_new_blocks( else: new_blocks = self.block_pool.get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) - if type(self.kv_cache_spec) is FullAttentionSpec: + if type(self.kv_cache_spec) in (FullAttentionSpec, TQFullAttentionSpec): self.new_block_ids.extend(b.block_id for b in new_blocks) return new_blocks @@ -1114,6 +1115,7 @@ def __init__( spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, + TQFullAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6f8ad8e7d8ef..8aed95ddd0eb 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -245,6 +245,32 @@ def real_page_size_bytes(self) -> int: ) +@dataclass(frozen=True, kw_only=True) +class TQFullAttentionSpec(FullAttentionSpec): + """FullAttentionSpec with TQ-aware page size. + + Python equivalent of the C++ TQ4FullAttentionSpec. Overrides + real_page_size_bytes to use TQ slot bytes instead of the raw + head_size * dtype formula. + """ + + tq_slot_size: int = 0 + + @property + def real_page_size_bytes(self) -> int: + if self.tq_slot_size > 0: + return self.block_size * self.num_kv_heads * self.tq_slot_size + return super().real_page_size_bytes + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + merged = super().merge(specs) + assert all(s.tq_slot_size == specs[0].tq_slot_size for s in specs), ( + "All TQ layers in the same KV cache group must use the same tq_slot_size." + ) + return replace(merged, tq_slot_size=specs[0].tq_slot_size) + + @dataclass(frozen=True, kw_only=True) class MLAAttentionSpec(FullAttentionSpec): # TODO(Lucas/Chen): less hacky way to do this diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 83fc12cb5c3b..5780624c2226 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -120,7 +120,7 @@ def init_meta( for group in attn_groups_iter: spec = group.kv_cache_spec - if type(spec) is not FullAttentionSpec: + if not isinstance(spec, FullAttentionSpec): continue if group.kv_cache_group_id >= len(kernel_block_sizes): continue