[Quantization] Add TurboQuant dynamic kv cache compression#38280
[Quantization] Add TurboQuant dynamic kv cache compression#38280lishunyang12 wants to merge 52 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements TurboQuant, an online vector quantization algorithm for KV cache compression, supporting sub-4-bit quantization (including fractional bit-widths) through random rotations and Lloyd-Max scalar quantization. The changes include the core quantization logic, bit-packing utilities, optimized Triton kernels for encoding/decoding, and integration into the vLLM attention layer via a pre-dequantization step. Feedback was provided regarding a contradiction between a code comment and the actual device initialization logic in the attention layer.
| init_device = torch.device("cuda") if torch.cuda.is_available() \ | ||
| else torch.device("cpu") |
There was a problem hiding this comment.
The comment "Use CPU init, will be moved to GPU on first use" directly contradicts the code, which initializes init_device to torch.device("cuda") if CUDA is available. This discrepancy can lead to confusion regarding the actual device placement and memory allocation strategy for TurboQuantState objects, potentially impacting performance expectations or debugging efforts. Please either update the comment to accurately reflect that the initialization occurs on CUDA when available, or modify the code to consistently initialize on CPU if that was the intended behavior for torch.compile compatibility.
86d28fc to
0ceb79c
Compare
|
How's performance ? |
0ceb79c to
89ff32f
Compare
Thanks for your Attention. I am still debugging this PR as the triton kernels are not fully in place. The Needle-in-a-Haystack test was based on my pure pytorch implementation early on which is not on par with the performance on what has been shown in the paper. |
5e74522 to
35f09ab
Compare
Phase 1 Benchmark ResultsModel: Qwen/Qwen2.5-1.5B-Instruct, GPU: H200, Mode: enforce_eager Quality: 100% match at ALL bit-widths
TTFT: No overhead (<1%)
ITL + E2E Latency: No overhead
Prefill Throughput: Identical
Batched Throughput (gen tok/s): Slight improvement at bs=16
Key takeaways
|
TQ_QJL=1: enables 1-bit residual correction on top of MSE quantization. TQ_BITS=3: sets key bit width (default 4). Slot layout updated to include QJL sign bytes + QJL norm. Signed-off-by: lishunyang <lishunyang12@163.com>
Encode: after MSE quantize, compute residual, project onto S matrix, extract 1-bit signs, pack into cache slot alongside MSE indices. Decode: unpack signs, reconstruct correction via sqrt(pi/2)/d * ||r|| * (signs @ S), add to MSE reconstruction. Slot layout: [outlier | mse_packed | qjl_signs | norm | qjl_norm] Activated via TQ_QJL=1 TQ_BITS=3 for 3-bit effective (2-bit MSE + 1-bit QJL). QJL uses unfused path (fused kernels fall back when QJL enabled). Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
When QJL is enabled, bit_width=3 but mse_bits=2 (1 bit reserved for QJL signs). The pack/unpack must use mse_bits for correct slot sizing. Signed-off-by: lishunyang <lishunyang12@163.com>
- Delete triton_turboquant.py (dead O(d²) rotation, never used in prod) - Remove QJL from packed cache backend (doesn't work due to Hadamard ordering mismatch between PyTorch and Triton — keep in emulation mode) - Remove unused Pi/PiT fields from TurboQuantState - Remove QJL slot layout from get_kv_cache_spec - Remove TQ_QJL env var - Simplify mse_bits → bit_width (no QJL reservation needed) - Remove test for deleted file Signed-off-by: lishunyang <lishunyang12@163.com>
|
|
||
| has_outliers = normal_idx is not None and n_outliers > 0 | ||
| out = torch.empty(N, head_size, dtype=torch.bfloat16, device=cache.device) | ||
| scratch = torch.empty(N, BLOCK_D, dtype=torch.float32, device=cache.device) |
There was a problem hiding this comment.
vllm serve /mnt/data3/models/MiniMax/MiniMax-M2.5 -tp 4 --trust-remote-code --kv-cache-dtype turboquant
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] File "/mnt/data4/jxy/vllm/vllm/model_executor/layers/attention/kv_transfer_utils.py", line 39, in wrapper
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] return func(*args, **kwargs)
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] File "/mnt/data4/jxy/vllm/vllm/model_executor/layers/attention/attention.py", line 819, in unified_attention_with_output
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] self.impl.forward(
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] File "/mnt/data4/jxy/vllm/vllm/v1/attention/backends/turboquant_attn.py", line 239, in forward
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] key_cache, value_cache, block_table = self._decode_turboquant_cache(
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] File "/mnt/data4/jxy/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1181, in _fn
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] return fn(*args, **kwargs)
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] File "/mnt/data4/jxy/vllm/vllm/v1/attention/backends/turboquant_attn.py", line 333, in _decode_turboquant_cache
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] return self._decode_fused_4bit(
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] File "/mnt/data4/jxy/vllm/vllm/v1/attention/backends/turboquant_attn.py", line 395, in _decode_fused_4bit
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] decoded = fused_paged_decode(
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] File "/mnt/data4/jxy/vllm/vllm/v1/attention/ops/triton_fused_turboquant.py", line 467, in fused_paged_decode
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] scratch = torch.empty(N, BLOCK_D, dtype=torch.float32, device=cache.device)
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 96.00 GiB. GPU 0 has a total capacity of 139.81 GiB of which 14.61 GiB is free. Including non-PyTorch memory, this process has 125.19 GiB memory in use. Of the allocated memory 121.24 GiB is allocated by PyTorch, with 81.75 MiB allocated in private pools (e.g., CUDA Graphs), and 81.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]
I’m trying to test MiniMax-M2.5 on H20 and ran into the issue mentioned above.
scratch = torch.empty(N, BLOCK_D, dtype=torch.float32, device=cache.device)
Just a quick question: could scratch be consuming a large amount of GPU memory?
|
Qwen3.5 currently fails with |
|
Benchmark — MiniMax-M2.5, H20-3e TurboQuant |
|
I tested the long-context performance of Qwen3-8B on an H20. The results below clearly demonstrate a significant drop in performance after enabling TurboQuant.
Is this normal? |
|
This PR may have better code quality than India vibe coding ones, because the contributor uses anime avatar. |
|
FYI — PR #38479 has working hybrid Mamba model support. Tested Nemotron-Cascade-2-30B-A3B (Mamba+MoE+Attention hybrid, head_dim=128) on 8x RTX A4000 with |
|
Why is there multiple TurboQuant PR's already? Instead of focusing on one and getting the feature shipped. |
…instances) Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Second this. @lishunyang12 Are you going to check the other implementation and make a combination, or just give up on supporting Qwen3.5's hybrid Mamaba model? I am running Qwen3.5 models, and really appreciate if turboquant can work peacefully with the models. |
|
This pull request has merge conflicts that must be resolved before it can be |
| # Initialize on CUDA if available, CPU otherwise | ||
| init_device = ( | ||
| torch.device("cuda") | ||
| if torch.cuda.is_available() | ||
| else torch.device("cpu") | ||
| ) |
There was a problem hiding this comment.
| # Initialize on CUDA if available, CPU otherwise | |
| init_device = ( | |
| torch.device("cuda") | |
| if torch.cuda.is_available() | |
| else torch.device("cpu") | |
| ) | |
| # Initialize on accelerator if available, CPU otherwise | |
| init_device = torch.device(current_platform.device_name) |
| key_cache, | ||
| value_cache, | ||
| slot_mapping, | ||
| self.kv_cache_dtype, |
There was a problem hiding this comment.
does this kernel support turboquant?
|
I started with the latest commit aa6e58e in the H20 environment using the command: |
|
Close now as I afraid it don't have bandwidth to push it further anymore in near future (I am still an uni stu and exam period is coming). Hope this pr can be preserved as a reference that might be useful for final integration of turboquant into this repo. I learned a lot also from reading paper, POC to benchmarking and optimizing. Please move to #38479 as it is another promising integration also. |
TurboQuant KV Cache Quantization
Adds TurboQuant (ICLR 2026, Google Research) as a new
--kv-cache-dtypeoption. Compresses the KV cache from bf16 to packed 4-bit uint8 using Hadamard rotation + Lloyd-Max scalar quantization + outlier-aware channel allocation.Architecture
graph TB subgraph "vLLM Attention Layer" A[Query, Key, Value from model] --> B{attn_type?} B -->|DECODER| C[TurboQuantBackend] B -->|Encoder / Sliding Window| D[FlashAttn / Triton<br>auto fallback] end subgraph "TurboQuant Backend" C --> E[do_kv_cache_update] C --> F[forward] E --> G[Fused Triton Encode] G --> H[(Paged uint8<br>KV Cache)] F --> I[Fused Triton Decode] H --> I I --> J[bf16 K,V tensors] J --> K[unified_attention<br>standard Triton kernel] K --> L[Attention Output] end style H fill:#f96,stroke:#333 style G fill:#6b6,stroke:#333,color:#fff style I fill:#6b6,stroke:#333,color:#fff style K fill:#69f,stroke:#333,color:#fffEncode Pipeline (on token write)
graph LR A["K/V vector<br>[128 dims, bf16]"] --> B{Split channels} B -->|19 outlier channels| C["Keep bf16<br>(38 bytes)"] B -->|109 normal channels| D[Normalize L2] D --> E["Sign flip<br>(random ±1)"] E --> F["Hadamard butterfly<br>(7 levels, O(d log d))"] F --> G["Lloyd-Max 4-bit<br>(16 centroids)"] G --> H["Bit-pack<br>(2 idx/byte)"] C --> I["Cache slot [95 bytes]"] H --> I J["Norm fp16<br>(2 bytes)"] --> I style I fill:#f96,stroke:#333 style F fill:#6b6,stroke:#333,color:#fff style G fill:#6b6,stroke:#333,color:#fffDecode Pipeline (on attention)
graph LR A["Cache slot<br>[95 bytes uint8]"] --> B["Unpack<br>4-bit indices"] B --> C["Codebook lookup<br>(16 centroids)"] C --> D["Inverse Hadamard<br>(7 levels)"] D --> E["Sign flip +<br>scale by norm"] A --> F["Read outlier<br>bf16 channels"] E --> G["Interleave normal<br>+ outlier channels"] F --> G G --> H["Reconstructed K/V<br>[128 dims, bf16]"] H --> I["unified_attention<br>(standard Triton)"] style A fill:#f96,stroke:#333 style D fill:#6b6,stroke:#333,color:#fff style I fill:#69f,stroke:#333,color:#fffMemory Layout
block-beta columns 3 block:slot["Cache Slot (95 bytes per token per head)"]:3 A["Outlier bf16\n38 bytes\n(19 channels × 2B)"] B["Packed 4-bit indices\n55 bytes\n(109 channels / 2)"] C["Norm fp16\n2 bytes"] end block:baseline["Baseline bf16 (256 bytes per token per head)"]:3 D["128 dimensions × 2 bytes = 256 bytes"] end style A fill:#f96 style B fill:#6b6,color:#fff style C fill:#69f,color:#fff style D fill:#dddKnown Limitations & Status
1. Throughput overhead (0.36x baseline)
The current architecture decompresses the entire KV cache from packed uint8 to bf16 on every forward call before running attention. For a batch of 8 sequences at 128 tokens each, this means decoding ~64 blocks × 28 layers × 2 (K+V) = 3,584 Hadamard inverse transforms per forward step. The actual attention kernel (
unified_attention) runs on the decompressed bf16 — identical to baseline — but the decompression dominates latency.The fix is a fused decode+attention kernel that dequantizes KV blocks on-the-fly inside the attention dot product, never materializing the bf16 buffer.
2. Hybrid models (Qwen3.5) not supported
TurboQuant correctly auto-skips non-DECODER layers (Mamba, GDN, sliding-window) by falling back to the standard FlashAttn backend. However, vLLM requires all KV cache specs in a cache group to have compatible page sizes (
unify_kv_cache_spec_page_sizeinkv_cache_utils.py). TurboQuant uses uint8 slots of 95 bytes (vs bf16 at 256 bytes), and this page size cannot be reconciled with Mamba state cache pages. The fix requires framework-level changes to allow heterogeneous page sizes per cache group — not something fixable from the attention backend alone.Tested:
Qwen/Qwen3.5-35B-A3B-FP8fails withNotImplementedError: The page size of the layer is not divisible by the maximum page size.3. Minimum 4-bit quantization
3-bit and 2-bit quantization produce garbage output. At lower bit widths, the quantization error is too large for the 28-layer attention stack to tolerate. The TurboQuant paper uses QJL (Quantized Johnson-Lindenstrauss) 1-bit residual correction to fix this (e.g., 2-bit MSE + 1-bit QJL = 3-bit effective). QJL is partially implemented but does not produce correct results because the encode path computes residuals using PyTorch Hadamard (which has different butterfly element ordering than the Triton kernel).
4. CUDA graph memory overhead
TurboQuant decode allocates temporary bf16 buffers (for the decompressed cache) that get captured inside CUDA graphs. This increases graph pool memory from 0.5 GiB (baseline) to 5.9 GiB. The buffers are reused across graph replays (no per-call allocation), but the one-time capture cost is significant. The fused decode+attention kernel (limitation #1) would also eliminate this overhead.
Benchmark — Qwen2.5-7B-Instruct, H100 80GB, 50% GPU util
Reproduce
Setup
Quality test
Throughput benchmark
Unit tests
Configuration
--kv-cache-dtype turboquantTQ_LITE=1env varTQ_CUDA_WPH=1env varTQ_BITS=3env var