diff --git a/README.md b/README.md index 2961101..ad75725 100644 --- a/README.md +++ b/README.md @@ -507,6 +507,7 @@ Qwen3.5 uses Gated DeltaNet (75% RNN) + full attention (25% KV). Other engines r | **Hybrid cache sync** | Keep trimmable KV + non-trimmable RNN layers in sync | Qwen3.5 (Gated DeltaNet + attention) | | **Tool logits bias** | Jump-forward decoding — bias logits toward structured tokens | All models with `--enable-tool-logits-bias` | | **Auto tool recovery** | Detect broken text-format tool calls, convert to structured | All 18 parser formats (incl. Gemma 4) | +| **TurboQuant V-cache** | Rotate + Lloyd-Max compress V cache (86% savings on dense models) | All models with `--kv-cache-turboquant` | | **KV cache quantization** | Quantize prefix cache entries to reduce memory | All models with `--kv-cache-quantization` | | **Prefill chunking** | Configurable step size for large-prompt throughput | All models | | **Cloud routing** | Offload high-token requests to cloud LLM when local is slow | All models with `--cloud-model` | @@ -585,6 +586,7 @@ Also: logprobs API, structured JSON output (`response_format`), continuous batch | Flag | Description | Default | |------|-------------|---------| | `--prefill-step-size` | Tokens per prefill chunk | `2048` | +| `--kv-cache-turboquant` | TurboQuant V-cache compression (3-4 bit, 86% savings on dense models) | off | | `--kv-cache-quantization` | Quantize prefix cache entries for memory savings | off | | `--enable-prefix-cache` | Cache common prefixes across requests | off | | `--gpu-memory-utilization` | Fraction of device memory to use (0.0-1.0) | `0.90` | diff --git a/tests/test_turboquant.py b/tests/test_turboquant.py new file mode 100644 index 0000000..0f4e6f6 --- /dev/null +++ b/tests/test_turboquant.py @@ -0,0 +1,611 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for TurboQuant KV cache compression.""" + +import mlx.core as mx +import numpy as np +import pytest + +from vllm_mlx.turboquant import ( + LLOYD_MAX_BOUNDARIES, + LLOYD_MAX_CODEBOOKS, + TurboQuantConfig, + TurboQuantKVCache, + auto_select_bits, + generate_rotation_matrix, + turboquant_decode, + turboquant_encode, +) + +# --------------------------------------------------------------------------- +# TurboQuantConfig +# --------------------------------------------------------------------------- + + +class TestTurboQuantConfig: + def test_valid_3bit(self): + cfg = TurboQuantConfig(bits=3) + assert cfg.bits == 3 + + def test_valid_4bit(self): + cfg = TurboQuantConfig(bits=4) + assert cfg.bits == 4 + + def test_invalid_bits(self): + with pytest.raises(ValueError, match="bits must be 3 or 4"): + TurboQuantConfig(bits=2) + + def test_invalid_group_size(self): + with pytest.raises(ValueError, match="group_size must be >= 1"): + TurboQuantConfig(group_size=0) + + def test_defaults(self): + cfg = TurboQuantConfig() + assert cfg.bits == 3 + assert cfg.group_size == 32 + assert cfg.rotation_seed == 42 + + +# --------------------------------------------------------------------------- +# auto_select_bits +# --------------------------------------------------------------------------- + + +class TestAutoSelectBits: + def test_large_head_dim(self): + assert auto_select_bits(128) == 3 + + def test_medium_head_dim(self): + assert auto_select_bits(96) == 3 + + def test_small_head_dim(self): + assert auto_select_bits(64) == 4 + + def test_tiny_head_dim(self): + assert auto_select_bits(32) == 4 + + +# --------------------------------------------------------------------------- +# Lloyd-Max codebooks +# --------------------------------------------------------------------------- + + +class TestLloydMaxCodebooks: + def test_3bit_size(self): + assert LLOYD_MAX_CODEBOOKS[3].shape == (8,) + + def test_4bit_size(self): + assert LLOYD_MAX_CODEBOOKS[4].shape == (16,) + + def test_3bit_boundaries_size(self): + assert LLOYD_MAX_BOUNDARIES[3].shape == (7,) + + def test_4bit_boundaries_size(self): + assert LLOYD_MAX_BOUNDARIES[4].shape == (15,) + + def test_codebook_sorted(self): + for bits in (3, 4): + cb = np.array(LLOYD_MAX_CODEBOOKS[bits]) + assert np.all(cb[:-1] <= cb[1:]), f"{bits}-bit codebook not sorted" + + def test_boundaries_sorted(self): + for bits in (3, 4): + bd = np.array(LLOYD_MAX_BOUNDARIES[bits]) + assert np.all(bd[:-1] <= bd[1:]), f"{bits}-bit boundaries not sorted" + + def test_codebook_symmetric(self): + """Codebook should be approximately symmetric around 0.""" + for bits in (3, 4): + cb = np.array(LLOYD_MAX_CODEBOOKS[bits]) + assert abs(cb.sum()) < 0.1, f"{bits}-bit codebook not symmetric" + + +# --------------------------------------------------------------------------- +# Rotation matrix +# --------------------------------------------------------------------------- + + +class TestRotationMatrix: + def test_orthogonality(self): + """Q @ Q.T should be identity.""" + Q = generate_rotation_matrix(128, seed=42) + Q_np = np.array(Q, dtype=np.float32) + product = Q_np @ Q_np.T + np.testing.assert_allclose(product, np.eye(128), atol=1e-5) + + def test_deterministic(self): + """Same seed and dim should produce same matrix.""" + Q1 = generate_rotation_matrix(64, seed=123) + Q2 = generate_rotation_matrix(64, seed=123) + np.testing.assert_array_equal(np.array(Q1), np.array(Q2)) + + def test_different_seeds(self): + """Different seeds should produce different matrices.""" + Q1 = generate_rotation_matrix(64, seed=1) + Q2 = generate_rotation_matrix(64, seed=2) + assert not np.allclose(np.array(Q1), np.array(Q2)) + + def test_different_dims(self): + Q64 = generate_rotation_matrix(64, seed=42) + Q128 = generate_rotation_matrix(128, seed=42) + assert Q64.shape == (64, 64) + assert Q128.shape == (128, 128) + + def test_caching(self): + """Second call should return cached result.""" + Q1 = generate_rotation_matrix(32, seed=99) + Q2 = generate_rotation_matrix(32, seed=99) + # Should be the exact same object (cached) + assert Q1 is Q2 + + +# --------------------------------------------------------------------------- +# Encode / Decode roundtrip +# --------------------------------------------------------------------------- + + +class TestEncodeDecode: + @pytest.fixture + def rotation_128(self): + return generate_rotation_matrix(128, seed=42) + + @pytest.fixture + def rotation_64(self): + return generate_rotation_matrix(64, seed=42) + + @pytest.fixture + def gaussian_data_128(self): + """Simulate V tensor: (1, 8, 32, 128) — batch=1, 8 heads, 32 tokens, head_dim=128.""" + np.random.seed(0) + return mx.array(np.random.randn(1, 8, 32, 128).astype(np.float16)) + + @pytest.fixture + def gaussian_data_64(self): + np.random.seed(0) + return mx.array(np.random.randn(1, 8, 32, 64).astype(np.float16)) + + def test_4bit_roundtrip_quality_128(self, gaussian_data_128, rotation_128): + indices, scales, zeros = turboquant_encode( + gaussian_data_128, bits=4, group_size=32, rotation=rotation_128 + ) + reconstructed = turboquant_decode( + indices, + scales, + zeros, + bits=4, + group_size=32, + rotation=rotation_128, + head_dim=128, + ) + + # Cosine similarity per vector + orig = np.array(gaussian_data_128.reshape(-1, 128), dtype=np.float32) + recon = np.array(reconstructed.reshape(-1, 128), dtype=np.float32) + cosines = np.sum(orig * recon, axis=-1) / ( + np.linalg.norm(orig, axis=-1) * np.linalg.norm(recon, axis=-1) + 1e-8 + ) + mean_cosine = cosines.mean() + assert mean_cosine > 0.95, f"4-bit cosine {mean_cosine:.4f} < 0.95" + + def test_3bit_roundtrip_quality_128(self, gaussian_data_128, rotation_128): + indices, scales, zeros = turboquant_encode( + gaussian_data_128, bits=3, group_size=32, rotation=rotation_128 + ) + reconstructed = turboquant_decode( + indices, + scales, + zeros, + bits=3, + group_size=32, + rotation=rotation_128, + head_dim=128, + ) + + orig = np.array(gaussian_data_128.reshape(-1, 128), dtype=np.float32) + recon = np.array(reconstructed.reshape(-1, 128), dtype=np.float32) + cosines = np.sum(orig * recon, axis=-1) / ( + np.linalg.norm(orig, axis=-1) * np.linalg.norm(recon, axis=-1) + 1e-8 + ) + mean_cosine = cosines.mean() + assert mean_cosine > 0.90, f"3-bit cosine {mean_cosine:.4f} < 0.90" + + def test_4bit_roundtrip_quality_64(self, gaussian_data_64, rotation_64): + """head_dim=64 needs 4-bit for decent quality.""" + indices, scales, zeros = turboquant_encode( + gaussian_data_64, bits=4, group_size=32, rotation=rotation_64 + ) + reconstructed = turboquant_decode( + indices, + scales, + zeros, + bits=4, + group_size=32, + rotation=rotation_64, + head_dim=64, + ) + + orig = np.array(gaussian_data_64.reshape(-1, 64), dtype=np.float32) + recon = np.array(reconstructed.reshape(-1, 64), dtype=np.float32) + cosines = np.sum(orig * recon, axis=-1) / ( + np.linalg.norm(orig, axis=-1) * np.linalg.norm(recon, axis=-1) + 1e-8 + ) + mean_cosine = cosines.mean() + assert mean_cosine > 0.93, f"4-bit head_dim=64 cosine {mean_cosine:.4f} < 0.93" + + def test_4bit_mse(self, gaussian_data_128, rotation_128): + """MSE should be low for 4-bit.""" + indices, scales, zeros = turboquant_encode( + gaussian_data_128, bits=4, group_size=32, rotation=rotation_128 + ) + reconstructed = turboquant_decode( + indices, + scales, + zeros, + bits=4, + group_size=32, + rotation=rotation_128, + head_dim=128, + ) + mse = float(mx.mean((gaussian_data_128 - reconstructed) ** 2)) + assert mse < 0.05, f"4-bit MSE {mse:.4f} > 0.05" + + def test_3bit_mse(self, gaussian_data_128, rotation_128): + indices, scales, zeros = turboquant_encode( + gaussian_data_128, bits=3, group_size=32, rotation=rotation_128 + ) + reconstructed = turboquant_decode( + indices, + scales, + zeros, + bits=3, + group_size=32, + rotation=rotation_128, + head_dim=128, + ) + mse = float(mx.mean((gaussian_data_128 - reconstructed) ** 2)) + assert mse < 0.15, f"3-bit MSE {mse:.4f} > 0.15" + + def test_indices_dtype(self, gaussian_data_128, rotation_128): + indices, _, _ = turboquant_encode( + gaussian_data_128, bits=4, group_size=32, rotation=rotation_128 + ) + assert indices.dtype == mx.uint8 + + def test_packed_indices_range_4bit(self, gaussian_data_128, rotation_128): + """Packed indices are uint8 with nibble-packed values.""" + packed, _, _ = turboquant_encode( + gaussian_data_128, bits=4, group_size=32, rotation=rotation_128 + ) + assert packed.dtype == mx.uint8 + # Each byte has high nibble + low nibble, each in [0,15] + assert int(mx.max(packed)) <= 255 + + def test_packed_indices_range_3bit(self, gaussian_data_128, rotation_128): + packed, _, _ = turboquant_encode( + gaussian_data_128, bits=3, group_size=32, rotation=rotation_128 + ) + assert packed.dtype == mx.uint8 + + def test_output_shapes(self, gaussian_data_128, rotation_128): + """Verify output shapes are correct (packed indices).""" + packed, scales, zeros = turboquant_encode( + gaussian_data_128, bits=3, group_size=32, rotation=rotation_128 + ) + # packed indices: last dim = ceil(head_dim/2) due to nibble packing + assert packed.shape == (1, 8, 32, 64) # 128 // 2 + # scales/zeros: (..., seq_len, n_groups) + n_groups = 128 // 32 # = 4 + assert scales.shape == (1, 8, 32, n_groups) + assert zeros.shape == (1, 8, 32, n_groups) + + def test_non_divisible_group_size(self): + """head_dim not divisible by group_size should still work.""" + np.random.seed(0) + data = mx.array(np.random.randn(1, 4, 16, 100).astype(np.float16)) + rotation = generate_rotation_matrix(100, seed=42) + + indices, scales, zeros = turboquant_encode( + data, bits=4, group_size=32, rotation=rotation + ) + reconstructed = turboquant_decode( + indices, + scales, + zeros, + bits=4, + group_size=32, + rotation=rotation, + head_dim=100, + ) + assert reconstructed.shape == data.shape + + def test_single_token(self): + """Single-token V should work.""" + np.random.seed(0) + data = mx.array(np.random.randn(1, 4, 1, 128).astype(np.float16)) + rotation = generate_rotation_matrix(128, seed=42) + + indices, scales, zeros = turboquant_encode( + data, bits=4, group_size=32, rotation=rotation + ) + reconstructed = turboquant_decode( + indices, + scales, + zeros, + bits=4, + group_size=32, + rotation=rotation, + head_dim=128, + ) + assert reconstructed.shape == data.shape + + +# --------------------------------------------------------------------------- +# TurboQuantKVCache +# --------------------------------------------------------------------------- + + +class TestTurboQuantKVCache: + @pytest.fixture + def mock_kv_cache(self): + """Create a mock KVCache-like object.""" + from unittest.mock import MagicMock + + kv = MagicMock() + np.random.seed(0) + kv.keys = mx.array(np.random.randn(1, 8, 32, 128).astype(np.float16)) + kv.values = mx.array(np.random.randn(1, 8, 32, 128).astype(np.float16)) + kv.offset = 32 + return kv + + @pytest.fixture + def config(self): + return TurboQuantConfig(bits=4, group_size=32) + + def test_from_kv_cache(self, mock_kv_cache, config): + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config) + assert tq.keys is not None + assert tq.values_compressed[0] is not None # indices + assert tq.offset == 32 + assert tq.head_dim == 128 + + def test_to_kv_cache_roundtrip(self, mock_kv_cache, config): + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config) + restored = tq.to_kv_cache() + + # Keys should be identical (FP16, no compression) + np.testing.assert_array_equal( + np.array(restored.keys), np.array(mock_kv_cache.keys) + ) + + # Values should be close (compressed + decompressed) + orig = np.array(mock_kv_cache.values, dtype=np.float32) + recon = np.array(restored.values, dtype=np.float32) + cosines = np.sum(orig.reshape(-1, 128) * recon.reshape(-1, 128), axis=-1) / ( + np.linalg.norm(orig.reshape(-1, 128), axis=-1) + * np.linalg.norm(recon.reshape(-1, 128), axis=-1) + + 1e-8 + ) + assert cosines.mean() > 0.93 + + def test_keys_unchanged(self, mock_kv_cache, config): + """K must stay FP16, not be compressed.""" + original_keys = np.array(mock_kv_cache.keys) + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config) + np.testing.assert_array_equal(np.array(tq.keys), original_keys) + + def test_memory_savings(self, mock_kv_cache, config): + """Compressed V should use less memory than FP16 V.""" + fp16_v_bytes = mock_kv_cache.values.nbytes + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config) + + indices, scales, zeros = tq.values_compressed + compressed_bytes = indices.nbytes + scales.nbytes + zeros.nbytes + + ratio = compressed_bytes / fp16_v_bytes + # Nibble-packed indices (half size) + fp16 scales/zeros: ~31% of FP16 V + assert ratio < 0.40, f"Compression ratio {ratio:.2f} > 0.40" + + def test_3bit_memory_savings(self, mock_kv_cache): + config3 = TurboQuantConfig(bits=3, group_size=32) + fp16_v_bytes = mock_kv_cache.values.nbytes + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config3) + + indices, scales, zeros = tq.values_compressed + compressed_bytes = indices.nbytes + scales.nbytes + zeros.nbytes + + ratio = compressed_bytes / fp16_v_bytes + assert ratio < 0.40, f"3-bit ratio {ratio:.2f} > 0.40" + + def test_is_trimmable(self, mock_kv_cache, config): + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config) + assert tq.is_trimmable() + + def test_trim(self, mock_kv_cache, config): + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config) + tq.trim(10) + assert tq.offset == 22 + assert tq.keys.shape[-2] == 22 + + def test_trim_all(self, mock_kv_cache, config): + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config) + tq.trim(100) # More than offset + assert tq.offset == 0 + + def test_empty_cache(self, config): + from unittest.mock import MagicMock + + kv = MagicMock() + kv.keys = None + kv.values = None + kv.offset = 0 + + tq = TurboQuantKVCache.from_kv_cache(kv, config) + assert tq.keys is None + assert tq.offset == 0 + + restored = tq.to_kv_cache() + assert restored.keys is None + + def test_memory_bytes_property(self, mock_kv_cache, config): + tq = TurboQuantKVCache.from_kv_cache(mock_kv_cache, config) + mem = tq.memory_bytes + assert mem > 0 + # Should be less than FP16 keys + FP16 values + fp16_total = mock_kv_cache.keys.nbytes + mock_kv_cache.values.nbytes + assert mem < fp16_total + + +# --------------------------------------------------------------------------- +# Integration: memory_cache compress/decompress +# --------------------------------------------------------------------------- + + +class TestMemoryCacheIntegration: + """Test TurboQuant wiring in memory_cache.py.""" + + def _make_cache_list(self, n_layers=4, seq_len=32, n_heads=8, head_dim=128): + """Create a list of real KVCache layers.""" + from mlx_lm.models.cache import KVCache + + cache = [] + np.random.seed(0) + for _ in range(n_layers): + kv = KVCache() + kv.keys = mx.array( + np.random.randn(1, n_heads, seq_len, head_dim).astype(np.float16) + ) + kv.values = mx.array( + np.random.randn(1, n_heads, seq_len, head_dim).astype(np.float16) + ) + kv.offset = seq_len + cache.append(kv) + return cache + + def test_compress_decompress_roundtrip(self): + """Compress then decompress should produce valid KVCache layers.""" + from vllm_mlx.memory_cache import ( + _turboquant_compress_cache, + _turboquant_decompress_cache, + ) + + cache = self._make_cache_list() + compressed = _turboquant_compress_cache(cache, bits=4, group_size=32) + + # All layers should be TurboQuantKVCache + for layer in compressed: + assert isinstance(layer, TurboQuantKVCache) + + decompressed = _turboquant_decompress_cache(compressed) + + # All layers should have keys and values + for layer in decompressed: + assert layer.keys is not None + assert layer.values is not None + + def test_compress_memory_reduction(self): + """Compressed cache should use less total memory.""" + from vllm_mlx.memory_cache import ( + _turboquant_compress_cache, + estimate_kv_cache_memory, + ) + + cache = self._make_cache_list() + original_mem = sum(layer.keys.nbytes + layer.values.nbytes for layer in cache) + + compressed = _turboquant_compress_cache(cache, bits=4, group_size=32) + compressed_mem = estimate_kv_cache_memory(compressed) + + # Compressed should be significantly smaller + ratio = compressed_mem / original_mem + assert ratio < 0.75, f"Compression ratio {ratio:.2f} > 0.75" + assert compressed_mem > 0, "Memory estimate should not be 0" + + def test_none_layers_passthrough(self): + """None layers should pass through unchanged.""" + from vllm_mlx.memory_cache import ( + _turboquant_compress_cache, + _turboquant_decompress_cache, + ) + + cache = [None, None] + compressed = _turboquant_compress_cache(cache, bits=4, group_size=32) + assert compressed == [None, None] + + decompressed = _turboquant_decompress_cache(compressed) + assert decompressed == [None, None] + + def test_mixed_layers(self): + """Non-KVCache layers should pass through unchanged.""" + from unittest.mock import MagicMock + + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _turboquant_compress_cache + + # Create a mix: KVCache + non-KVCache + kv = KVCache() + np.random.seed(0) + kv.keys = mx.array(np.random.randn(1, 8, 32, 128).astype(np.float16)) + kv.values = mx.array(np.random.randn(1, 8, 32, 128).astype(np.float16)) + kv.offset = 32 + + mamba = MagicMock() # Not a KVCache instance + + cache = [kv, mamba, None] + compressed = _turboquant_compress_cache(cache, bits=4, group_size=32) + + assert isinstance(compressed[0], TurboQuantKVCache) + assert compressed[1] is mamba # Passed through + assert compressed[2] is None # Passed through + + def test_trim_cache_offset_with_turboquant(self): + """_trim_cache_offset should trim TurboQuantKVCache without mutating original.""" + from vllm_mlx.memory_cache import ( + _trim_cache_offset, + _turboquant_compress_cache, + ) + + cache = self._make_cache_list(n_layers=2, seq_len=32) + compressed = _turboquant_compress_cache(cache, bits=4, group_size=32) + + # Save original offsets + orig_offsets = [c.offset for c in compressed] + orig_keys_shapes = [c.keys.shape for c in compressed] + + # Trim 10 tokens + trimmed = _trim_cache_offset(compressed, trim_by=10) + + # Trimmed copies should have reduced offset + for tc in trimmed: + assert tc.offset == 22 # 32 - 10 + + # Original entries must NOT be mutated + for i, c in enumerate(compressed): + assert c.offset == orig_offsets[i] + assert c.keys.shape == orig_keys_shapes[i] + + def test_trim_cache_offset_mixed_layers(self): + """_trim_cache_offset handles mixed TurboQuantKVCache + other layers.""" + from unittest.mock import MagicMock + + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import ( + _trim_cache_offset, + _turboquant_compress_cache, + ) + + # Create KVCache + non-KVCache + kv = KVCache() + kv.keys = mx.array(np.random.randn(1, 4, 20, 128).astype(np.float16)) + kv.values = mx.array(np.random.randn(1, 4, 20, 128).astype(np.float16)) + kv.offset = 20 + + other = MagicMock() + + compressed = _turboquant_compress_cache([kv], bits=4, group_size=32) + mixed = compressed + [other, None] + + trimmed = _trim_cache_offset(mixed, trim_by=5) + assert len(trimmed) == 3 + assert trimmed[0].offset == 15 # TurboQuantKVCache trimmed + assert trimmed[2] is None diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 6b13e42..eeecabe 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -266,6 +266,14 @@ def serve_command(args): if getattr(args, "specprefill", False): print("\n ⚠ --specprefill is deprecated and has no effect.\n") + # Mutual exclusion: turboquant vs standard quantization + if args.kv_cache_turboquant and args.kv_cache_quantization: + print( + "\n Error: --kv-cache-turboquant and --kv-cache-quantization are " + "mutually exclusive. Choose one.\n" + ) + sys.exit(1) + # Build scheduler config enable_prefix_cache = args.enable_prefix_cache and not args.disable_prefix_cache @@ -294,6 +302,10 @@ def serve_command(args): kv_cache_quantization_bits=args.kv_cache_quantization_bits, kv_cache_quantization_group_size=args.kv_cache_quantization_group_size, kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens, + # TurboQuant V-only compression + kv_cache_turboquant=args.kv_cache_turboquant, + kv_cache_turboquant_bits=args.kv_cache_turboquant_bits, + kv_cache_turboquant_group_size=args.kv_cache_turboquant_group_size, ) print("Mode: Continuous batching (for multiple concurrent users)") @@ -313,7 +325,17 @@ def serve_command(args): else f"{args.cache_memory_percent * 100:.0f}% of RAM" ) print(f"Memory-aware cache: {cache_info}") - if args.kv_cache_quantization: + if args.kv_cache_turboquant: + bits_str = ( + str(args.kv_cache_turboquant_bits) + if args.kv_cache_turboquant_bits + else "auto" + ) + print( + f"TurboQuant V-cache: {bits_str}-bit, " + f"group_size={args.kv_cache_turboquant_group_size} (K stays FP16)" + ) + elif args.kv_cache_quantization: print( f"KV cache quantization: {args.kv_cache_quantization_bits}-bit, " f"group_size={args.kv_cache_quantization_group_size}" @@ -1009,6 +1031,28 @@ def main(): default=256, help="Minimum tokens for quantization to apply (default: 256)", ) + # TurboQuant KV cache compression (V-only, experimental) + serve_parser.add_argument( + "--kv-cache-turboquant", + action="store_true", + help="Enable TurboQuant V-cache compression (3-4 bit, ~86%% prefix cache savings " + "on dense models). K stays FP16. Experimental — mutually exclusive with " + "--kv-cache-quantization.", + ) + serve_parser.add_argument( + "--kv-cache-turboquant-bits", + type=int, + default=None, + choices=[3, 4], + help="Bit width for TurboQuant (default: auto-select by head_dim — " + "3-bit for head_dim>=96, 4-bit for head_dim=64)", + ) + serve_parser.add_argument( + "--kv-cache-turboquant-group-size", + type=int, + default=32, + help="Group size for TurboQuant quantization (default: 32)", + ) serve_parser.add_argument( "--stream-interval", type=int, diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index 503c508..a794743 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -109,6 +109,12 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int: for layer_cache in cache: if layer_cache is None: continue + # TurboQuantKVCache: has values_compressed instead of values + from .turboquant import TurboQuantKVCache + + if isinstance(layer_cache, TurboQuantKVCache): + total_bytes += layer_cache.memory_bytes + continue # Handle different cache object types # Check dict first since dicts have .keys() method that would match below if isinstance(layer_cache, dict) and "state" in layer_cache: @@ -170,6 +176,10 @@ class MemoryCacheConfig: kv_bits: int = 8 kv_group_size: int = 64 kv_min_quantize_tokens: int = 256 + # TurboQuant V-only compression (asymmetric: K=FP16, V=3-4bit) + kv_turboquant: bool = False + kv_turboquant_bits: int | None = None # None = auto-select by head_dim + kv_turboquant_group_size: int = 32 def __post_init__(self) -> None: if not 0.0 < self.max_memory_percent <= 1.0: @@ -289,6 +299,11 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: tc.group_size = layer_cache.group_size tc.bits = layer_cache.bits trimmed.append(tc) + elif hasattr(layer_cache, "values_compressed"): + # TurboQuantKVCache — use its trim method on a copy + tc = copy.copy(layer_cache) + tc.trim(trim_by) + trimmed.append(tc) elif ( hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys") @@ -406,6 +421,53 @@ def _dequantize_cache(cache: list[Any]) -> list[Any]: return result +def _turboquant_compress_cache( + cache: list[Any], bits: int | None, group_size: int +) -> list[Any]: + """Compress KVCache V tensors using TurboQuant (K stays FP16).""" + from mlx_lm.models.cache import KVCache + + from .turboquant import TurboQuantConfig, TurboQuantKVCache, auto_select_bits + + compressed_count = 0 + result = [] + for layer in cache: + if layer is None: + result.append(layer) + continue + if isinstance(layer, KVCache) and layer.keys is not None: + head_dim = layer.values.shape[-1] if layer.values is not None else 128 + actual_bits = bits if bits is not None else auto_select_bits(head_dim) + config = TurboQuantConfig(bits=actual_bits, group_size=group_size) + result.append(TurboQuantKVCache.from_kv_cache(layer, config)) + compressed_count += 1 + else: + result.append(layer) + + if compressed_count > 0: + logger.debug( + f"TurboQuant compressed {compressed_count}/{len(cache)} layers " + f"({bits or 'auto'}-bit, group_size={group_size})" + ) + return result + + +def _turboquant_decompress_cache(cache: list[Any]) -> list[Any]: + """Decompress TurboQuantKVCache layers back to regular KVCache.""" + from .turboquant import TurboQuantKVCache + + result = [] + for layer in cache: + if layer is None: + result.append(layer) + continue + if isinstance(layer, TurboQuantKVCache) and layer.keys is not None: + result.append(layer.to_kv_cache()) + else: + result.append(layer) + return result + + class MemoryAwarePrefixCache: """ Prefix cache with memory-based eviction. @@ -464,6 +526,14 @@ def __init__( f"max_entries={self._config.max_entries}" ) + def _decompress_cache(self, cache: list[Any]) -> list[Any]: + """Decompress cache layers (TurboQuant or standard quantization).""" + if self._config.kv_turboquant: + return _turboquant_decompress_cache(cache) + elif self._config.kv_quantize: + return _dequantize_cache(cache) + return cache + def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: """ Find cached KV state for the given tokens. @@ -500,8 +570,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: # Deep copy: cache objects have mutable offset/state that # generation modifies in-place, corrupting the stored entry. cache_out = copy.deepcopy(entry.cache) - if self._config.kv_quantize: - cache_out = _dequantize_cache(cache_out) + cache_out = self._decompress_cache(cache_out) return cache_out, [] # --- O(log N) prefix & supersequence match via sorted index --- @@ -576,11 +645,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._stats.hits += 1 self._stats.tokens_saved += n_requested self._last_match_type = "supersequence" - trimmed_cache = ( - _dequantize_cache(trimmed_cache) - if self._config.kv_quantize - else trimmed_cache - ) + trimmed_cache = self._decompress_cache(trimmed_cache) return trimmed_cache, [] else: self._entries.move_to_end(best_super.tokens) @@ -588,8 +653,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._stats.tokens_saved += n_requested self._last_match_type = "supersequence" cache_out = copy.deepcopy(best_super.cache) - if self._config.kv_quantize: - cache_out = _dequantize_cache(cache_out) + cache_out = self._decompress_cache(cache_out) return cache_out, [] # --- Prefix match --- @@ -600,8 +664,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: remaining = tokens[best_length:] self._last_match_type = "prefix" cache_out = copy.deepcopy(best_match.cache) - if self._config.kv_quantize: - cache_out = _dequantize_cache(cache_out) + cache_out = self._decompress_cache(cache_out) return cache_out, remaining # --- LCP (Longest Common Prefix) for divergent sequences --- @@ -668,11 +731,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: f"trimmed={excess} remaining={len(remaining)}" ) self._last_match_type = "lcp" - trimmed_cache = ( - _dequantize_cache(trimmed_cache) - if self._config.kv_quantize - else trimmed_cache - ) + trimmed_cache = self._decompress_cache(trimmed_cache) return trimmed_cache, remaining self._stats.misses += 1 @@ -715,8 +774,17 @@ def store( # Trim oversized KV arrays to actual used size cache = _trim_to_offset(cache) - # Quantize if enabled and sequence is long enough + # Compress cache for storage (TurboQuant or standard quantization) if ( + self._config.kv_turboquant + and len(tokens) >= self._config.kv_min_quantize_tokens + ): + cache = _turboquant_compress_cache( + cache, + self._config.kv_turboquant_bits, + self._config.kv_turboquant_group_size, + ) + elif ( self._config.kv_quantize and len(tokens) >= self._config.kv_min_quantize_tokens ): diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index c6e40d2..9163782 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -85,6 +85,11 @@ class SchedulerConfig: kv_cache_quantization_group_size: int = 64 kv_cache_min_quantize_tokens: int = 256 + # TurboQuant V-only compression (asymmetric: K=FP16, V=3-4bit rotated Lloyd-Max) + kv_cache_turboquant: bool = False + kv_cache_turboquant_bits: int | None = None # None = auto-select by head_dim + kv_cache_turboquant_group_size: int = 32 + # Paged cache settings (experimental - for memory efficiency) use_paged_cache: bool = ( False # Use BlockAwarePrefixCache instead of PrefixCacheManager @@ -1107,6 +1112,9 @@ def __init__( kv_bits=self.config.kv_cache_quantization_bits, kv_group_size=self.config.kv_cache_quantization_group_size, kv_min_quantize_tokens=self.config.kv_cache_min_quantize_tokens, + kv_turboquant=self.config.kv_cache_turboquant, + kv_turboquant_bits=self.config.kv_cache_turboquant_bits, + kv_turboquant_group_size=self.config.kv_cache_turboquant_group_size, ) self.memory_aware_cache = MemoryAwarePrefixCache( model=model, diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 4128fad..78e3b78 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -709,6 +709,14 @@ def main(): parser.add_argument("--kv-group-size", type=int, default=64, help=_ap.SUPPRESS) parser.add_argument("--draft-model", type=str, default=None, help=_ap.SUPPRESS) parser.add_argument("--num-draft-tokens", type=int, default=4, help=_ap.SUPPRESS) + # TurboQuant flags — accepted but only functional via rapid-mlx serve (cli.py) + parser.add_argument("--kv-cache-turboquant", action="store_true", help=_ap.SUPPRESS) + parser.add_argument( + "--kv-cache-turboquant-bits", type=int, default=None, help=_ap.SUPPRESS + ) + parser.add_argument( + "--kv-cache-turboquant-group-size", type=int, default=32, help=_ap.SUPPRESS + ) parser.add_argument( "--mcp-config", type=str, diff --git a/vllm_mlx/turboquant.py b/vllm_mlx/turboquant.py new file mode 100644 index 0000000..e0c7e3b --- /dev/null +++ b/vllm_mlx/turboquant.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +TurboQuant KV cache compression for prefix cache. + +V-only asymmetric compression: K stays FP16, V is quantized to 3-4 bits +using random orthogonal rotation + Lloyd-Max codebook quantization. + +Based on the TurboQuant paper (arXiv 2504.19874, ICLR 2026). + +Usage:: + + config = TurboQuantConfig(bits=3) + tq_cache = TurboQuantKVCache.from_kv_cache(kv_cache, config) + restored = tq_cache.to_kv_cache() +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import mlx.core as mx +import numpy as np + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class TurboQuantConfig: + """TurboQuant compression settings.""" + + bits: int = 3 # 3 or 4 + group_size: int = 32 + rotation_seed: int = 42 + + def __post_init__(self): + if self.bits not in (3, 4): + raise ValueError(f"bits must be 3 or 4, got {self.bits}") + if self.group_size < 1: + raise ValueError(f"group_size must be >= 1, got {self.group_size}") + + +def auto_select_bits(head_dim: int) -> int: + """Select bit width based on head dimension. + + 3-bit is safe for head_dim >= 96 (cosine > 0.95). + 4-bit is required for head_dim = 64 (3-bit degrades below 0.85). + """ + return 3 if head_dim >= 96 else 4 + + +# --------------------------------------------------------------------------- +# Lloyd-Max codebooks (precomputed for unit Gaussian) +# --------------------------------------------------------------------------- + +# Optimal Lloyd-Max quantizer for N(0,1) data. +# Centroids = conditional expectations E[X | X in bin_i]. +# Boundaries = decision thresholds between adjacent centroids. +# Reference: Lloyd (1982), Max (1960). Values from scipy Lloyd-Max solver. +# fmt: off + +# 3-bit: 8 centroids, 7 boundaries +_LLOYD_MAX_3BIT = mx.array([ + -2.1519, -1.3440, -0.7560, -0.2451, 0.2451, 0.7560, 1.3440, 2.1519 +], dtype=mx.float16) + +_LLOYD_MAX_3BIT_BOUNDS = mx.array([ + -1.7479, -1.0500, -0.5005, 0.0000, 0.5005, 1.0500, 1.7479 +], dtype=mx.float16) + +# 4-bit: 16 centroids, 15 boundaries +_LLOYD_MAX_4BIT = mx.array([ + -2.7326, -2.0690, -1.6180, -1.2562, -0.9423, -0.6568, -0.3881, -0.1284, + 0.1284, 0.3881, 0.6568, 0.9423, 1.2562, 1.6180, 2.0690, 2.7326 +], dtype=mx.float16) + +_LLOYD_MAX_4BIT_BOUNDS = mx.array([ + -2.4008, -1.8435, -1.4371, -1.0993, -0.7996, -0.5224, -0.2582, 0.0000, + 0.2582, 0.5224, 0.7996, 1.0993, 1.4371, 1.8435, 2.4008 +], dtype=mx.float16) +# fmt: on + +LLOYD_MAX_CODEBOOKS = {3: _LLOYD_MAX_3BIT, 4: _LLOYD_MAX_4BIT} +LLOYD_MAX_BOUNDARIES = {3: _LLOYD_MAX_3BIT_BOUNDS, 4: _LLOYD_MAX_4BIT_BOUNDS} + + +# --------------------------------------------------------------------------- +# Bit-packing: 2 indices per uint8 (nibble packing) +# --------------------------------------------------------------------------- + + +def _pack_nibbles(indices: mx.array) -> mx.array: + """Pack pairs of 4-bit indices into uint8 (2 per byte). + + Input shape: (..., N) where N is even. Values in [0, 15]. + Output shape: (..., N//2) dtype uint8. + """ + # Pad to even length if needed + *batch, n = indices.shape + if n % 2 != 0: + indices = mx.pad(indices, [(0, 0)] * len(batch) + [(0, 1)]) + n += 1 + + reshaped = indices.reshape(*batch, n // 2, 2) + high = reshaped[..., 0].astype(mx.uint8) << 4 + low = reshaped[..., 1].astype(mx.uint8) & 0x0F + return (high | low).astype(mx.uint8) + + +def _unpack_nibbles(packed: mx.array, original_len: int) -> mx.array: + """Unpack uint8 nibble-packed array back to individual indices. + + Input shape: (..., N//2) dtype uint8. + Output shape: (..., original_len) dtype uint8. + """ + high = (packed >> 4) & 0x0F + low = packed & 0x0F + *batch, n_packed = packed.shape + # Interleave high and low nibbles + unpacked = mx.concatenate( + [mx.expand_dims(high, -1), mx.expand_dims(low, -1)], axis=-1 + ).reshape(*batch, n_packed * 2) + return unpacked[..., :original_len] + + +# --------------------------------------------------------------------------- +# Rotation matrix (cached per head_dim) +# --------------------------------------------------------------------------- + +_rotation_cache: dict[tuple[int, int], mx.array] = {} + + +def generate_rotation_matrix(dim: int, seed: int = 42) -> mx.array: + """Generate a fixed random orthogonal matrix Q via QR decomposition. + + Result is cached per (dim, seed) — called once per unique head_dim. + """ + key = (dim, seed) + if key in _rotation_cache: + return _rotation_cache[key] + + # Use numpy for deterministic QR (mlx doesn't have linalg.qr) + rng = np.random.RandomState(seed) + random_matrix = rng.randn(dim, dim).astype(np.float32) + q, _ = np.linalg.qr(random_matrix) + # Keep float32 for rotation to preserve orthogonality during matmul. + # The V data is upcast to float32 for rotation, then back to float16. + rotation = mx.array(q, dtype=mx.float32) + + _rotation_cache[key] = rotation + return rotation + + +# --------------------------------------------------------------------------- +# Encode / Decode +# --------------------------------------------------------------------------- + + +def turboquant_encode( + values: mx.array, + bits: int, + group_size: int, + rotation: mx.array, +) -> tuple[mx.array, mx.array, mx.array]: + """Compress V tensor using TurboQuant. + + Args: + values: V tensor, shape (..., seq_len, head_dim). FP16. + bits: 3 or 4. + group_size: Elements per quantization group. + rotation: Orthogonal matrix, shape (head_dim, head_dim). + + Returns: + (packed_indices, scales, zeros) where: + - packed_indices: uint8, shape (..., seq_len, ceil(head_dim/2)) — nibble-packed + - scales: float16, shape (..., seq_len, n_groups) — per-group scale + - zeros: float16, shape (..., seq_len, n_groups) — per-group mean + """ + # 1. Rotate along head_dim: V @ Q^T (in float32 for precision) + rotated = values.astype(mx.float32) @ rotation.T + + # 2. Per-group normalize to unit Gaussian + orig_shape = rotated.shape + head_dim = orig_shape[-1] + n_groups = (head_dim + group_size - 1) // group_size + + # Pad if head_dim not divisible by group_size + if head_dim % group_size != 0: + pad_size = group_size * n_groups - head_dim + rotated = mx.pad(rotated, [(0, 0)] * (len(orig_shape) - 1) + [(0, pad_size)]) + + # Reshape to (..., seq_len, n_groups, group_size) + grouped = rotated.reshape(*orig_shape[:-1], n_groups, group_size) + + # Compute per-group statistics + group_mean = mx.mean(grouped, axis=-1, keepdims=True) # (..., n_groups, 1) + group_std = mx.maximum( + mx.sqrt(mx.mean((grouped - group_mean) ** 2, axis=-1, keepdims=True)), + mx.array(1e-6, dtype=mx.float16), + ) + + # Normalize to ~N(0,1) + normalized = (grouped - group_mean) / group_std + + # 3. Quantize using Lloyd-Max codebook via broadcasting comparison + # For each value, count how many boundaries it exceeds → gives the bin index. + # boundaries shape: (n_levels - 1,), normalized shape: (..., group_size) + boundaries = LLOYD_MAX_BOUNDARIES[bits] + # Expand for broadcasting: normalized[..., None] > boundaries[None, ...] + # Sum across boundary dim gives index + expanded = mx.expand_dims(normalized, axis=-1) # (..., group_size, 1) + # boundaries reshaped to (1, ..., 1, n_bounds) for broadcast + bounds = boundaries.reshape((1,) * len(normalized.shape) + (-1,)) + indices = mx.sum(expanded > bounds, axis=-1).astype(mx.uint8) # (..., group_size) + + # Reshape indices back to (..., seq_len, padded_head_dim) + indices = indices.reshape(*orig_shape[:-1], n_groups * group_size) + # Trim padding + if head_dim % group_size != 0: + indices = indices[..., :head_dim] + + # Scales and zeros: squeeze keepdim + scales = group_std.squeeze(-1) # (..., seq_len, n_groups) + zeros = group_mean.squeeze(-1) # (..., seq_len, n_groups) + + # 4. Bit-pack indices: 2 per uint8 (halves index memory) + packed_indices = _pack_nibbles(indices) + + return packed_indices, scales, zeros + + +def turboquant_decode( + packed_indices: mx.array, + scales: mx.array, + zeros: mx.array, + bits: int, + group_size: int, + rotation: mx.array, + head_dim: int, +) -> mx.array: + """Decompress V tensor from TurboQuant format. + + Args: + packed_indices: nibble-packed uint8 indices, shape (..., seq_len, head_dim//2) + scales: float16 per-group scale, shape (..., seq_len, n_groups) + zeros: float16 per-group mean, shape (..., seq_len, n_groups) + bits: 3 or 4 + group_size: Elements per quantization group + rotation: Orthogonal matrix, shape (head_dim, head_dim) + head_dim: Original head dimension (before any padding) + + Returns: + Reconstructed V tensor, shape (..., seq_len, head_dim). FP16. + """ + codebook = LLOYD_MAX_CODEBOOKS[bits] + n_groups = scales.shape[-1] + + # 1. Unpack nibble-packed indices and look up codebook values + indices = _unpack_nibbles(packed_indices, head_dim) + dequantized = codebook[indices] # (..., seq_len, head_dim) + + # 2. Pad if needed, reshape to groups + padded_dim = n_groups * group_size + if head_dim < padded_dim: + pad_size = padded_dim - head_dim + dequantized = mx.pad( + dequantized, [(0, 0)] * (len(dequantized.shape) - 1) + [(0, pad_size)] + ) + + orig_batch_shape = dequantized.shape[:-1] + grouped = dequantized.reshape(*orig_batch_shape, n_groups, group_size) + + # 3. Denormalize: x = x * scale + mean + scales_expanded = mx.expand_dims(scales, axis=-1) # (..., n_groups, 1) + zeros_expanded = mx.expand_dims(zeros, axis=-1) + grouped = grouped * scales_expanded + zeros_expanded + + # 4. Reshape back and trim padding + rotated = grouped.reshape(*orig_batch_shape, padded_dim) + if head_dim < padded_dim: + rotated = rotated[..., :head_dim] + + # 5. Inverse rotation: V_reconstructed = rotated @ Q (float32 for precision) + values = rotated.astype(mx.float32) @ rotation + + return values.astype(mx.float16) + + +# --------------------------------------------------------------------------- +# TurboQuantKVCache — prefix cache storage wrapper +# --------------------------------------------------------------------------- + + +class TurboQuantKVCache: + """KV cache with TurboQuant V compression for prefix cache storage. + + K stays FP16. V is compressed to 3-4 bits using rotation + Lloyd-Max. + This class is used in the prefix cache (store/fetch), not during + model forward passes. + """ + + def __init__( + self, + keys: mx.array, + values_compressed: tuple[mx.array, mx.array, mx.array], + offset: int, + config: TurboQuantConfig, + head_dim: int, + ): + self.keys = keys + self.values_compressed = values_compressed # (indices, scales, zeros) + self.offset = offset + self.config = config + self.head_dim = head_dim + + @classmethod + def from_kv_cache(cls, kv_cache, config: TurboQuantConfig) -> TurboQuantKVCache: + """Compress a standard KVCache into TurboQuant format.""" + keys = kv_cache.keys + values = kv_cache.values + offset = kv_cache.offset + + if keys is None or values is None: + return cls( + keys=None, + values_compressed=(None, None, None), + offset=0, + config=config, + head_dim=0, + ) + + # Get actual data up to offset + if offset < keys.shape[-2]: + keys = keys[..., :offset, :] + values = values[..., :offset, :] + + head_dim = values.shape[-1] + rotation = generate_rotation_matrix(head_dim, config.rotation_seed) + + indices, scales, zeros = turboquant_encode( + values, config.bits, config.group_size, rotation + ) + + return cls( + keys=keys, + values_compressed=(indices, scales, zeros), + offset=offset, + config=config, + head_dim=head_dim, + ) + + def to_kv_cache(self): + """Decompress back to a standard KVCache.""" + from mlx_lm.models.cache import KVCache + + kv = KVCache() + + if self.keys is None: + return kv + + rotation = generate_rotation_matrix(self.head_dim, self.config.rotation_seed) + indices, scales, zeros = self.values_compressed + + values = turboquant_decode( + indices, + scales, + zeros, + self.config.bits, + self.config.group_size, + rotation, + self.head_dim, + ) + + kv.keys = self.keys + kv.values = values + kv.offset = self.offset + return kv + + def is_trimmable(self) -> bool: + return True + + def trim(self, n: int) -> None: + """Trim n tokens from the end.""" + if self.keys is not None and n > 0: + new_offset = max(0, self.offset - n) + self.keys = self.keys[..., :new_offset, :] + indices, scales, zeros = self.values_compressed + self.values_compressed = ( + indices[..., :new_offset, :] if indices is not None else None, + scales[..., :new_offset, :] if scales is not None else None, + zeros[..., :new_offset, :] if zeros is not None else None, + ) + self.offset = new_offset + + @property + def memory_bytes(self) -> int: + """Estimate memory usage in bytes.""" + total = 0 + if self.keys is not None: + total += self.keys.nbytes + indices, scales, zeros = self.values_compressed + if indices is not None: + total += indices.nbytes + if scales is not None: + total += scales.nbytes + if zeros is not None: + total += zeros.nbytes + return total