diff --git a/tests/quantization/test_turboquant.py b/tests/quantization/test_turboquant.py index 78c137e67628..5378db4800cd 100644 --- a/tests/quantization/test_turboquant.py +++ b/tests/quantization/test_turboquant.py @@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.turboquant.centroids import ( get_centroids, + get_residual_scale, solve_lloyd_max, ) from vllm.model_executor.layers.quantization.turboquant.config import ( @@ -47,14 +48,25 @@ def _is_power_of_2(n: int) -> bool: key_mse_bits=0, value_quant_bits=4, mse_bits=4, n_centroids=16, centroid_bits=4, norm_correction=False, + use_qjl_residual=False, key_packed_size=128, value_packed_size=68, slot_size=196, slot_size_aligned=196, ), + "turboquant_4bit": dict( + key_fp8=False, key_quant_bits=3, + key_mse_bits=3, value_quant_bits=4, + mse_bits=3, n_centroids=8, centroid_bits=3, + norm_correction=True, + use_qjl_residual=True, + key_packed_size=64, value_packed_size=68, + slot_size=132, slot_size_aligned=132, + ), "turboquant_4bit_nc": dict( key_fp8=False, key_quant_bits=4, key_mse_bits=4, value_quant_bits=4, mse_bits=4, n_centroids=16, centroid_bits=4, norm_correction=True, + use_qjl_residual=False, key_packed_size=66, value_packed_size=68, slot_size=134, slot_size_aligned=134, ), @@ -63,14 +75,25 @@ def _is_power_of_2(n: int) -> bool: key_mse_bits=3, value_quant_bits=4, mse_bits=3, n_centroids=8, centroid_bits=3, norm_correction=True, + use_qjl_residual=False, key_packed_size=50, value_packed_size=68, slot_size=118, slot_size_aligned=118, ), + "turboquant_3bit": dict( + key_fp8=False, key_quant_bits=2, + key_mse_bits=2, value_quant_bits=3, + mse_bits=2, n_centroids=4, centroid_bits=2, + norm_correction=True, + use_qjl_residual=True, + key_packed_size=44, value_packed_size=52, + slot_size=96, slot_size_aligned=96, + ), "turboquant_3bit_nc": dict( key_fp8=False, key_quant_bits=3, key_mse_bits=3, value_quant_bits=3, mse_bits=3, n_centroids=8, centroid_bits=3, norm_correction=True, + use_qjl_residual=False, key_packed_size=50, value_packed_size=52, slot_size=102, slot_size_aligned=102, ), @@ -122,6 +145,11 @@ def test_norm_correction(self, preset): cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) assert cfg.norm_correction is PRESET_EXPECTED[preset]["norm_correction"] + @pytest.mark.parametrize("preset", ALL_PRESETS) + def test_qjl_mode(self, preset): + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + assert cfg.use_qjl_residual is PRESET_EXPECTED[preset]["use_qjl_residual"] + @pytest.mark.parametrize("preset", ALL_PRESETS) def test_packed_sizes(self, preset): cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) @@ -171,7 +199,7 @@ def test_mse_key_or_fp8_exclusive(self, preset): assert cfg.key_quant_bits == 8 else: assert cfg.key_mse_bits > 0 - assert cfg.key_quant_bits in (3, 4) + assert cfg.key_quant_bits in (2, 3, 4) @pytest.mark.parametrize("preset", ALL_PRESETS) @pytest.mark.parametrize("head_dim", [64, 96, 128, 256]) @@ -460,7 +488,7 @@ class TestStoreDecodeRoundTrip: @pytest.mark.parametrize( "preset", - ["turboquant_k8v4", "turboquant_4bit_nc"], + ["turboquant_k8v4", "turboquant_4bit_nc", "turboquant_k3v4_nc"], ) def test_single_token_roundtrip(self, preset): """Store 1 token, decode with query=key, check attention output. @@ -494,6 +522,17 @@ def test_single_token_roundtrip(self, preset): H = _build_hadamard(D, "cuda") PiT = (signs.unsqueeze(1) * H).contiguous().float() Pi = PiT.T.contiguous() + if cfg.use_qjl_residual: + qjl_signs = generate_wht_signs(D, seed=43, device=device) + PhiT = (qjl_signs.unsqueeze(1) * H).contiguous().float() + qjl_scale = torch.tensor( + get_residual_scale(D, cfg.key_mse_bits), + device=device, + dtype=torch.float32, + ) + else: + PhiT = None + qjl_scale = None # Generate centroids centroids, _ = solve_lloyd_max(D, cfg.centroid_bits) @@ -526,10 +565,12 @@ def test_single_token_roundtrip(self, preset): slot_mapping, PiT, midpoints, + PhiT, mse_bits=cfg.key_mse_bits, key_packed_size=cfg.key_packed_size, value_quant_bits=cfg.effective_value_quant_bits, key_fp8=cfg.key_fp8, + qjl_residual_bits=cfg.qjl_residual_bits, ) # Decode: use key as query so attention = softmax([1]) * V = V @@ -551,6 +592,8 @@ def test_single_token_roundtrip(self, preset): key_fp8=cfg.key_fp8, norm_correction=cfg.norm_correction, PiT=PiT, + PhiT=PhiT, + qjl_scale=qjl_scale, max_num_kv_splits=4, ) diff --git a/tests/quantization/test_turboquant_reference.py b/tests/quantization/test_turboquant_reference.py new file mode 100644 index 000000000000..56b57a3bf616 --- /dev/null +++ b/tests/quantization/test_turboquant_reference.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.v1.attention.ops.turboquant_kv_cache import ( + apply_turboquant_query_transforms, + build_turboquant_outlier_masks, + canonicalize_turboquant_dtype, + dequantize_turboquant_vectors, + get_turboquant_bits, + get_turboquant_centroids, + get_turboquant_group_dims, + get_turboquant_layout, + get_turboquant_packed_dim, + get_turboquant_qjl_matrix, + get_turboquant_qjl_inverse_transform_matrix, + get_turboquant_rotation, + get_turboquant_mse_inverse_transform_matrix, + get_turboquant_mse_to_qjl_matrix, + pack_turboquant_indices, + quantize_turboquant_vectors, + unpack_turboquant_indices, + validate_turboquant_group_indices, +) +from vllm.v1.attention.ops.turboquant_metadata import ( + TurboQuantTensorMetadata, + build_default_turboquant_metadata, + discover_turboquant_metadata_path, + load_turboquant_metadata, + save_turboquant_metadata, +) +from vllm.v1.attention.ops.triton_turboquant_decode import ( + triton_turboquant_decode_attention, +) +from vllm.v1.attention.ops.triton_turboquant_store import triton_turboquant_store + + +def test_turboquant_aliases_match_reference_recipes(): + assert canonicalize_turboquant_dtype("turboquant_3bit") == "turboquant25" + assert canonicalize_turboquant_dtype("turboquant_4bit") == "turboquant35" + assert get_turboquant_bits("turboquant_3bit") == 2.5 + assert get_turboquant_bits("turboquant_4bit") == 3.5 + + +def test_turboquant_layout_is_consistent(): + layout = get_turboquant_layout("turboquant_4bit", 128) + high_dim, low_dim = get_turboquant_group_dims(128, "turboquant_4bit") + assert (high_dim, low_dim) == (64, 64) + assert layout.groups[0].dim == 64 + assert layout.groups[0].mse_bits == 3 + assert layout.groups[1].dim == 64 + assert layout.groups[1].mse_bits == 2 + assert layout.packed_dim == get_turboquant_packed_dim(128, "turboquant35") + + +def test_pack_unpack_roundtrip(): + values = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.uint8) + packed = pack_turboquant_indices(values, 2) + unpacked = unpack_turboquant_indices(packed, values.shape[-1], 2) + assert torch.equal(values, unpacked) + + +def test_quantize_dequantize_reference_path_shapes(): + torch.manual_seed(0) + x = torch.randn(3, 2, 128, dtype=torch.float32) + high_idx, low_idx = build_turboquant_outlier_masks(x, "turboquant_3bit") + device = torch.device("cpu") + rotations = ( + get_turboquant_rotation(device, high_idx.shape[-1]), + get_turboquant_rotation(device, low_idx.shape[-1]), + ) + qjl_matrices = ( + get_turboquant_qjl_matrix(device, high_idx.shape[-1]), + get_turboquant_qjl_matrix(device, low_idx.shape[-1]), + ) + centroids = { + 1: get_turboquant_centroids(device, low_idx.shape[-1], 1), + 2: get_turboquant_centroids(device, high_idx.shape[-1], 2), + } + packed = quantize_turboquant_vectors( + x, "turboquant_3bit", rotations, qjl_matrices, centroids, (high_idx, low_idx) + ) + restored = dequantize_turboquant_vectors( + packed, + "turboquant_3bit", + 128, + rotations, + qjl_matrices, + centroids, + (high_idx, low_idx), + x.dtype, + ) + assert packed.shape[-1] == get_turboquant_layout("turboquant_3bit", 128).packed_dim + assert restored.shape == x.shape + assert torch.isfinite(restored).all() + + +def test_turboquant_metadata_roundtrip(tmp_path): + metadata = build_default_turboquant_metadata( + recipe="turboquant_4bit", + head_size=128, + num_kv_heads=2, + layer_names=["model.layers.0.self_attn"], + model_name="tests/turboquant", + ) + path = tmp_path / "turboquant_kv.json" + save_turboquant_metadata(metadata, path) + loaded = load_turboquant_metadata(str(path)) + assert loaded.recipe == "turboquant35" + assert loaded.get_layer("language_model.model.layers.0.self_attn.attn") == ( + loaded.layers["model.layers.0.self_attn"] + ) + assert discover_turboquant_metadata_path(str(tmp_path), None) == str(path.resolve()) + + +def test_turboquant_tensor_metadata_group_indices_shape(): + metadata = TurboQuantTensorMetadata( + high_precision_indices=((0, 1, 2, 3, 4, 5, 6, 7),) + ) + high, low = metadata.get_group_indices( + device=torch.device("cpu"), + head_size=32, + kv_cache_dtype="turboquant_3bit", + ) + assert high.shape == (1, 8) + assert low.shape == (1, 24) + + +def test_grouped_op_layer_store_decode_reference_path(): + torch.manual_seed(0) + device = torch.device("cpu") + key = torch.randn(1, 2, 128, dtype=torch.float32, device=device) + value = torch.randn(1, 2, 128, dtype=torch.float32, device=device) + high_idx, low_idx = build_turboquant_outlier_masks(key, "turboquant_3bit") + rotations = ( + get_turboquant_rotation(device, high_idx.shape[-1]), + get_turboquant_rotation(device, low_idx.shape[-1]), + ) + qjl_matrices = ( + get_turboquant_qjl_matrix(device, high_idx.shape[-1]), + get_turboquant_qjl_matrix(device, low_idx.shape[-1]), + ) + centroids = { + 1: get_turboquant_centroids(device, low_idx.shape[-1], 1), + 2: get_turboquant_centroids(device, high_idx.shape[-1], 2), + } + layout = get_turboquant_layout("turboquant_3bit", 128) + value_bytes = (128 * 3 + 7) // 8 + 4 + kv_cache = torch.zeros(1, 16, 2, layout.packed_dim + value_bytes, dtype=torch.uint8) + slot_mapping = torch.tensor([0], dtype=torch.int32) + triton_turboquant_store( + key=key, + value=value, + kv_cache=kv_cache, + slot_mapping=slot_mapping, + PiT=None, + midpoints=None, + PhiT=None, + mse_bits=2, + key_packed_size=layout.packed_dim, + value_quant_bits=3, + grouped_recipe="turboquant_3bit", + group_rotations=rotations, + group_qjl=qjl_matrices, + group_centroids=centroids, + group_indices=(high_idx, low_idx), + ) + output = triton_turboquant_decode_attention( + query=key, + kv_cache=kv_cache, + block_table=torch.tensor([[0]], dtype=torch.int32), + seq_lens=torch.tensor([1], dtype=torch.int32), + Pi=None, + centroids=None, + scale=1.0 / (128**0.5), + mse_bits=2, + key_packed_size=layout.packed_dim, + value_quant_bits=3, + grouped_recipe="turboquant_3bit", + group_rotations=rotations, + group_qjl=qjl_matrices, + group_centroids=centroids, + group_indices=(high_idx, low_idx), + ) + assert output.shape == value.shape + assert torch.isfinite(output).all() + + +def test_grouped_query_transforms_shapes(): + torch.manual_seed(0) + query = torch.randn(2, 4, 128, dtype=torch.float32) + group0 = torch.arange(64, dtype=torch.int64).repeat(2, 1) + group1 = torch.arange(64, 128, dtype=torch.int64).repeat(2, 1) + kv_head_for_query_head = torch.tensor([0, 0, 1, 1], dtype=torch.int64) + rotations = ( + get_turboquant_rotation(torch.device("cpu"), 64), + get_turboquant_rotation(torch.device("cpu"), 64, seed_offset=1), + ) + qjl = ( + get_turboquant_qjl_matrix(torch.device("cpu"), 64), + get_turboquant_qjl_matrix(torch.device("cpu"), 64, seed_offset=1), + ) + (q_rot0, q_rot1), (q_qjl0, q_qjl1) = apply_turboquant_query_transforms( + query, + (group0, group1), + rotations, + qjl, + kv_head_for_query_head=kv_head_for_query_head, + ) + assert q_rot0.shape == (2, 4, 64) + assert q_rot1.shape == (2, 4, 64) + assert q_qjl0.shape == (2, 4, 64) + assert q_qjl1.shape == (2, 4, 64) + + +def test_transform_matrix_helpers_shapes(): + device = torch.device("cpu") + mse_inv = get_turboquant_mse_inverse_transform_matrix(device, 64) + qjl_inv = get_turboquant_qjl_inverse_transform_matrix(device, 64) + mse_to_qjl = get_turboquant_mse_to_qjl_matrix(device, 64) + assert mse_inv.shape == (64, 64) + assert qjl_inv.shape == (64, 64) + assert mse_to_qjl.shape == (64, 64) + + +def test_group_index_validation_rejects_head_mismatch(): + x = torch.randn(1, 2, 128, dtype=torch.float32) + bad_group0 = torch.arange(64, dtype=torch.int64).repeat(3, 1) + bad_group1 = torch.arange(64, 128, dtype=torch.int64).repeat(3, 1) + try: + validate_turboquant_group_indices(x, (bad_group0, bad_group1)) + except ValueError as exc: + assert "KV head count" in str(exc) + else: + raise AssertionError("Expected validate_turboquant_group_indices to fail") diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 47a655f22d53..63fdfd838ef2 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -25,8 +25,10 @@ "fp8_inc", "fp8_ds_mla", "turboquant_k8v4", + "turboquant_4bit", "turboquant_4bit_nc", "turboquant_k3v4_nc", + "turboquant_3bit", "turboquant_3bit_nc", "int8_per_token_head", "fp8_per_token_head", diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index a92e2f4ad188..546e8c6cb658 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -407,6 +407,7 @@ def _init_turboquant_buffers( """Initialize TurboQuant rotation/projection matrices and centroids.""" from vllm.model_executor.layers.quantization.turboquant.centroids import ( get_centroids, + get_residual_scale, ) from vllm.model_executor.layers.quantization.turboquant.config import ( TurboQuantConfig, @@ -414,8 +415,14 @@ def _init_turboquant_buffers( from vllm.model_executor.layers.quantization.turboquant.quantizer import ( generate_wht_signs, ) + from vllm.v1.attention.ops.turboquant_metadata import ( + build_default_turboquant_metadata, + discover_turboquant_metadata_path, + load_turboquant_metadata, + ) tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype, head_size) + _vllm_cfg = get_current_vllm_config() # Each layer needs a unique rotation matrix so quantization errors # don't correlate across layers. Stride must exceed max head_dim to @@ -427,21 +434,78 @@ def _init_turboquant_buffers( layer_idx = extract_layer_index(prefix) seed = tq_config.seed + layer_idx * _TQ_LAYER_SEED_STRIDE - self.register_buffer( - "_tq_signs", - generate_wht_signs(head_size, seed=seed), - ) - self.register_buffer( - "_tq_centroids", - get_centroids(head_size, tq_config.centroid_bits), - ) + if tq_config.use_grouped_layout: + metadata_path = discover_turboquant_metadata_path( + _vllm_cfg.model_config.model if _vllm_cfg.model_config is not None else None, + None, + ) + metadata = None + if metadata_path is not None: + try: + metadata = load_turboquant_metadata(metadata_path) + layer_metadata = metadata.get_layer(prefix) + except (KeyError, ValueError) as exc: + logger.warning_once( + "TurboQuant metadata at %s could not be used for layer %s: %s. " + "Falling back to default grouped indices.", + metadata_path, + prefix, + exc, + ) + metadata = None + if metadata is None: + metadata = build_default_turboquant_metadata( + recipe=cache_dtype, + head_size=head_size, + num_kv_heads=self.num_kv_heads, + layer_names=[prefix], + model_name=( + _vllm_cfg.model_config.model + if _vllm_cfg.model_config is not None + else None + ), + ) + layer_metadata = metadata.get_layer(prefix) + key_high_idx, key_low_idx = layer_metadata.key.get_group_indices( + device=torch.device("cpu"), + head_size=head_size, + kv_cache_dtype=cache_dtype, + ) + self.register_buffer( + "_tq_group0_idx", + key_high_idx, + ) + self.register_buffer( + "_tq_group1_idx", + key_low_idx, + ) + else: + self.register_buffer( + "_tq_signs", + generate_wht_signs(head_size, seed=seed), + ) + self.register_buffer( + "_tq_centroids", + get_centroids(head_size, tq_config.centroid_bits), + ) + if tq_config.use_qjl_residual: + self.register_buffer( + "_tq_qjl_signs", + generate_wht_signs(head_size, seed=seed + 1), + ) + self.register_buffer( + "_tq_qjl_scale", + torch.tensor( + get_residual_scale(head_size, tq_config.key_mse_bits), + dtype=torch.float32, + ), + ) self._tq_config = tq_config # Pre-allocate decode intermediate buffers so model.to(device) moves # them to GPU *before* the memory profiler runs. Without this the # profiler gives all free memory to KV cache blocks and the first # decode OOMs when these buffers are lazily allocated. - _vllm_cfg = get_current_vllm_config() B = _vllm_cfg.scheduler_config.max_num_seqs Hq = self.num_heads S = _vllm_cfg.attention_config.tq_max_kv_splits_for_cuda_graph diff --git a/vllm/model_executor/layers/quantization/turboquant/centroids.py b/vllm/model_executor/layers/quantization/turboquant/centroids.py index 490265747c5b..6104aa823675 100644 --- a/vllm/model_executor/layers/quantization/turboquant/centroids.py +++ b/vllm/model_executor/layers/quantization/turboquant/centroids.py @@ -84,3 +84,29 @@ def get_centroids(d: int, bits: int) -> torch.Tensor: """Get precomputed Lloyd-Max centroids (cached).""" centroids, _ = solve_lloyd_max(d, bits) return centroids + + +@lru_cache(maxsize=32) +def get_residual_scale(d: int, bits: int) -> float: + """Approximate 1-bit residual correction scale for TurboQuant. + + After Lloyd-Max quantization, we model the residual as an approximately + isotropic zero-mean random vector. A 1-bit sign sketch then uses the + Gaussian proxy E|X| = sigma * sqrt(2 / pi) to convert signs back into a + score correction constant. + """ + centroids, boundaries = solve_lloyd_max(d, bits) + sigma2 = 1.0 / d + + def pdf(x: float) -> float: + return _gaussian_pdf(x, sigma2) + + mse = 0.0 + finite_edges = [-10.0 / math.sqrt(d), *boundaries.tolist(), 10.0 / math.sqrt(d)] + for i, c in enumerate(centroids.tolist()): + a = finite_edges[i] + b = finite_edges[i + 1] + mse += _trapz(lambda x: (x - c) ** 2 * pdf(x), a, b, n=400) + + sigma_resid = math.sqrt(max(mse, 1e-12)) + return sigma_resid * math.sqrt(2.0 / math.pi) diff --git a/vllm/model_executor/layers/quantization/turboquant/config.py b/vllm/model_executor/layers/quantization/turboquant/config.py index 289bed120773..c59ff3185d95 100644 --- a/vllm/model_executor/layers/quantization/turboquant/config.py +++ b/vllm/model_executor/layers/quantization/turboquant/config.py @@ -5,29 +5,51 @@ import math from dataclasses import dataclass +from vllm.v1.attention.ops.turboquant_kv_cache import get_turboquant_layout + # Named TQ presets: each maps to frozen config parameters. -# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys. +# key_quant_bits: 8 = FP8 keys, 2-4 = MSE (Lloyd-Max) quantized keys. # value_quant_bits: 3-4 = uniform quantized values. +# +# ``turboquant_4bit`` and ``turboquant_3bit`` are the canonical presets that +# match the official blog's user-facing naming. The older ``*_nc`` names are +# kept as compatibility aliases for existing configs and tests. TQ_PRESETS: dict[str, dict] = { "turboquant_k8v4": { "key_quant_bits": 8, "value_quant_bits": 4, "norm_correction": False, + "qjl_residual_bits": 0, + }, + "turboquant_4bit": { + "key_quant_bits": 3, + "value_quant_bits": 4, + "norm_correction": True, + "qjl_residual_bits": 1, }, "turboquant_4bit_nc": { "key_quant_bits": 4, "value_quant_bits": 4, "norm_correction": True, + "qjl_residual_bits": 0, }, "turboquant_k3v4_nc": { "key_quant_bits": 3, "value_quant_bits": 4, "norm_correction": True, + "qjl_residual_bits": 0, + }, + "turboquant_3bit": { + "key_quant_bits": 2, + "value_quant_bits": 3, + "norm_correction": True, + "qjl_residual_bits": 1, }, "turboquant_3bit_nc": { "key_quant_bits": 3, "value_quant_bits": 3, "norm_correction": True, + "qjl_residual_bits": 0, }, } @@ -36,21 +58,24 @@ class TurboQuantConfig: """Configuration for TurboQuant KV-cache quantization. - Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys - and uniform quantization for values. QJL is intentionally omitted — - community consensus (5+ independent groups) found it hurts attention - quality by amplifying variance through softmax. + Uses the current vLLM TurboQuant serving path: + 1. WHT-rotated Lloyd-Max scalar quantization for keys + 2. optional 1-bit residual sign correction for the official 3-bit / 4-bit + presets + 3. uniform quantization for values Named presets (use via --kv-cache-dtype): turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL + turboquant_4bit: official-style 4-bit keys (3-bit base + 1-bit residual) turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71% turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63% + turboquant_3bit: official-style 3-bit keys (2-bit base + 1-bit residual) turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59% Args: head_dim: Attention head dimension (e.g. 64, 96, 128). key_quant_bits: Bits for key quantization. 8 = FP8 keys (no - rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys. + rotation/MSE). 2-4 = Lloyd-Max MSE quantized keys. value_quant_bits: Bits per value dimension for uniform quantization. 3 = 8 levels, 4 = 16 levels (default). seed: Base seed for deterministic random matrix generation. @@ -61,10 +86,12 @@ class TurboQuantConfig: """ head_dim: int = 128 - key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys + key_quant_bits: int = 3 # 2-4 = MSE keys, 8 = FP8 keys value_quant_bits: int = 4 # 3-4 = uniform quantized values seed: int = 42 norm_correction: bool = False + qjl_residual_bits: int = 0 + preset_name: str | None = None @property def key_fp8(self) -> bool: @@ -83,6 +110,26 @@ def mse_bits(self) -> int: return self.value_quant_bits return self.key_quant_bits + @property + def use_qjl_residual(self) -> bool: + return not self.key_fp8 and self.qjl_residual_bits > 0 + + @property + def use_grouped_layout(self) -> bool: + """Whether this preset uses the grouped official-style layout. + + The canonical ``turboquant_4bit`` / ``turboquant_3bit`` modes map to a + grouped MSE+QJL representation, while legacy ``*_nc`` presets keep the + older flat per-dimension layout. + """ + return self.use_qjl_residual + + @property + def grouped_recipe(self) -> str | None: + if not self.use_grouped_layout: + return None + return self.preset_name + @property def key_mse_bits(self) -> int: """MSE bits actually used for key quantization (0 if FP8 keys).""" @@ -109,12 +156,17 @@ def key_packed_size(self) -> int: TQ mode: - MSE indices: ceil(head_dim * key_mse_bits / 8) bytes - vec_norm: 2 bytes (float16) + - residual: ceil(head_dim * qjl_residual_bits / 8) bytes """ if self.key_fp8: return self.head_dim # 1 byte per element + if self.use_grouped_layout: + assert self.grouped_recipe is not None + return get_turboquant_layout(self.grouped_recipe, self.head_dim).packed_dim mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8) norm_bytes = 2 # vec_norm fp16 - return mse_bytes + norm_bytes + qjl_bytes = math.ceil(self.head_dim * self.qjl_residual_bits / 8) + return mse_bytes + norm_bytes + qjl_bytes @property def effective_value_quant_bits(self) -> int: @@ -168,7 +220,8 @@ def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]: def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig": """Create config from a named preset. - Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc. + Valid presets include turboquant_4bit and turboquant_3bit, along with + compatibility aliases such as turboquant_4bit_nc. """ if cache_dtype not in TQ_PRESETS: valid = ", ".join(TQ_PRESETS.keys()) @@ -182,4 +235,6 @@ def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig": key_quant_bits=preset["key_quant_bits"], value_quant_bits=preset["value_quant_bits"], norm_correction=preset["norm_correction"], + qjl_residual_bits=preset["qjl_residual_bits"], + preset_name=cache_dtype, ) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 26e377de69cb..8e195ade0ffe 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -43,8 +43,10 @@ "fp8_inc": torch.float8_e4m3fn, "fp8_ds_mla": torch.uint8, "turboquant_k8v4": torch.uint8, + "turboquant_4bit": torch.uint8, "turboquant_4bit_nc": torch.uint8, "turboquant_k3v4_nc": torch.uint8, + "turboquant_3bit": torch.uint8, "turboquant_3bit_nc": torch.uint8, } diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index 279fcb04ace4..d07bd7dbf825 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -27,6 +27,12 @@ from vllm.config import get_current_vllm_config from vllm.config.cache import CacheDType from vllm.triton_utils import triton +from vllm.v1.attention.ops.turboquant_kv_cache import ( + get_turboquant_centroids, + get_turboquant_layout, + get_turboquant_qjl_matrix, + get_turboquant_rotation, +) from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -91,8 +97,10 @@ class TurboQuantAttentionBackend(AttentionBackend): ] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "turboquant_k8v4", + "turboquant_4bit", "turboquant_4bit_nc", "turboquant_k3v4_nc", + "turboquant_3bit", "turboquant_3bit_nc", ] @@ -126,7 +134,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - cache_dtype_str: str = "turboquant_4bit_nc", + cache_dtype_str: str = "turboquant_4bit", ) -> tuple[int, ...]: """Combined K+V cache shape — no leading 2 dimension. @@ -285,19 +293,53 @@ def _ensure_on_device(self, layer, device): correct device via register_buffer + model.to(device). """ if not hasattr(layer, "_tq_cached"): - D = layer._tq_signs.shape[0] - signs = layer._tq_signs.to(device=device, dtype=torch.float32) - - # WHT rotation: orthonormal + self-inverse, enabling future - # in-kernel butterfly fusion and trivial inverse for continuation. - H = _build_hadamard(D, str(device)) - layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous() - layer._tq_Pi = layer._tq_PiT.T.contiguous() - - c = layer._tq_centroids.to(device=device, dtype=torch.float32) - # Precompute midpoints for threshold-based quantization - c_sorted, _ = c.sort() - layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2 + if self.tq_config.use_grouped_layout: + group0_idx = layer._tq_group0_idx.to(device=device) + group1_idx = layer._tq_group1_idx.to(device=device) + group_dims = (group0_idx.shape[-1], group1_idx.shape[-1]) + layer._tq_group_indices = ( + group0_idx.contiguous(), + group1_idx.contiguous(), + ) + layer._tq_group_rotations = ( + get_turboquant_rotation(device, group_dims[0], seed_offset=101), + get_turboquant_rotation(device, group_dims[1], seed_offset=211), + ) + layer._tq_group_qjl = ( + get_turboquant_qjl_matrix(device, group_dims[0], seed_offset=307), + get_turboquant_qjl_matrix(device, group_dims[1], seed_offset=401), + ) + recipe = self.tq_config.grouped_recipe + assert recipe is not None + layout = get_turboquant_layout(recipe, self.head_size) + layer._tq_group_centroids = { + group.mse_bits: get_turboquant_centroids( + device, group.dim, group.mse_bits + ) + for group in layout.groups + } + else: + D = layer._tq_signs.shape[0] + signs = layer._tq_signs.to(device=device, dtype=torch.float32) + + # WHT rotation: orthonormal + self-inverse, enabling future + # in-kernel butterfly fusion and trivial inverse for continuation. + H = _build_hadamard(D, str(device)) + layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous() + layer._tq_Pi = layer._tq_PiT.T.contiguous() + + c = layer._tq_centroids.to(device=device, dtype=torch.float32) + # Precompute midpoints for threshold-based quantization + c_sorted, _ = c.sort() + layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2 + if self.tq_config.use_qjl_residual: + qjl_signs = layer._tq_qjl_signs.to(device=device, dtype=torch.float32) + H_qjl = _build_hadamard(D, str(device)) + layer._tq_qjl_PhiT = (qjl_signs.unsqueeze(1) * H_qjl).contiguous() + layer._tq_qjl_Phi = layer._tq_qjl_PhiT.T.contiguous() + layer._tq_qjl_scale_dev = layer._tq_qjl_scale.to( + device=device, dtype=torch.float32 + ) # Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf) # are pre-allocated via register_buffer in Attention.__init__ # and moved to GPU by model.to(device) — no allocation needed @@ -366,9 +408,9 @@ def forward( tq_layer: Any = layer device = q.device self._ensure_on_device(tq_layer, device) - Pi = tq_layer._tq_Pi - PiT = tq_layer._tq_PiT - centroids = tq_layer._tq_centroids + Pi = getattr(tq_layer, "_tq_Pi", None) + PiT = getattr(tq_layer, "_tq_PiT", None) + centroids = getattr(tq_layer, "_tq_centroids", None) # Compute attention (KV cache was already updated by do_kv_cache_update) # With reorder_batch_threshold=1, decodes come first in the batch. @@ -480,10 +522,17 @@ def _store_kv( slot_mapping, layer._tq_PiT, layer._tq_midpoints, + getattr(layer, "_tq_qjl_PhiT", None), mse_bits=self.tq_config.key_mse_bits, key_packed_size=self.tq_config.key_packed_size, value_quant_bits=self.tq_config.effective_value_quant_bits, key_fp8=self.tq_config.key_fp8, + qjl_residual_bits=self.tq_config.qjl_residual_bits, + grouped_recipe=self.tq_config.grouped_recipe, + group_rotations=getattr(layer, "_tq_group_rotations", None), + group_qjl=getattr(layer, "_tq_group_qjl", None), + group_centroids=getattr(layer, "_tq_group_centroids", None), + group_indices=getattr(layer, "_tq_group_indices", None), ) # ------------------------------------------------------------------ # @@ -592,8 +641,8 @@ def _prefill_attention( # For large continuations, fall back to _continuation_prefill. cached_len = seq_len - q_len if q_len <= _CONTINUATION_DECODE_THRESHOLD: - # Fast path: treat each query as a decode request - # with incremental seq_lens for causal masking. + # Fast path: treat each query as a decode request with + # incremental seq_lens for causal masking. synth_seq_lens = torch.arange( cached_len + 1, seq_len + 1, @@ -615,6 +664,17 @@ def _prefill_attention( key_fp8=self.tq_config.key_fp8, norm_correction=self.tq_config.norm_correction, PiT=PiT, + PhiT=getattr(layer, "_tq_qjl_PhiT", None), + qjl_scale=( + layer._tq_qjl_scale_dev + if self.tq_config.use_qjl_residual + else None + ), + grouped_recipe=self.tq_config.grouped_recipe, + group_rotations=getattr(layer, "_tq_group_rotations", None), + group_qjl=getattr(layer, "_tq_group_qjl", None), + group_centroids=getattr(layer, "_tq_group_centroids", None), + group_indices=getattr(layer, "_tq_group_indices", None), ) else: # Large continuation: dequant cached K/V and use @@ -700,6 +760,11 @@ def _continuation_prefill( NUM_KV_HEADS=Hk, MSE_BYTES=mse_bytes, KPS=self.tq_config.key_packed_size, + QJL_BYTES=( + math.ceil(D * self.tq_config.qjl_residual_bits / 8) + if self.tq_config.use_qjl_residual + else 0 + ), VQB=self.tq_config.effective_value_quant_bits, VAL_DATA_BYTES=val_data_bytes, MSE_BITS=self.tq_config.key_mse_bits, @@ -803,6 +868,15 @@ def _decode_attention( key_fp8=self.tq_config.key_fp8, norm_correction=self.tq_config.norm_correction, PiT=PiT, + PhiT=getattr(layer, "_tq_qjl_PhiT", None), + qjl_scale=( + layer._tq_qjl_scale_dev if self.tq_config.use_qjl_residual else None + ), + grouped_recipe=self.tq_config.grouped_recipe, + group_rotations=getattr(layer, "_tq_group_rotations", None), + group_qjl=getattr(layer, "_tq_group_qjl", None), + group_centroids=getattr(layer, "_tq_group_centroids", None), + group_indices=getattr(layer, "_tq_group_indices", None), mid_o_buf=mid_o_buf, output_buf=output_buf, lse_buf=lse_buf, diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index 8b276e31eafb..df0181415ffa 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -12,13 +12,40 @@ from typing import Any import torch +import torch.nn.functional as F from vllm.triton_utils import tl, triton +from vllm.v1.attention.ops.turboquant_kv_cache import ( + dequantize_turboquant_vectors, + validate_turboquant_group_indices, +) from vllm.v1.attention.ops.triton_decode_attention import ( _fwd_kernel_stage2, ) _FP8_E4B15: dict[int, int] = {} +def _unpack_uniform_values_reference( + packed: torch.Tensor, + head_dim: int, + bits: int, +) -> torch.Tensor: + data_bytes = math.ceil(head_dim * bits / 8) + data = packed[..., :data_bytes] + tail = packed[..., data_bytes : data_bytes + 4] + scale = tail[..., :2].contiguous().view(torch.float16).to(torch.float32).unsqueeze(-1) + v_min = tail[..., 2:4].contiguous().view(torch.float16).to(torch.float32).unsqueeze(-1) + if bits == 4: + q0 = data & 0xF + q1 = (data >> 4) & 0xF + q = torch.stack((q0, q1), dim=-1).reshape(*data.shape[:-1], head_dim) + else: + raw = data.reshape(*data.shape[:-1], head_dim // 8, 3).to(torch.int32) + packed24 = raw[..., 0] | (raw[..., 1] << 8) | (raw[..., 2] << 16) + shifts = (torch.arange(8, device=packed.device, dtype=torch.int32) * 3).view( + *((1,) * (raw.ndim - 1)), 8 + ) + q = ((packed24.unsqueeze(-1) >> shifts) & 0x7).reshape(*raw.shape[:-2], head_dim) + return q.to(torch.float32) * scale + v_min def _use_fp8_e4b15(device: int = 0) -> int: @@ -38,6 +65,7 @@ def _use_fp8_e4b15(device: int = 0) -> int: def _tq_decode_stage1( # Precomputed query projection Q_rot_ptr, # [B, Hq, D] float32 + Q_qjl_ptr, # [B, Hq, D] float32 # Compressed KV cache (combined K+V) KV_cache_ptr, # [num_blocks, block_size, Hk, padded_slot] uint8 # Block table and sequence info @@ -67,6 +95,7 @@ def _tq_decode_stage1( MSE_BITS: tl.constexpr, # 3 or 4 MSE_BYTES: tl.constexpr, # ceil(D * mse_bits / 8) KPS: tl.constexpr, # key_packed_size + QJL_BYTES: tl.constexpr, VQB: tl.constexpr, # value_quant_bits (4 or 8=FP8) VAL_DATA_BYTES: tl.constexpr, # ceil(D * vqb / 8) or D for FP8 # Score constants @@ -75,6 +104,8 @@ def _tq_decode_stage1( BLOCK_D: tl.constexpr, # next_power_of_2(HEAD_DIM) BLOCK_KV: tl.constexpr, # tokens per tile (16) KEY_FP8: tl.constexpr, # 1 if K is stored as FP8 + QJL_ENABLED: tl.constexpr = 0, + QJL_SCALE: tl.constexpr = 0.0, NORM_CORRECTION: tl.constexpr = 0, # 1 = re-normalize centroids FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+) ): @@ -103,6 +134,7 @@ def _tq_decode_stage1( # Load query vector: q_rot — [BLOCK_D] float32 q_base = bid * stride_qb + hid * stride_qh q_rot = tl.load(Q_rot_ptr + q_base + d_offs, mask=d_mask, other=0.0).to(tl.float32) + q_qjl = tl.load(Q_qjl_ptr + q_base + d_offs, mask=d_mask, other=0.0).to(tl.float32) # Precompute byte/bit index vectors for MSE gather loads if not KEY_FP8: @@ -215,6 +247,25 @@ def _tq_decode_stage1( vec_norms = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) scores = vec_norms * term1 * ATTN_SCALE + if QJL_ENABLED: + qjl_base = slot_bases + MSE_BYTES + 2 + qjl_byte_idx = d_offs // 8 + qjl_bit_shift = d_offs % 8 + qjl_raw = tl.load( + KV_cache_ptr + qjl_base[:, None] + qjl_byte_idx[None, :], + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + qjl_sign = tl.where( + ((qjl_raw >> qjl_bit_shift[None, :]) & 0x1) > 0, + 1.0, + -1.0, + ) + term2 = tl.sum( + tl.where(d_mask[None, :], q_qjl[None, :] * qjl_sign, 0.0), + axis=1, + ) + scores += vec_norms * QJL_SCALE * term2 * ATTN_SCALE scores = tl.where(kv_mask, scores, -float("inf")) # ============================================================ @@ -307,6 +358,278 @@ def _tq_decode_stage1( tl.store(Mid_o_ptr + out_base + HEAD_DIM, lse) +@triton.jit +def _grouped_tq_decode_stage1( + Q_rot_0_ptr, + Q_qjl_0_ptr, + Q_rot_1_ptr, + Q_qjl_1_ptr, + KV_cache_ptr, + Block_table_ptr, + Seq_lens_ptr, + Centroids_0_ptr, + Centroids_1_ptr, + Mid_o_ptr, + stride_q0_b, + stride_q0_h, + stride_q1_b, + stride_q1_h, + stride_cache_block, + stride_cache_pos, + stride_cache_head, + stride_cache_dim, + stride_bt_b, + stride_mid_b, + stride_mid_h, + stride_mid_s, + NUM_KV_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + KV_GROUP_SIZE: tl.constexpr, + KPS: tl.constexpr, + VQB: tl.constexpr, + VAL_DATA_BYTES: tl.constexpr, + ATTN_SCALE: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_KV: tl.constexpr, + G0_DIM: tl.constexpr, + G0_PADDED: tl.constexpr, + G0_MSE_BITS: tl.constexpr, + G0_GROUP_OFFSET: tl.constexpr, + G0_QJL_OFFSET: tl.constexpr, + G0_VECTOR_NORM_OFFSET: tl.constexpr, + G0_RESIDUAL_NORM_OFFSET: tl.constexpr, + G0_QJL_SCALE: tl.constexpr, + G1_DIM: tl.constexpr, + G1_PADDED: tl.constexpr, + G1_MSE_BITS: tl.constexpr, + G1_GROUP_OFFSET: tl.constexpr, + G1_QJL_OFFSET: tl.constexpr, + G1_VECTOR_NORM_OFFSET: tl.constexpr, + G1_RESIDUAL_NORM_OFFSET: tl.constexpr, + G1_QJL_SCALE: tl.constexpr, +): + bid = tl.program_id(0) + hid = tl.program_id(1) + sid = tl.program_id(2) + + kv_head = hid // KV_GROUP_SIZE + seq_len = tl.load(Seq_lens_ptr + bid) + split_len = tl.cdiv(seq_len, NUM_KV_SPLITS) + split_start = split_len * sid + split_end = tl.minimum(split_start + split_len, seq_len) + if split_start >= split_end: + return + + d_offs = tl.arange(0, BLOCK_D) + d_mask = d_offs < HEAD_DIM + kv_range = tl.arange(0, BLOCK_KV) + offs_d0 = tl.arange(0, G0_PADDED) + mask_d0 = offs_d0 < G0_DIM + offs_d1 = tl.arange(0, G1_PADDED) + mask_d1 = offs_d1 < G1_DIM + + q0_base = bid * stride_q0_b + hid * stride_q0_h + q_rot_0 = tl.load( + Q_rot_0_ptr + q0_base + offs_d0, mask=mask_d0, other=0.0 + ).to(tl.float32) + q_qjl_0 = tl.load( + Q_qjl_0_ptr + q0_base + offs_d0, mask=mask_d0, other=0.0 + ).to(tl.float32) + q1_base = bid * stride_q1_b + hid * stride_q1_h + q_rot_1 = tl.load( + Q_rot_1_ptr + q1_base + offs_d1, mask=mask_d1, other=0.0 + ).to(tl.float32) + q_qjl_1 = tl.load( + Q_qjl_1_ptr + q1_base + offs_d1, mask=mask_d1, other=0.0 + ).to(tl.float32) + + if VQB == 3: + val_bit_off = d_offs * 3 + val_byte_idx = val_bit_off // 8 + val_bit_shift = val_bit_off % 8 + + m_prev = -float("inf") + l_prev = 0.0 + acc = tl.zeros([BLOCK_D], dtype=tl.float32) + bt_base = bid * stride_bt_b + + for start_n in range(split_start, split_end, BLOCK_KV): + kv_offs = start_n + kv_range + kv_mask = kv_offs < split_end + page_idx = kv_offs // BLOCK_SIZE + page_off = kv_offs % BLOCK_SIZE + block_nums = tl.load( + Block_table_ptr + bt_base + page_idx, + mask=kv_mask, + other=0, + ) + slot_bases = ( + block_nums * stride_cache_block + + page_off * stride_cache_pos + + kv_head * stride_cache_head + ) + + key_indices_0 = _grouped_unpack_fixed_indices( + KV_cache_ptr, + slot_bases, + offs_d0, + stride_cache_dim, + G0_MSE_BITS, + BLOCK_KV, + G0_PADDED, + G0_GROUP_OFFSET, + kv_mask[:, None] & mask_d0[None, :], + ) + key_centroids_0 = tl.load( + Centroids_0_ptr + key_indices_0, + mask=kv_mask[:, None] & mask_d0[None, :], + other=0.0, + ) + key_qjl_signs_0 = _grouped_unpack_signs( + KV_cache_ptr, + slot_bases, + offs_d0, + stride_cache_dim, + G0_QJL_OFFSET, + kv_mask[:, None] & mask_d0[None, :], + ) + key_vector_norm_0_base = slot_bases + G0_VECTOR_NORM_OFFSET * stride_cache_dim + kv0_lo = tl.load(KV_cache_ptr + key_vector_norm_0_base, mask=kv_mask, other=0).to( + tl.uint16 + ) + kv0_hi = tl.load( + KV_cache_ptr + key_vector_norm_0_base + stride_cache_dim, + mask=kv_mask, + other=0, + ).to(tl.uint16) + key_vector_norm_0 = (kv0_lo | (kv0_hi << 8)).to( + tl.float16, bitcast=True + ).to(tl.float32) + kr0_base = slot_bases + G0_RESIDUAL_NORM_OFFSET * stride_cache_dim + kr0_lo = tl.load(KV_cache_ptr + kr0_base, mask=kv_mask, other=0).to(tl.uint16) + kr0_hi = tl.load( + KV_cache_ptr + kr0_base + stride_cache_dim, mask=kv_mask, other=0 + ).to(tl.uint16) + key_residual_norm_0 = (kr0_lo | (kr0_hi << 8)).to( + tl.float16, bitcast=True + ).to(tl.float32) + + key_indices_1 = _grouped_unpack_fixed_indices( + KV_cache_ptr, + slot_bases, + offs_d1, + stride_cache_dim, + G1_MSE_BITS, + BLOCK_KV, + G1_PADDED, + G1_GROUP_OFFSET, + kv_mask[:, None] & mask_d1[None, :], + ) + key_centroids_1 = tl.load( + Centroids_1_ptr + key_indices_1, + mask=kv_mask[:, None] & mask_d1[None, :], + other=0.0, + ) + key_qjl_signs_1 = _grouped_unpack_signs( + KV_cache_ptr, + slot_bases, + offs_d1, + stride_cache_dim, + G1_QJL_OFFSET, + kv_mask[:, None] & mask_d1[None, :], + ) + key_vector_norm_1_base = slot_bases + G1_VECTOR_NORM_OFFSET * stride_cache_dim + kv1_lo = tl.load(KV_cache_ptr + key_vector_norm_1_base, mask=kv_mask, other=0).to( + tl.uint16 + ) + kv1_hi = tl.load( + KV_cache_ptr + key_vector_norm_1_base + stride_cache_dim, + mask=kv_mask, + other=0, + ).to(tl.uint16) + key_vector_norm_1 = (kv1_lo | (kv1_hi << 8)).to( + tl.float16, bitcast=True + ).to(tl.float32) + kr1_base = slot_bases + G1_RESIDUAL_NORM_OFFSET * stride_cache_dim + kr1_lo = tl.load(KV_cache_ptr + kr1_base, mask=kv_mask, other=0).to(tl.uint16) + kr1_hi = tl.load( + KV_cache_ptr + kr1_base + stride_cache_dim, mask=kv_mask, other=0 + ).to(tl.uint16) + key_residual_norm_1 = (kr1_lo | (kr1_hi << 8)).to( + tl.float16, bitcast=True + ).to(tl.float32) + + scores = key_vector_norm_0 * tl.sum( + key_centroids_0 * q_rot_0[None, :], axis=1 + ) + scores += key_vector_norm_0 * key_residual_norm_0 * G0_QJL_SCALE * tl.sum( + key_qjl_signs_0 * q_qjl_0[None, :], axis=1 + ) + scores += key_vector_norm_1 * tl.sum( + key_centroids_1 * q_rot_1[None, :], axis=1 + ) + scores += key_vector_norm_1 * key_residual_norm_1 * G1_QJL_SCALE * tl.sum( + key_qjl_signs_1 * q_qjl_1[None, :], axis=1 + ) + scores = tl.where(kv_mask, scores * ATTN_SCALE, -float("inf")) + + n_e_max = tl.maximum(tl.max(scores, 0), m_prev) + re_scale = tl.exp(m_prev - n_e_max) + p = tl.exp(scores - n_e_max) + + val_bases = slot_bases + KPS * stride_cache_dim + if VQB == 3: + val_addrs0 = val_bases[:, None] + val_byte_idx[None, :] * stride_cache_dim + val_raw0 = tl.load( + KV_cache_ptr + val_addrs0, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + val_raw1 = tl.load( + KV_cache_ptr + val_addrs0 + stride_cache_dim, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + raw16 = val_raw0 | (val_raw1 << 8) + v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32) + else: + vb_idx = d_offs // 2 + vb_shift = (d_offs % 2) * 4 + val_addrs = val_bases[:, None] + vb_idx[None, :] * stride_cache_dim + val_raw = tl.load( + KV_cache_ptr + val_addrs, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32) + + sc_bases = val_bases + VAL_DATA_BYTES * stride_cache_dim + sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(tl.uint16) + sc_hi = tl.load( + KV_cache_ptr + sc_bases + stride_cache_dim, mask=kv_mask, other=0 + ).to(tl.uint16) + v_scales = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + zr_lo = tl.load( + KV_cache_ptr + sc_bases + 2 * stride_cache_dim, mask=kv_mask, other=0 + ).to(tl.uint16) + zr_hi = tl.load( + KV_cache_ptr + sc_bases + 3 * stride_cache_dim, mask=kv_mask, other=0 + ).to(tl.uint16) + v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + values = v_idx * v_scales[:, None] + v_zeros[:, None] + + acc = acc * re_scale + tl.sum(p[:, None] * values, 0) + l_prev = l_prev * re_scale + tl.sum(p, 0) + m_prev = n_e_max + + out_base = bid * stride_mid_b + hid * stride_mid_h + sid * stride_mid_s + safe_l = tl.where(l_prev > 0.0, l_prev, 1.0) + tl.store(Mid_o_ptr + out_base + d_offs, acc / safe_l, mask=d_mask) + tl.store(Mid_o_ptr + out_base + HEAD_DIM, m_prev + tl.log(safe_l)) + + # --------------------------------------------------------------------------- # Pre-dequant kernel: Bulk dequant K (MSE+norms) and V to fp16 # --------------------------------------------------------------------------- @@ -334,6 +657,7 @@ def _tq_full_dequant_kv( NUM_KV_HEADS: tl.constexpr, MSE_BYTES: tl.constexpr, KPS: tl.constexpr, + QJL_BYTES: tl.constexpr, VQB: tl.constexpr, VAL_DATA_BYTES: tl.constexpr, MSE_BITS: tl.constexpr, @@ -471,6 +795,7 @@ def _get_layout(D, mse_bits, value_quant_bits, key_packed_size): "val_data_bytes": val_data_bytes, "mse_bits": mse_bits, "n_centroids": 2**mse_bits, + "qjl_bytes": max(0, key_packed_size - math.ceil(D * mse_bits / 8) - 2), "BLOCK_D": triton.next_power_of_2(D), } _layout_cache[key] = cfg @@ -482,8 +807,8 @@ def triton_turboquant_decode_attention( kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8 block_table: torch.Tensor, # [B, max_num_blocks] int32 seq_lens: torch.Tensor, # [B] int32 - Pi: torch.Tensor, # [D, D] float32 - centroids: torch.Tensor, # [n_centroids] float32 + Pi: torch.Tensor | None, # [D, D] float32 + centroids: torch.Tensor | None, # [n_centroids] float32 scale: float, mse_bits: int, key_packed_size: int, @@ -491,6 +816,13 @@ def triton_turboquant_decode_attention( key_fp8: bool = False, norm_correction: bool = False, PiT: torch.Tensor | None = None, # [D, D] pre-computed Pi.T contiguous + PhiT: torch.Tensor | None = None, # [D, D] second orthogonal sketch matrix + qjl_scale: torch.Tensor | None = None, + grouped_recipe: str | None = None, + group_rotations: tuple[torch.Tensor, torch.Tensor] | None = None, + group_qjl: tuple[torch.Tensor, torch.Tensor] | None = None, + group_centroids: dict[int, torch.Tensor] | None = None, + group_indices: tuple[torch.Tensor, torch.Tensor] | None = None, # Pre-allocated buffers (optional, avoids per-call allocation) mid_o_buf: torch.Tensor | None = None, output_buf: torch.Tensor | None = None, @@ -503,23 +835,84 @@ def triton_turboquant_decode_attention( Returns: output tensor [B, Hq, D] in query's dtype. """ B, Hq, D = query.shape + if grouped_recipe is not None: + if ( + group_rotations is None + or group_qjl is None + or group_centroids is None + or group_indices is None + ): + raise ValueError("Grouped TurboQuant decode requires grouped tables.") + validate_turboquant_group_indices( + torch.empty( + query.shape[0], kv_cache.shape[2], D, device=query.device, dtype=query.dtype + ), + group_indices, + ) + output = torch.zeros_like(query) + key_bytes = key_packed_size + value_bytes = math.ceil(D * value_quant_bits / 8) + 4 + block_size = kv_cache.shape[1] + for seq_idx, seq_len in enumerate(seq_lens.tolist()): + num_blocks = math.ceil(seq_len / block_size) + block_ids = block_table[seq_idx, :num_blocks].to(torch.int64) + seq_cache = kv_cache.index_select(0, block_ids).reshape( + num_blocks * block_size, kv_cache.shape[2], kv_cache.shape[3] + )[:seq_len] + k_seq = dequantize_turboquant_vectors( + seq_cache[..., :key_bytes], + grouped_recipe, + D, + group_rotations, + group_qjl, + group_centroids, + group_indices, + query.dtype, + ) + v_seq = _unpack_uniform_values_reference( + seq_cache[..., key_bytes : key_bytes + value_bytes], + D, + value_quant_bits, + ).to(query.dtype) + q_t = query[seq_idx : seq_idx + 1].transpose(0, 1).unsqueeze(0).contiguous() + k_t = k_seq.transpose(0, 1).unsqueeze(0).contiguous() + v_t = v_seq.transpose(0, 1).unsqueeze(0).contiguous() + output[seq_idx : seq_idx + 1] = F.scaled_dot_product_attention( + q_t, + k_t, + v_t, + is_causal=False, + scale=scale, + enable_gqa=(k_t.shape[1] < q_t.shape[1]), + )[0].transpose(0, 1) + return output + if Pi is None or centroids is None: + raise ValueError("Pi and centroids are required for legacy TurboQuant decode.") + Hk = kv_cache.shape[2] block_size = kv_cache.shape[1] kv_group_size = Hq // Hk device = query.device cfg = _get_layout(D, mse_bits, value_quant_bits, key_packed_size) + if key_fp8: + cfg = {**cfg, "qjl_bytes": 0} # Compute q_rot = q @ Pi.T (rotated query for MSE key scoring) # FP8 path: pass query directly (float16); kernel casts inline. # MSE path: still needs external GEMM (cuBLAS), so q_rot is float32. if key_fp8: q_rot = query.contiguous() + q_qjl = query.contiguous() else: q_float = query.float() if PiT is None: PiT = Pi.T.contiguous() q_rot = (q_float @ PiT).contiguous() + if cfg["qjl_bytes"] > 0 and PhiT is not None: + q_qjl = (q_rot @ PhiT).contiguous() + else: + q_qjl = q_rot NUM_KV_SPLITS = max_num_kv_splits @@ -547,6 +940,7 @@ def triton_turboquant_decode_attention( grid = (B, Hq, NUM_KV_SPLITS) _tq_decode_stage1[grid]( q_rot, + q_qjl, kv_cache, block_table, seq_lens, @@ -569,12 +963,15 @@ def triton_turboquant_decode_attention( MSE_BITS=mse_bits, MSE_BYTES=cfg["mse_bytes"], KPS=key_packed_size, + QJL_BYTES=cfg["qjl_bytes"], VQB=value_quant_bits, VAL_DATA_BYTES=cfg["val_data_bytes"], ATTN_SCALE=scale, BLOCK_D=cfg["BLOCK_D"], BLOCK_KV=BLOCK_KV, KEY_FP8=1 if key_fp8 else 0, + QJL_ENABLED=1 if cfg["qjl_bytes"] > 0 else 0, + QJL_SCALE=(float(qjl_scale.item()) if qjl_scale is not None else 0.0), NORM_CORRECTION=1 if norm_correction else 0, FP8_E4B15=fp8_e4b15, num_warps=1, diff --git a/vllm/v1/attention/ops/triton_turboquant_store.py b/vllm/v1/attention/ops/triton_turboquant_store.py index 3da3347d5df5..1659e3ce88bc 100644 --- a/vllm/v1/attention/ops/triton_turboquant_store.py +++ b/vllm/v1/attention/ops/triton_turboquant_store.py @@ -14,9 +14,48 @@ import torch +from vllm.model_executor.layers.quantization.turboquant.centroids import get_centroids from vllm.triton_utils import tl, triton +from vllm.v1.attention.ops.turboquant_kv_cache import ( + quantize_turboquant_vectors, + validate_turboquant_group_indices, +) from vllm.v1.attention.ops.triton_turboquant_decode import _use_fp8_e4b15 + +def _pack_uniform_values_reference(value: torch.Tensor, bits: int) -> torch.Tensor: + levels = (1 << bits) - 1 + v_min = value.amin(dim=-1, keepdim=True) + v_max = value.amax(dim=-1, keepdim=True) + scale = (v_max - v_min) / max(levels, 1) + scale = torch.clamp(scale, min=1e-8) + q = torch.clamp(((value - v_min) / scale).round(), 0, levels).to(torch.uint8) + if bits == 4: + pairs = q.reshape(*q.shape[:-1], q.shape[-1] // 2, 2) + data = (pairs[..., 0] | (pairs[..., 1] << 4)).contiguous() + else: + groups = q.reshape(*q.shape[:-1], q.shape[-1] // 8, 8).to(torch.int32) + shifts = (torch.arange(8, device=value.device, dtype=torch.int32) * 3).view( + *((1,) * (groups.ndim - 1)), 8 + ) + packed24 = torch.sum(groups << shifts, dim=-1) + data = torch.stack( + ( + (packed24 & 0xFF).to(torch.uint8), + ((packed24 >> 8) & 0xFF).to(torch.uint8), + ((packed24 >> 16) & 0xFF).to(torch.uint8), + ), + dim=-1, + ).reshape(*q.shape[:-1], -1) + tail = torch.cat( + ( + scale.to(torch.float16).view(torch.uint8).reshape(*scale.shape[:-1], 2), + v_min.to(torch.float16).view(torch.uint8).reshape(*v_min.shape[:-1], 2), + ), + dim=-1, + ) + return torch.cat((data, tail), dim=-1) + # ═══════════════════════════════════════════════════════════════════════ # Shared: value uniform quantization + pack + scale/zero store # ═══════════════════════════════════════════════════════════════════════ @@ -220,6 +259,7 @@ def _tq_fused_store_mse( Y_ptr, # [NH, D] float32 — rotated normalized keys (x_hat @ PiT) Norms_ptr, # [NH] float32 — key vector norms (||k||) Value_ptr, # [NH, D] float32 — raw values + QJLPacked_ptr, # [NH, QJL_BYTES] uint8 — packed residual sign bits # Quantization tables Midpoints_ptr, # [n_centroids-1] float32 # Cache and indexing @@ -237,6 +277,7 @@ def _tq_fused_store_mse( # TQ layout MSE_BYTES: tl.constexpr, KPS: tl.constexpr, + QJL_BYTES: tl.constexpr, # Value quantization VQB: tl.constexpr, VAL_DATA_BYTES: tl.constexpr, @@ -307,6 +348,13 @@ def _tq_fused_store_mse( tl.store(KV_cache_ptr + slot_base + grp_offs * 3, b0, mask=grp_mask) tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 1, b1, mask=grp_mask) tl.store(KV_cache_ptr + slot_base + grp_offs * 3 + 2, b2, mask=grp_mask) + elif MSE_BITS == 2: + idx_quads = tl.reshape(idx, [BLOCK_D // 4, 4]) + shifts_2 = tl.arange(0, 4) * 2 + packed = tl.sum((idx_quads & 0x3) << shifts_2[None, :], axis=1).to(tl.uint8) + mse_offs = tl.arange(0, BLOCK_D // 4) + mse_mask = mse_offs < MSE_BYTES + tl.store(KV_cache_ptr + slot_base + mse_offs, packed, mask=mse_mask) # ── 3. STORE vec_norm (fp16, 2 bytes) ───────────────────────────── norm_offset = MSE_BYTES @@ -318,6 +366,11 @@ def _tq_fused_store_mse( KV_cache_ptr + slot_base + norm_offset + 1, ((vn_u16 >> 8) & 0xFF).to(tl.uint8) ) + if QJL_BYTES > 0: + qjl_offs = tl.arange(0, QJL_BYTES) + qjl_vals = tl.load(QJLPacked_ptr + pid * QJL_BYTES + qjl_offs) + tl.store(KV_cache_ptr + slot_base + norm_offset + 2 + qjl_offs, qjl_vals) + # ── 4. VALUE QUANTIZE + PACK ────────────────────────────────────── _store_quantized_value( Value_ptr, @@ -336,6 +389,59 @@ def _tq_fused_store_mse( ) +@triton.jit +def _tq_grouped_store_value( + Value_ptr, + KV_cache_ptr, + Slot_mapping_ptr, + stride_cache_block: tl.constexpr, + stride_cache_pos: tl.constexpr, + stride_cache_head: tl.constexpr, + stride_cache_dim: tl.constexpr, + D: tl.constexpr, + H: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_D: tl.constexpr, + KPS: tl.constexpr, + VQB: tl.constexpr, + VAL_DATA_BYTES: tl.constexpr, + BLOCK_VAL: tl.constexpr, + BLOCK_GRP: tl.constexpr = 16, +): + pid = tl.program_id(0) + token_idx = pid // H + head_idx = pid % H + + slot = tl.load(Slot_mapping_ptr + token_idx) + if slot < 0: + return + blk = slot // BLOCK_SIZE + off = slot % BLOCK_SIZE + slot_base = ( + blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head + ) + + base = pid * D + d_offs = tl.arange(0, BLOCK_D) + d_mask = d_offs < D + + _store_quantized_value( + Value_ptr, + KV_cache_ptr, + base, + slot_base, + d_offs, + d_mask, + D=D, + KPS=KPS * stride_cache_dim, + VQB=VQB, + VAL_DATA_BYTES=VAL_DATA_BYTES * stride_cache_dim, + BLOCK_D=BLOCK_D, + BLOCK_VAL=BLOCK_VAL, + BLOCK_GRP=BLOCK_GRP, + ) + + # ═══════════════════════════════════════════════════════════════════════ # Launcher # ═══════════════════════════════════════════════════════════════════════ @@ -346,20 +452,28 @@ def triton_turboquant_store( value: torch.Tensor, # [N, H, D] — raw values kv_cache: torch.Tensor, # [num_blocks, block_size, Hk, padded_slot] uint8 slot_mapping: torch.Tensor, # [N] int32 - PiT: torch.Tensor, # [D, D] float32 - midpoints: torch.Tensor, # [n_centroids-1] float32 + PiT: torch.Tensor | None, # [D, D] float32 + midpoints: torch.Tensor | None, # [n_centroids-1] float32 + PhiT: torch.Tensor | None, mse_bits: int, key_packed_size: int, value_quant_bits: int, key_fp8: bool = False, + qjl_residual_bits: int = 0, + grouped_recipe: str | None = None, + group_rotations: tuple[torch.Tensor, torch.Tensor] | None = None, + group_qjl: tuple[torch.Tensor, torch.Tensor] | None = None, + group_centroids: dict[int, torch.Tensor] | None = None, + group_indices: tuple[torch.Tensor, torch.Tensor] | None = None, ): - """Launch TQ store kernel (FP8 or MSE path).""" + """Launch TQ store kernel or grouped reference store path.""" N, H, D = key.shape NH = N * H block_size = kv_cache.shape[1] BLOCK_D = triton.next_power_of_2(D) mse_bytes = math.ceil(D * mse_bits / 8) n_centroids = 2**mse_bits + qjl_bytes = math.ceil(D * qjl_residual_bits / 8) val_data_bytes = math.ceil(D * value_quant_bits / 8) @@ -373,6 +487,44 @@ def triton_turboquant_store( block_grp = triton.next_power_of_2(D // 8) if D >= 8 else 1 # ── FP8 PATH: in-kernel FP8 cast + scatter via fp8 kernel ── + if grouped_recipe is not None: + if ( + group_rotations is None + or group_qjl is None + or group_centroids is None + or group_indices is None + ): + raise ValueError("Grouped TurboQuant store requires grouped tables.") + validate_turboquant_group_indices(key, group_indices) + valid_mask = slot_mapping >= 0 + if not torch.any(valid_mask): + return + valid_slots = slot_mapping[valid_mask].to(torch.int64) + blocks = torch.div(valid_slots, kv_cache.shape[1], rounding_mode="floor") + offsets = torch.remainder(valid_slots, kv_cache.shape[1]) + value_packed_size = val_data_bytes + 4 + kv_cache[blocks, offsets, :, :] = 0 + packed_value = _pack_uniform_values_reference( + value.to(torch.float32), value_quant_bits + ) + packed_key = quantize_turboquant_vectors( + key.to(torch.float32), + grouped_recipe, + group_rotations, + group_qjl, + group_centroids, + group_indices, + ) + kv_cache[blocks, offsets, :, :key_packed_size] = packed_key[valid_mask] + kv_cache[ + blocks, + offsets, + :, + key_packed_size : key_packed_size + value_packed_size, + ] = packed_value[valid_mask] + return + if PiT is None or midpoints is None: + raise ValueError("PiT and midpoints are required for legacy TurboQuant store.") if key_fp8: k_flat = key.reshape(NH, D).contiguous() v_flat = value.reshape(NH, D).contiguous() @@ -410,6 +562,26 @@ def triton_turboquant_store( x_hat = k_flat / (norms + 1e-8) y = x_hat @ PiT + if qjl_residual_bits > 0: + if PhiT is None: + raise ValueError("PhiT is required when qjl_residual_bits > 0") + # Reconstruct base centroids in PyTorch, then sketch the residual with + # a second orthogonal sign transform. + # bucketize returns [0, n_centroids-1] against sorted midpoints. + idx = torch.bucketize(y, midpoints) + centroid_table = get_centroids(D, mse_bits).to(device=y.device, dtype=y.dtype) + c_vals = centroid_table[idx] + resid = y - c_vals + resid_proj = resid @ PhiT + sign_bits = (resid_proj >= 0).to(torch.uint8) + qjl_packed = torch.zeros(NH, qjl_bytes, device=y.device, dtype=torch.uint8) + for bit in range(D): + byte_idx = bit // 8 + bit_shift = bit % 8 + qjl_packed[:, byte_idx] |= sign_bits[:, bit] << bit_shift + else: + qjl_packed = torch.empty(NH, max(qjl_bytes, 1), device=y.device, dtype=torch.uint8) + v_flat = value.float().reshape(NH, D) # Fused kernel: bucketize + MSE index pack + norm store + value pack @@ -418,6 +590,7 @@ def triton_turboquant_store( y, norms.squeeze(1), v_flat, + qjl_packed, midpoints, kv_cache.view(-1), slot_mapping, @@ -430,6 +603,7 @@ def triton_turboquant_store( BLOCK_D=BLOCK_D, MSE_BYTES=mse_bytes, KPS=key_packed_size, + QJL_BYTES=qjl_bytes, VQB=value_quant_bits, VAL_DATA_BYTES=val_data_bytes, BLOCK_VAL=BLOCK_VAL, diff --git a/vllm/v1/attention/ops/turboquant_kv_cache.py b/vllm/v1/attention/ops/turboquant_kv_cache.py new file mode 100644 index 000000000000..c377f7eeebe7 --- /dev/null +++ b/vllm/v1/attention/ops/turboquant_kv_cache.py @@ -0,0 +1,733 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import math +from dataclasses import dataclass +from functools import cache +from typing import Any + +import torch + +TURBOQUANT_COMPAT_DTYPES = { + "turboquant25": "turboquant25", + "turboquant35": "turboquant35", + "turboquant_3bit": "turboquant25", + "turboquant_4bit": "turboquant35", +} +TURBOQUANT_KV_CACHE_BITS = { + "turboquant25": 2.5, + "turboquant35": 3.5, +} +TURBOQUANT_OUTLIER_RATIOS = { + "turboquant25": 0.25, + "turboquant35": 0.50, +} +TURBOQUANT_GROUP_BITS = { + "turboquant25": (3, 2), + "turboquant35": (4, 3), +} +TURBOQUANT_VECTOR_NORM_BYTES = 2 +TURBOQUANT_RESIDUAL_NORM_BYTES = 2 +TURBOQUANT_GROUP_ALIGNMENT = 16 +TURBOQUANT_SEED = 20250428 +TURBOQUANT_QJL_SEED_OFFSET = 10_000 +TURBOQUANT_QJL_SCALE = math.sqrt(math.pi / 2.0) +TURBOQUANT_CODEBOOK_GRID_POINTS = 32768 +TURBOQUANT_CODEBOOK_EPS = 1e-6 + + +@dataclass(frozen=True) +class TurboQuantGroupLayout: + dim: int + bits: int + mse_bits: int + mse_payload_bytes: int + qjl_payload_bytes: int + qjl_offset: int + vector_norm_offset: int + residual_norm_offset: int + packed_bytes: int + + +@dataclass(frozen=True) +class TurboQuantLayout: + groups: tuple[TurboQuantGroupLayout, TurboQuantGroupLayout] + packed_dim: int + + +@dataclass(frozen=True) +class TurboQuantKernelMeta: + decode_block_n: int + decode_num_warps: int + update_tile: int + update_num_warps: int + postprocess_num_warps: int + + +def canonicalize_turboquant_dtype(bits_or_dtype: float | int | str) -> str: + if isinstance(bits_or_dtype, str): + try: + return TURBOQUANT_COMPAT_DTYPES[bits_or_dtype] + except KeyError as e: + raise ValueError( + f"Unsupported TurboQuant KV cache dtype: {bits_or_dtype}" + ) from e + + bits = float(bits_or_dtype) + if bits == 2.5: + return "turboquant25" + if bits == 3.5: + return "turboquant35" + raise ValueError(f"Unsupported TurboQuant bit-width: {bits}") + + +def is_turboquant_kv_cache(kv_cache_dtype: str) -> bool: + return kv_cache_dtype in TURBOQUANT_COMPAT_DTYPES + + +def get_turboquant_bits(kv_cache_dtype: str) -> float: + return TURBOQUANT_KV_CACHE_BITS[canonicalize_turboquant_dtype(kv_cache_dtype)] + + +def get_turboquant_outlier_count(head_size: int, kv_cache_dtype: str) -> int: + recipe = canonicalize_turboquant_dtype(kv_cache_dtype) + if head_size % TURBOQUANT_GROUP_ALIGNMENT != 0: + raise ValueError( + "TurboQuant KV cache requires head_size to be a multiple of 16." + ) + ratio = TURBOQUANT_OUTLIER_RATIOS[recipe] + aligned = int( + round(head_size * ratio / TURBOQUANT_GROUP_ALIGNMENT) + * TURBOQUANT_GROUP_ALIGNMENT + ) + if aligned <= 0 or aligned >= head_size: + raise ValueError( + f"Unsupported TurboQuant head_size {head_size} for {kv_cache_dtype}." + ) + return aligned + + +def get_turboquant_group_dims( + head_size: int, + kv_cache_dtype: str, +) -> tuple[int, int]: + outlier_count = get_turboquant_outlier_count(head_size, kv_cache_dtype) + return outlier_count, head_size - outlier_count + + +@cache +def _layout_cached(kv_cache_dtype: str, head_size: int) -> TurboQuantLayout: + recipe = canonicalize_turboquant_dtype(kv_cache_dtype) + group_dims = get_turboquant_group_dims(head_size, recipe) + group_bits = TURBOQUANT_GROUP_BITS[recipe] + groups: list[TurboQuantGroupLayout] = [] + cursor = 0 + for dim, bits in zip(group_dims, group_bits, strict=True): + mse_bits = bits - 1 + mse_payload_bytes = (dim * mse_bits + 7) // 8 + qjl_payload_bytes = (dim + 7) // 8 + qjl_offset = cursor + mse_payload_bytes + vector_norm_offset = qjl_offset + qjl_payload_bytes + residual_norm_offset = vector_norm_offset + TURBOQUANT_VECTOR_NORM_BYTES + packed_bytes = ( + mse_payload_bytes + + qjl_payload_bytes + + TURBOQUANT_VECTOR_NORM_BYTES + + TURBOQUANT_RESIDUAL_NORM_BYTES + ) + groups.append( + TurboQuantGroupLayout( + dim=dim, + bits=bits, + mse_bits=mse_bits, + mse_payload_bytes=mse_payload_bytes, + qjl_payload_bytes=qjl_payload_bytes, + qjl_offset=qjl_offset, + vector_norm_offset=vector_norm_offset, + residual_norm_offset=residual_norm_offset, + packed_bytes=packed_bytes, + ) + ) + cursor += packed_bytes + return TurboQuantLayout(groups=tuple(groups), packed_dim=cursor) + + +def get_turboquant_layout(kv_cache_dtype: str, head_size: int) -> TurboQuantLayout: + return _layout_cached(kv_cache_dtype, head_size) + + +def get_turboquant_packed_dim(head_size: int, bits_or_dtype: float | int | str) -> int: + return get_turboquant_layout( + canonicalize_turboquant_dtype(bits_or_dtype), head_size + ).packed_dim + + +def get_turboquant_kernel_meta( + device: torch.device, + head_size: int, +) -> TurboQuantKernelMeta: + if device.type != "cuda": + raise ValueError("TurboQuant Triton kernels require CUDA tensors.") + capability = torch.cuda.get_device_capability(device) + if capability == (8, 6): + return TurboQuantKernelMeta( + decode_block_n=8, + decode_num_warps=2, + update_tile=16, + update_num_warps=2, + postprocess_num_warps=2, + ) + return TurboQuantKernelMeta( + decode_block_n=8 if head_size >= 256 else 16, + decode_num_warps=4, + update_tile=32, + update_num_warps=4, + postprocess_num_warps=4, + ) + + +@cache +def _hadamard_block_sizes(dim: int) -> tuple[int, ...]: + sizes: list[int] = [] + remaining = dim + while remaining > 0: + block = 1 << (remaining.bit_length() - 1) + sizes.append(block) + remaining -= block + return tuple(sizes) + + +def _fwht_pow2(x: torch.Tensor) -> torch.Tensor: + orig_shape = x.shape + size = orig_shape[-1] + out = x.reshape(-1, size) + block = 1 + while block < size: + out = out.reshape(out.shape[0], -1, block * 2) + left = out[..., :block] + right = out[..., block : 2 * block] + out = torch.cat((left + right, left - right), dim=-1) + out = out.reshape(-1, size) + block *= 2 + return out.reshape(orig_shape) + + +@cache +def _structured_signs_cached( + device_type: str, + device_index: int | None, + dim: int, + seed: int, +) -> torch.Tensor: + device = torch.device(device_type, device_index) + generator = torch.Generator(device=device.type) + generator.manual_seed(seed) + randint_kwargs: dict[str, Any] = {"generator": generator, "device": device} + if device.type == "cuda": + with torch.accelerator.device_index(device.index): + draws = torch.randint(0, 2, (dim,), **randint_kwargs) + else: + draws = torch.randint(0, 2, (dim,), **randint_kwargs) + return torch.where( + draws > 0, + torch.ones(dim, dtype=torch.float32, device=device), + -torch.ones(dim, dtype=torch.float32, device=device), + ) + + +def _apply_block_hadamard( + x: torch.Tensor, + signs: torch.Tensor, + *, + normalized: bool, + inverse: bool, +) -> torch.Tensor: + outputs: list[torch.Tensor] = [] + cursor = 0 + for block_size in _hadamard_block_sizes(x.shape[-1]): + block = x[..., cursor : cursor + block_size] + block_signs = signs[cursor : cursor + block_size] + if inverse: + block = _fwht_pow2(block) + block = block * block_signs + else: + block = block * block_signs + block = _fwht_pow2(block) + if normalized: + block = block / math.sqrt(block_size) + outputs.append(block) + cursor += block_size + return torch.cat(outputs, dim=-1) + + +def _apply_mse_transform(x: torch.Tensor, signs: torch.Tensor) -> torch.Tensor: + return _apply_block_hadamard(x, signs, normalized=True, inverse=False) + + +def _apply_mse_inverse_transform(x: torch.Tensor, signs: torch.Tensor) -> torch.Tensor: + return _apply_block_hadamard(x, signs, normalized=True, inverse=True) + + +def _apply_qjl_transform(x: torch.Tensor, signs: torch.Tensor) -> torch.Tensor: + return _apply_block_hadamard(x, signs, normalized=False, inverse=False) + + +def _apply_qjl_inverse_transform(x: torch.Tensor, signs: torch.Tensor) -> torch.Tensor: + return _apply_block_hadamard(x, signs, normalized=False, inverse=True) + + +def get_turboquant_rotation( + device: torch.device, + dim: int, + seed_offset: int = 0, +) -> torch.Tensor: + return _structured_signs_cached( + device.type, device.index, dim, TURBOQUANT_SEED + seed_offset + dim + ) + + +def get_turboquant_qjl_matrix( + device: torch.device, + dim: int, + seed_offset: int = 0, +) -> torch.Tensor: + return _structured_signs_cached( + device.type, + device.index, + dim, + TURBOQUANT_SEED + TURBOQUANT_QJL_SEED_OFFSET + seed_offset + dim, + ) + + +@cache +def _transform_matrix_cached( + device_type: str, + device_index: int | None, + dim: int, + seed_offset: int, + kind: str, +) -> torch.Tensor: + device = torch.device(device_type, device_index) + identity = torch.eye(dim, dtype=torch.float32, device=device) + if kind == "mse_forward": + return _apply_mse_transform( + identity, get_turboquant_rotation(device, dim, seed_offset) + ) + if kind == "mse_inverse": + return _apply_mse_inverse_transform( + identity, get_turboquant_rotation(device, dim, seed_offset) + ) + if kind == "qjl_forward": + return _apply_qjl_transform( + identity, get_turboquant_qjl_matrix(device, dim, seed_offset) + ) + if kind == "qjl_inverse": + return _apply_qjl_inverse_transform( + identity, get_turboquant_qjl_matrix(device, dim, seed_offset) + ) + raise ValueError(f"Unsupported TurboQuant transform kind: {kind}") + + +def get_turboquant_mse_transform_matrix( + device: torch.device, + dim: int, + seed_offset: int = 0, +) -> torch.Tensor: + return _transform_matrix_cached( + device.type, device.index, dim, seed_offset, "mse_forward" + ) + + +def get_turboquant_mse_inverse_transform_matrix( + device: torch.device, + dim: int, + seed_offset: int = 0, +) -> torch.Tensor: + return _transform_matrix_cached( + device.type, device.index, dim, seed_offset, "mse_inverse" + ) + + +def get_turboquant_qjl_transform_matrix( + device: torch.device, + dim: int, + seed_offset: int = 0, +) -> torch.Tensor: + return _transform_matrix_cached( + device.type, device.index, dim, seed_offset, "qjl_forward" + ) + + +def get_turboquant_qjl_inverse_transform_matrix( + device: torch.device, + dim: int, + seed_offset: int = 0, +) -> torch.Tensor: + return _transform_matrix_cached( + device.type, device.index, dim, seed_offset, "qjl_inverse" + ) + + +@cache +def _mse_to_qjl_matrix_cached( + device_type: str, + device_index: int | None, + dim: int, + mse_seed_offset: int, + qjl_seed_offset: int, +) -> torch.Tensor: + device = torch.device(device_type, device_index) + mse_inverse = get_turboquant_mse_inverse_transform_matrix( + device, dim, mse_seed_offset + ) + qjl_forward = get_turboquant_qjl_transform_matrix(device, dim, qjl_seed_offset) + return torch.matmul(mse_inverse, qjl_forward) + + +def get_turboquant_mse_to_qjl_matrix( + device: torch.device, + dim: int, + mse_seed_offset: int = 0, + qjl_seed_offset: int = 0, +) -> torch.Tensor: + return _mse_to_qjl_matrix_cached( + device.type, + device.index, + dim, + mse_seed_offset, + qjl_seed_offset, + ) + + +def _beta_coordinate_pdf(x: torch.Tensor, dim: int) -> torch.Tensor: + exponent = 0.5 * (dim - 3) + log_norm = ( + torch.lgamma(torch.tensor(dim / 2.0, dtype=torch.float64)) + - 0.5 * math.log(math.pi) + - torch.lgamma(torch.tensor((dim - 1) / 2.0, dtype=torch.float64)) + ) + base = (1.0 - x.square()).clamp_min(TURBOQUANT_CODEBOOK_EPS) + return torch.exp(log_norm + exponent * torch.log(base)) + + +@cache +def _dimension_aware_codebook(dim: int, bits: int) -> torch.Tensor: + levels = 1 << bits + grid = torch.linspace( + -1.0 + TURBOQUANT_CODEBOOK_EPS, + 1.0 - TURBOQUANT_CODEBOOK_EPS, + TURBOQUANT_CODEBOOK_GRID_POINTS, + dtype=torch.float64, + ) + weights = _beta_coordinate_pdf(grid, dim) + centroids = torch.linspace( + -1.0 + 1.0 / (levels + 1), + 1.0 - 1.0 / (levels + 1), + levels, + dtype=torch.float64, + ) + for _ in range(200): + bounds = torch.empty(levels + 1, dtype=torch.float64) + bounds[0] = -1.0 + bounds[-1] = 1.0 + bounds[1:-1] = 0.5 * (centroids[:-1] + centroids[1:]) + assignments = torch.bucketize(grid, bounds[1:-1]) + masses = torch.zeros(levels, dtype=torch.float64) + sums = torch.zeros(levels, dtype=torch.float64) + masses.scatter_add_(0, assignments, weights) + sums.scatter_add_(0, assignments, weights * grid) + new_centroids = sums / masses.clamp_min(1e-18) + if torch.max(torch.abs(new_centroids - centroids)) < 1e-10: + centroids = new_centroids + break + centroids = new_centroids + return centroids.to(torch.float32) + + +def get_turboquant_centroids( + device: torch.device, + dim: int, + bits: int, +) -> torch.Tensor: + return _dimension_aware_codebook(dim, bits).to(device=device) + + +@cache +def _bit_layout( + device_type: str, + device_index: int | None, + head_size: int, + bits: int, +) -> tuple[torch.Tensor, torch.Tensor]: + device = torch.device(device_type, device_index) + flat_positions = torch.arange(head_size * bits, dtype=torch.int64, device=device) + return flat_positions // 8, flat_positions % 8 + + +def pack_turboquant_indices(indices: torch.Tensor, bits: int) -> torch.Tensor: + if bits <= 0: + shape = (*indices.shape[:-1], 0) + return torch.empty(shape, dtype=torch.uint8, device=indices.device) + head_size = indices.shape[-1] + num_bytes = (head_size * bits + 7) // 8 + byte_idx, bit_shift = _bit_layout( + indices.device.type, indices.device.index, head_size, bits + ) + bits_view = torch.zeros( + (*indices.shape[:-1], num_bytes), dtype=torch.uint8, device=indices.device + ) + expanded = indices.to(torch.int64).unsqueeze(-1) + offsets = torch.arange(bits, dtype=torch.int64, device=indices.device) + expanded_bits = ((expanded >> offsets) & 1).reshape(*indices.shape[:-1], -1) + bits_view.scatter_add_( + -1, + byte_idx.reshape(*((1,) * (indices.ndim - 1)), -1).expand_as(expanded_bits), + (expanded_bits << bit_shift.reshape(*((1,) * (indices.ndim - 1)), -1)).to( + torch.uint8 + ), + ) + return bits_view + + +def unpack_turboquant_indices( + packed: torch.Tensor, + head_size: int, + bits: int, +) -> torch.Tensor: + if bits <= 0: + shape = (*packed.shape[:-1], head_size) + return torch.zeros(shape, dtype=torch.uint8, device=packed.device) + byte_idx, bit_shift = _bit_layout( + packed.device.type, packed.device.index, head_size, bits + ) + gathered = packed.index_select(-1, byte_idx).to(torch.int64) + gathered = gathered.reshape(*packed.shape[:-1], head_size, bits) + shifts = bit_shift.reshape(*((1,) * (packed.ndim - 1)), head_size, bits) + values = ((gathered >> shifts) & 1) << torch.arange( + bits, dtype=torch.int64, device=packed.device + ).reshape(*((1,) * (packed.ndim - 1)), 1, bits) + return values.sum(dim=-1).to(torch.uint8) + + +def _norms_to_bytes(norms: torch.Tensor, byte_width: int) -> torch.Tensor: + norm_half = norms.to(torch.float16).contiguous() + return norm_half.reshape(-1).view(torch.uint8).reshape(*norm_half.shape, byte_width) + + +def _bytes_to_norms(norm_bytes: torch.Tensor, byte_width: int) -> torch.Tensor: + raw = norm_bytes.contiguous().reshape(-1, byte_width).view(torch.float16) + return raw.reshape(*norm_bytes.shape[:-1], 1).to(torch.float32) + + +def build_turboquant_outlier_masks( + x: torch.Tensor, + kv_cache_dtype: str, +) -> tuple[torch.Tensor, torch.Tensor]: + outlier_count, _ = get_turboquant_group_dims(x.shape[-1], kv_cache_dtype) + x_fp32 = x.to(torch.float32) + score = x_fp32.reshape(-1, x_fp32.shape[-2], x_fp32.shape[-1]).square().mean(dim=0) + outlier_idx = torch.topk(score, k=outlier_count, dim=-1).indices + outlier_idx = torch.sort(outlier_idx, dim=-1).values + all_idx = torch.arange(x.shape[-1], device=x.device, dtype=torch.int64) + all_idx = all_idx.unsqueeze(0).expand(score.shape[0], -1) + regular_mask = torch.ones_like(all_idx, dtype=torch.bool) + regular_mask.scatter_(1, outlier_idx, False) + regular_idx = all_idx[regular_mask].reshape(score.shape[0], -1) + return outlier_idx, regular_idx + + +def _gather_group(x: torch.Tensor, group_indices: torch.Tensor) -> torch.Tensor: + return torch.gather( + x, + dim=-1, + index=group_indices.unsqueeze(0).expand(x.shape[0], -1, -1), + ) + + +def validate_turboquant_group_indices( + x: torch.Tensor, + group_indices: tuple[torch.Tensor, torch.Tensor], +) -> None: + if group_indices[0].shape[0] != x.shape[1] or group_indices[1].shape[0] != x.shape[1]: + raise ValueError("TurboQuant group metadata must match the KV head count.") + + +def apply_turboquant_query_transforms( + query: torch.Tensor, + group_indices: tuple[torch.Tensor, torch.Tensor], + rotations: tuple[torch.Tensor, torch.Tensor], + qjl_matrices: tuple[torch.Tensor, torch.Tensor], + kv_head_for_query_head: torch.Tensor | None = None, + per_query_group_indices: tuple[torch.Tensor, torch.Tensor] | None = None, +) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + query_fp32 = query.to(torch.float32) + if per_query_group_indices is None: + assert kv_head_for_query_head is not None + gathered_indices = tuple( + group.index_select(0, kv_head_for_query_head) for group in group_indices + ) + else: + gathered_indices = per_query_group_indices + gathered_groups = tuple( + _gather_group(query_fp32, group) for group in gathered_indices + ) + q_rot = tuple( + _apply_mse_transform(group_tensor, rotation) + for group_tensor, rotation in zip(gathered_groups, rotations, strict=True) + ) + q_qjl = tuple( + _apply_qjl_transform(group_tensor, qjl_matrix) + * (TURBOQUANT_QJL_SCALE / group_tensor.shape[-1]) + for group_tensor, qjl_matrix in zip(gathered_groups, qjl_matrices, strict=True) + ) + return q_rot, q_qjl + + +def scatter_turboquant_output( + head_size: int, + dtype: torch.dtype, + group_outputs: tuple[torch.Tensor, torch.Tensor], + group_indices: tuple[torch.Tensor, torch.Tensor], + out: torch.Tensor | None = None, +) -> torch.Tensor: + if out is None: + output = torch.zeros( + (*group_outputs[0].shape[:-1], head_size), + dtype=torch.float32, + device=group_outputs[0].device, + ) + else: + output = out + output.zero_() + for group_output, indices in zip(group_outputs, group_indices, strict=True): + output.scatter_add_( + -1, + indices.unsqueeze(0).expand(group_output.shape[0], -1, -1), + group_output.to(dtype=output.dtype), + ) + if out is not None: + return out + return output.to(dtype=dtype) + + +def quantize_turboquant_vectors( + x: torch.Tensor, + kv_cache_dtype: str, + rotations: tuple[torch.Tensor, torch.Tensor], + qjl_matrices: tuple[torch.Tensor, torch.Tensor], + centroids: dict[int, torch.Tensor], + group_indices: tuple[torch.Tensor, torch.Tensor], +) -> torch.Tensor: + validate_turboquant_group_indices(x, group_indices) + layout = get_turboquant_layout(kv_cache_dtype, x.shape[-1]) + groups = tuple( + _gather_group(x.to(torch.float32), indices) for indices in group_indices + ) + packed_groups: list[torch.Tensor] = [] + for group_x, group_layout, rotation, qjl_matrix in zip( + groups, layout.groups, rotations, qjl_matrices, strict=True + ): + vector_norms = group_x.norm(dim=-1, keepdim=True).clamp_min(1e-12) + unit = group_x / vector_norms + rotated = _apply_mse_transform(unit, rotation) + mse_indices = torch.zeros_like(rotated, dtype=torch.uint8) + rotated_hat = torch.zeros_like(rotated, dtype=torch.float32) + if group_layout.mse_bits > 0: + group_centroids = centroids[group_layout.mse_bits] + mse_indices = torch.abs(rotated.unsqueeze(-1) - group_centroids).argmin( + dim=-1 + ).to(torch.uint8) + rotated_hat = group_centroids[mse_indices.long()] + mse_hat = _apply_mse_inverse_transform(rotated_hat, rotation) + residual = unit - mse_hat + residual_norms = residual.norm(dim=-1, keepdim=True) + qjl_bits = (_apply_qjl_transform(residual, qjl_matrix) >= 0).to(torch.uint8) + packed_groups.append( + torch.cat( + ( + pack_turboquant_indices(mse_indices, group_layout.mse_bits), + pack_turboquant_indices(qjl_bits, 1), + _norms_to_bytes( + vector_norms.squeeze(-1), TURBOQUANT_VECTOR_NORM_BYTES + ), + _norms_to_bytes( + residual_norms.squeeze(-1), TURBOQUANT_RESIDUAL_NORM_BYTES + ), + ), + dim=-1, + ) + ) + return torch.cat(packed_groups, dim=-1) + + +def dequantize_turboquant_vectors( + packed: torch.Tensor, + kv_cache_dtype: str, + head_size: int, + rotations: tuple[torch.Tensor, torch.Tensor], + qjl_matrices: tuple[torch.Tensor, torch.Tensor], + centroids: dict[int, torch.Tensor], + group_indices: tuple[torch.Tensor, torch.Tensor], + dtype: torch.dtype, +) -> torch.Tensor: + if packed.shape[-2] != group_indices[0].shape[0]: + raise ValueError("TurboQuant packed tensor KV head count does not match metadata.") + layout = get_turboquant_layout(kv_cache_dtype, head_size) + group_outputs: list[torch.Tensor] = [] + cursor = 0 + for group_layout, rotation, qjl_matrix in zip( + layout.groups, rotations, qjl_matrices, strict=True + ): + group_packed = packed[..., cursor : cursor + group_layout.packed_bytes] + cursor += group_layout.packed_bytes + group_cursor = 0 + mse_indices = unpack_turboquant_indices( + group_packed[ + ..., group_cursor : group_cursor + group_layout.mse_payload_bytes + ], + group_layout.dim, + group_layout.mse_bits, + ) + group_cursor += group_layout.mse_payload_bytes + qjl_bits = unpack_turboquant_indices( + group_packed[ + ..., group_cursor : group_cursor + group_layout.qjl_payload_bytes + ], + group_layout.dim, + 1, + ) + group_cursor += group_layout.qjl_payload_bytes + vector_norms = _bytes_to_norms( + group_packed[ + ..., group_cursor : group_cursor + TURBOQUANT_VECTOR_NORM_BYTES + ], + TURBOQUANT_VECTOR_NORM_BYTES, + ) + group_cursor += TURBOQUANT_VECTOR_NORM_BYTES + residual_norms = _bytes_to_norms( + group_packed[ + ..., group_cursor : group_cursor + TURBOQUANT_RESIDUAL_NORM_BYTES + ], + TURBOQUANT_RESIDUAL_NORM_BYTES, + ) + rotated_hat = torch.zeros( + (*group_packed.shape[:-1], group_layout.dim), + dtype=torch.float32, + device=packed.device, + ) + if group_layout.mse_bits > 0: + rotated_hat = centroids[group_layout.mse_bits][mse_indices.long()] + mse_hat = _apply_mse_inverse_transform(rotated_hat, rotation) + qjl_signs = qjl_bits.to(torch.float32).mul_(2.0).sub_(1.0) + qjl_hat = _apply_qjl_inverse_transform(qjl_signs, qjl_matrix) * ( + TURBOQUANT_QJL_SCALE / group_layout.dim + ) + group_outputs.append((mse_hat + qjl_hat * residual_norms) * vector_norms) + return scatter_turboquant_output( + head_size=head_size, + dtype=dtype, + group_outputs=(group_outputs[0], group_outputs[1]), + group_indices=group_indices, + ) diff --git a/vllm/v1/attention/ops/turboquant_metadata.py b/vllm/v1/attention/ops/turboquant_metadata.py new file mode 100644 index 000000000000..b01e0a0baa5b --- /dev/null +++ b/vllm/v1/attention/ops/turboquant_metadata.py @@ -0,0 +1,314 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import json +from dataclasses import dataclass +from functools import cache +from pathlib import Path + +import torch + +from vllm.v1.attention.ops.turboquant_kv_cache import ( + TURBOQUANT_GROUP_ALIGNMENT, + canonicalize_turboquant_dtype, + get_turboquant_outlier_count, +) + +TURBOQUANT_METADATA_VERSION = 1 +TURBOQUANT_TRANSFORM_VERSION = "structured_hadamard_v1" +TURBOQUANT_CODEBOOK_VERSION = "lloyd_beta_v1" + + +@dataclass(frozen=True) +class TurboQuantTensorMetadata: + high_precision_indices: tuple[tuple[int, ...], ...] + + def get_group_indices( + self, + device: torch.device, + head_size: int, + kv_cache_dtype: str, + ) -> tuple[torch.Tensor, torch.Tensor]: + high_cpu, low_cpu = _cached_group_indices( + self.high_precision_indices, + head_size, + canonicalize_turboquant_dtype(kv_cache_dtype), + ) + if device.type == "cpu": + return high_cpu, low_cpu + return high_cpu.to(device=device), low_cpu.to(device=device) + + def to_json(self) -> list[list[int]]: + return [list(indices) for indices in self.high_precision_indices] + + +@cache +def _cached_group_indices( + high_precision_indices: tuple[tuple[int, ...], ...], + head_size: int, + kv_cache_dtype: str, +) -> tuple[torch.Tensor, torch.Tensor]: + outlier_count = get_turboquant_outlier_count(head_size, kv_cache_dtype) + if len(high_precision_indices) == 0: + raise ValueError("TurboQuant metadata must contain at least one KV head.") + all_idx = torch.arange(head_size, dtype=torch.int64) + high_groups: list[torch.Tensor] = [] + low_groups: list[torch.Tensor] = [] + for head_idx, high_indices in enumerate(high_precision_indices): + if len(high_indices) != outlier_count: + raise ValueError( + "TurboQuant metadata high-precision group size mismatch for " + f"head {head_idx}: expected {outlier_count}, got {len(high_indices)}." + ) + high = torch.tensor(high_indices, dtype=torch.int64) + if torch.any(high[:-1] >= high[1:]): + raise ValueError( + "TurboQuant metadata high-precision indices must be strictly sorted." + ) + if high.min().item() < 0 or high.max().item() >= head_size: + raise ValueError( + "TurboQuant metadata high-precision indices are out of range." + ) + if torch.unique(high).numel() != high.numel(): + raise ValueError( + "TurboQuant metadata high-precision indices must be unique." + ) + low_mask = torch.ones(head_size, dtype=torch.bool) + low_mask.scatter_(0, high, False) + low = all_idx[low_mask] + high_groups.append(high) + low_groups.append(low) + return torch.stack(high_groups, dim=0), torch.stack(low_groups, dim=0) + + +@dataclass(frozen=True) +class TurboQuantLayerMetadata: + key: TurboQuantTensorMetadata + value: TurboQuantTensorMetadata + + def to_json(self) -> dict[str, list[list[int]]]: + return { + "key_high_precision_indices": self.key.to_json(), + "value_high_precision_indices": self.value.to_json(), + } + + +@dataclass(frozen=True) +class TurboQuantCalibrationMetadata: + method: str + objective: str + num_prompts: int + max_seq_len: int + batch_size: int + num_observed_tokens: int + dtype: str + device: str + prompts_sha256: str + + def to_json(self) -> dict[str, object]: + return self.__dict__.copy() + + +@dataclass(frozen=True) +class TurboQuantMetadata: + recipe: str + head_size: int + model_name: str | None + layers: dict[str, TurboQuantLayerMetadata] + calibration: TurboQuantCalibrationMetadata | None = None + version: int = TURBOQUANT_METADATA_VERSION + transform_version: str = TURBOQUANT_TRANSFORM_VERSION + codebook_version: str = TURBOQUANT_CODEBOOK_VERSION + + def get_layer(self, layer_name: str) -> TurboQuantLayerMetadata: + candidate_names = _turboquant_layer_name_candidates(layer_name) + for candidate_name in candidate_names: + layer = self.layers.get(candidate_name) + if layer is not None: + return layer + raise KeyError( + "TurboQuant metadata does not contain layer " + f"{layer_name!r}. Tried aliases: " + f"{', '.join(repr(name) for name in candidate_names)}." + ) + + def to_json(self) -> dict[str, object]: + payload: dict[str, object] = { + "version": self.version, + "recipe": self.recipe, + "head_size": self.head_size, + "model_name": self.model_name, + "transform_version": self.transform_version, + "codebook_version": self.codebook_version, + "layers": { + layer_name: layer_metadata.to_json() + for layer_name, layer_metadata in self.layers.items() + }, + } + if self.calibration is not None: + payload["calibration"] = self.calibration.to_json() + return payload + + +def _parse_tensor_metadata( + payload: object, + field_name: str, +) -> TurboQuantTensorMetadata: + if not isinstance(payload, list): + raise ValueError(f"TurboQuant metadata field {field_name!r} must be a list.") + high_precision_indices: list[tuple[int, ...]] = [] + for head_payload in payload: + if not isinstance(head_payload, list) or not all( + isinstance(index, int) for index in head_payload + ): + raise ValueError( + f"TurboQuant metadata field {field_name!r} must contain integer lists." + ) + high_precision_indices.append(tuple(head_payload)) + return TurboQuantTensorMetadata(tuple(high_precision_indices)) + + +def turboquant_metadata_from_json(payload: dict[str, object]) -> TurboQuantMetadata: + version = int(payload.get("version", TURBOQUANT_METADATA_VERSION)) + if version != TURBOQUANT_METADATA_VERSION: + raise ValueError( + f"Unsupported TurboQuant metadata version {version}. Expected " + f"{TURBOQUANT_METADATA_VERSION}." + ) + recipe = payload.get("recipe") + if not isinstance(recipe, str): + raise ValueError("TurboQuant metadata must define a string recipe.") + recipe = canonicalize_turboquant_dtype(recipe) + head_size = payload.get("head_size") + if not isinstance(head_size, int) or head_size % TURBOQUANT_GROUP_ALIGNMENT != 0: + raise ValueError( + "TurboQuant metadata must define an aligned integer head_size." + ) + model_name = payload.get("model_name") + if model_name is not None and not isinstance(model_name, str): + raise ValueError("TurboQuant metadata model_name must be a string or null.") + layers_payload = payload.get("layers") + if not isinstance(layers_payload, dict): + raise ValueError("TurboQuant metadata must define an object-valued layers.") + layers: dict[str, TurboQuantLayerMetadata] = {} + for layer_name, layer_payload in layers_payload.items(): + if not isinstance(layer_name, str) or not isinstance(layer_payload, dict): + raise ValueError("TurboQuant metadata layers must be object-valued.") + layers[layer_name] = TurboQuantLayerMetadata( + key=_parse_tensor_metadata( + layer_payload.get("key_high_precision_indices"), + "key_high_precision_indices", + ), + value=_parse_tensor_metadata( + layer_payload.get("value_high_precision_indices"), + "value_high_precision_indices", + ), + ) + calibration_payload = payload.get("calibration") + calibration: TurboQuantCalibrationMetadata | None = None + if calibration_payload is not None: + if not isinstance(calibration_payload, dict): + raise ValueError( + "TurboQuant metadata calibration field must be an object or null." + ) + calibration = TurboQuantCalibrationMetadata( + method=str(calibration_payload.get("method", "")), + objective=str(calibration_payload.get("objective", "")), + num_prompts=int(calibration_payload.get("num_prompts", 0)), + max_seq_len=int(calibration_payload.get("max_seq_len", 0)), + batch_size=int(calibration_payload.get("batch_size", 0)), + num_observed_tokens=int(calibration_payload.get("num_observed_tokens", 0)), + dtype=str(calibration_payload.get("dtype", "")), + device=str(calibration_payload.get("device", "")), + prompts_sha256=str(calibration_payload.get("prompts_sha256", "")), + ) + return TurboQuantMetadata( + recipe=recipe, + head_size=head_size, + model_name=model_name, + layers=layers, + calibration=calibration, + version=version, + transform_version=str( + payload.get("transform_version", TURBOQUANT_TRANSFORM_VERSION) + ), + codebook_version=str(payload.get("codebook_version", TURBOQUANT_CODEBOOK_VERSION)), + ) + + +@cache +def load_turboquant_metadata(path: str) -> TurboQuantMetadata: + with open(path, encoding="utf-8") as f: + payload = json.load(f) + if not isinstance(payload, dict): + raise ValueError("TurboQuant metadata root must be a JSON object.") + return turboquant_metadata_from_json(payload) + + +def save_turboquant_metadata(metadata: TurboQuantMetadata, path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(metadata.to_json(), f, indent=2, sort_keys=True) + f.write("\n") + + +def discover_turboquant_metadata_path( + model_name_or_path: str | None, + explicit_path: str | None, +) -> str | None: + if explicit_path is not None: + return str(Path(explicit_path).expanduser().resolve()) + if model_name_or_path is None: + return None + model_path = Path(model_name_or_path).expanduser().resolve() + if model_path.is_file(): + model_path = model_path.parent + elif not model_path.is_dir(): + return None + metadata_path = model_path / "turboquant_kv.json" + if metadata_path.is_file(): + return str(metadata_path.resolve()) + return None + + +def build_default_turboquant_metadata( + *, + recipe: str, + head_size: int, + num_kv_heads: int, + layer_names: list[str], + model_name: str | None = None, +) -> TurboQuantMetadata: + recipe = canonicalize_turboquant_dtype(recipe) + outlier_count = get_turboquant_outlier_count(head_size, recipe) + default_high = tuple(tuple(range(outlier_count)) for _ in range(num_kv_heads)) + layer_metadata = TurboQuantLayerMetadata( + key=TurboQuantTensorMetadata(default_high), + value=TurboQuantTensorMetadata(default_high), + ) + return TurboQuantMetadata( + recipe=recipe, + head_size=head_size, + model_name=model_name, + layers={layer_name: layer_metadata for layer_name in layer_names}, + ) + + +def _turboquant_layer_name_candidates(layer_name: str) -> tuple[str, ...]: + candidates: list[str] = [] + + def add(name: str) -> None: + if name not in candidates: + candidates.append(name) + + add(layer_name) + if layer_name.endswith(".attn"): + add(layer_name.removesuffix(".attn")) + if layer_name.startswith("language_model."): + add(layer_name.removeprefix("language_model.")) + if layer_name.endswith(".self_attn.attn"): + add(layer_name.removesuffix(".attn")) + return tuple(candidates)