Skip to content

[Quantization] Add TurboQuant dynamic kv cache compression#38280

Closed
lishunyang12 wants to merge 52 commits intovllm-project:mainfrom
lishunyang12:feat/turboquant-kv-cache
Closed

[Quantization] Add TurboQuant dynamic kv cache compression#38280
lishunyang12 wants to merge 52 commits intovllm-project:mainfrom
lishunyang12:feat/turboquant-kv-cache

Conversation

@lishunyang12
Copy link
Copy Markdown

@lishunyang12 lishunyang12 commented Mar 26, 2026

TurboQuant KV Cache Quantization

Adds TurboQuant (ICLR 2026, Google Research) as a new --kv-cache-dtype option. Compresses the KV cache from bf16 to packed 4-bit uint8 using Hadamard rotation + Lloyd-Max scalar quantization + outlier-aware channel allocation.

vllm serve Qwen/Qwen2.5-7B-Instruct --kv-cache-dtype turboquant

Architecture

graph TB
    subgraph "vLLM Attention Layer"
        A[Query, Key, Value from model] --> B{attn_type?}
        B -->|DECODER| C[TurboQuantBackend]
        B -->|Encoder / Sliding Window| D[FlashAttn / Triton<br>auto fallback]
    end

    subgraph "TurboQuant Backend"
        C --> E[do_kv_cache_update]
        C --> F[forward]
        E --> G[Fused Triton Encode]
        G --> H[(Paged uint8<br>KV Cache)]
        F --> I[Fused Triton Decode]
        H --> I
        I --> J[bf16 K,V tensors]
        J --> K[unified_attention<br>standard Triton kernel]
        K --> L[Attention Output]
    end

    style H fill:#f96,stroke:#333
    style G fill:#6b6,stroke:#333,color:#fff
    style I fill:#6b6,stroke:#333,color:#fff
    style K fill:#69f,stroke:#333,color:#fff
Loading

Encode Pipeline (on token write)

graph LR
    A["K/V vector<br>[128 dims, bf16]"] --> B{Split channels}
    B -->|19 outlier channels| C["Keep bf16<br>(38 bytes)"]
    B -->|109 normal channels| D[Normalize L2]
    D --> E["Sign flip<br>(random ±1)"]
    E --> F["Hadamard butterfly<br>(7 levels, O(d log d))"]
    F --> G["Lloyd-Max 4-bit<br>(16 centroids)"]
    G --> H["Bit-pack<br>(2 idx/byte)"]
    C --> I["Cache slot [95 bytes]"]
    H --> I
    J["Norm fp16<br>(2 bytes)"] --> I

    style I fill:#f96,stroke:#333
    style F fill:#6b6,stroke:#333,color:#fff
    style G fill:#6b6,stroke:#333,color:#fff
Loading

Decode Pipeline (on attention)

graph LR
    A["Cache slot<br>[95 bytes uint8]"] --> B["Unpack<br>4-bit indices"]
    B --> C["Codebook lookup<br>(16 centroids)"]
    C --> D["Inverse Hadamard<br>(7 levels)"]
    D --> E["Sign flip +<br>scale by norm"]
    A --> F["Read outlier<br>bf16 channels"]
    E --> G["Interleave normal<br>+ outlier channels"]
    F --> G
    G --> H["Reconstructed K/V<br>[128 dims, bf16]"]
    H --> I["unified_attention<br>(standard Triton)"]

    style A fill:#f96,stroke:#333
    style D fill:#6b6,stroke:#333,color:#fff
    style I fill:#69f,stroke:#333,color:#fff
Loading

Memory Layout

block-beta
    columns 3
    block:slot["Cache Slot (95 bytes per token per head)"]:3
        A["Outlier bf16\n38 bytes\n(19 channels × 2B)"]
        B["Packed 4-bit indices\n55 bytes\n(109 channels / 2)"]
        C["Norm fp16\n2 bytes"]
    end
    block:baseline["Baseline bf16 (256 bytes per token per head)"]:3
        D["128 dimensions × 2 bytes = 256 bytes"]
    end

    style A fill:#f96
    style B fill:#6b6,color:#fff
    style C fill:#69f,color:#fff
    style D fill:#ddd
Loading

Status: WIP — This PR is functional and produces coherent output, but has significant limitations listed below. Feedback welcome on the approach before further optimization.

Known Limitations & Status

1. Throughput overhead (0.36x baseline)

The current architecture decompresses the entire KV cache from packed uint8 to bf16 on every forward call before running attention. For a batch of 8 sequences at 128 tokens each, this means decoding ~64 blocks × 28 layers × 2 (K+V) = 3,584 Hadamard inverse transforms per forward step. The actual attention kernel (unified_attention) runs on the decompressed bf16 — identical to baseline — but the decompression dominates latency.

The fix is a fused decode+attention kernel that dequantizes KV blocks on-the-fly inside the attention dot product, never materializing the bf16 buffer.

2. Hybrid models (Qwen3.5) not supported

TurboQuant correctly auto-skips non-DECODER layers (Mamba, GDN, sliding-window) by falling back to the standard FlashAttn backend. However, vLLM requires all KV cache specs in a cache group to have compatible page sizes (unify_kv_cache_spec_page_size in kv_cache_utils.py). TurboQuant uses uint8 slots of 95 bytes (vs bf16 at 256 bytes), and this page size cannot be reconciled with Mamba state cache pages. The fix requires framework-level changes to allow heterogeneous page sizes per cache group — not something fixable from the attention backend alone.

Tested: Qwen/Qwen3.5-35B-A3B-FP8 fails with NotImplementedError: The page size of the layer is not divisible by the maximum page size.

3. Minimum 4-bit quantization

3-bit and 2-bit quantization produce garbage output. At lower bit widths, the quantization error is too large for the 28-layer attention stack to tolerate. The TurboQuant paper uses QJL (Quantized Johnson-Lindenstrauss) 1-bit residual correction to fix this (e.g., 2-bit MSE + 1-bit QJL = 3-bit effective). QJL is partially implemented but does not produce correct results because the encode path computes residuals using PyTorch Hadamard (which has different butterfly element ordering than the Triton kernel).

4. CUDA graph memory overhead

TurboQuant decode allocates temporary bf16 buffers (for the decompressed cache) that get captured inside CUDA graphs. This increases graph pool memory from 0.5 GiB (baseline) to 5.9 GiB. The buffers are reused across graph replays (no per-call allocation), but the one-time capture cost is significant. The fused decode+attention kernel (limitation #1) would also eliminate this overhead.


Benchmark — Qwen2.5-7B-Instruct, H100 80GB, 50% GPU util

Metric Baseline (bf16) TurboQuant (4-bit)
KV cache capacity 453K tokens 1,221K tokens (2.7x)
Max concurrent requests (4K ctx) 110 298 (2.7x)
Throughput (bs=8, 128 tok) 767 tok/s 274 tok/s (0.36x)
Output quality Coherent
CUDA graphs Yes Yes
torch.compile Yes Yes

Reproduce

Setup

git fetch origin pull/38280/head:pr-38280 && git checkout pr-38280
pip install -e . && pip install tblib

Quality test

cat > /tmp/test_tq.py << 'EOF'
from vllm import LLM, SamplingParams
if __name__ == "__main__":
    llm = LLM("Qwen/Qwen2.5-7B-Instruct", kv_cache_dtype="turboquant",
              max_model_len=4096, gpu_memory_utilization=0.5)
    for o in llm.generate(
        ["What is 2+2?", "Explain gravity in 3 sentences.", "Write a haiku about the moon."],
        SamplingParams(max_tokens=100),
    ):
        print(o.outputs[0].text[:200])
        print()
EOF
python /tmp/test_tq.py

Throughput benchmark

cat > /tmp/bench_tq.py << 'EOF'
import time, torch
from vllm import LLM, SamplingParams
PROMPT = "Explain the theory of general relativity in detail."
def bench(dtype, label):
    llm = LLM("Qwen/Qwen2.5-7B-Instruct", kv_cache_dtype=dtype,
              max_model_len=4096, gpu_memory_utilization=0.5)
    p = SamplingParams(max_tokens=128, temperature=0.0)
    llm.generate([PROMPT], p)
    torch.cuda.synchronize(); t0 = time.perf_counter()
    out = llm.generate([PROMPT] * 8, p)
    torch.cuda.synchronize(); dt = time.perf_counter() - t0
    toks = sum(len(o.outputs[0].token_ids) for o in out)
    print(f"{label}: {toks} tokens in {dt:.2f}s = {toks/dt:.1f} tok/s")
    del llm; torch.cuda.empty_cache()
    return toks / dt
if __name__ == "__main__":
    t1 = bench("auto", "Baseline")
    t2 = bench("turboquant", "TurboQuant")
    print(f"Ratio: {t2/t1:.2f}x")
EOF
python /tmp/bench_tq.py

Unit tests

pytest tests/kernels/test_turboquant_fused.py -v

Configuration

Mode Usage Description
Standard --kv-cache-dtype turboquant Full Hadamard + 4-bit + outlier
TQ_LITE TQ_LITE=1 env var No rotation, pure scalar quantize
CUDA WPH TQ_CUDA_WPH=1 env var Warp-shuffle Hadamard decode (JIT)
Custom bits TQ_BITS=3 env var Set MSE quantization bit width

Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify bot added the v1 label Mar 26, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements TurboQuant, an online vector quantization algorithm for KV cache compression, supporting sub-4-bit quantization (including fractional bit-widths) through random rotations and Lloyd-Max scalar quantization. The changes include the core quantization logic, bit-packing utilities, optimized Triton kernels for encoding/decoding, and integration into the vLLM attention layer via a pre-dequantization step. Feedback was provided regarding a contradiction between a code comment and the actual device initialization logic in the attention layer.

Comment on lines +381 to +382
init_device = torch.device("cuda") if torch.cuda.is_available() \
else torch.device("cpu")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment "Use CPU init, will be moved to GPU on first use" directly contradicts the code, which initializes init_device to torch.device("cuda") if CUDA is available. This discrepancy can lead to confusion regarding the actual device placement and memory allocation strategy for TurboQuantState objects, potentially impacting performance expectations or debugging efforts. Please either update the comment to accurately reflect that the initialization occurs on CUDA when available, or modify the code to consistently initialize on CPU if that was the intended behavior for torch.compile compatibility.

@lishunyang12 lishunyang12 force-pushed the feat/turboquant-kv-cache branch from 86d28fc to 0ceb79c Compare March 26, 2026 20:08
@ExtReMLapin
Copy link
Copy Markdown
Contributor

How's performance ?

@lishunyang12 lishunyang12 force-pushed the feat/turboquant-kv-cache branch from 0ceb79c to 89ff32f Compare March 26, 2026 20:18
@lishunyang12
Copy link
Copy Markdown
Author

lishunyang12 commented Mar 26, 2026

How's performance ?

How's performance ?

Thanks for your Attention. I am still debugging this PR as the triton kernels are not fully in place. The Needle-in-a-Haystack test was based on my pure pytorch implementation early on which is not on par with the performance on what has been shown in the paper.

@lishunyang12 lishunyang12 marked this pull request as draft March 26, 2026 20:51
@lishunyang12 lishunyang12 force-pushed the feat/turboquant-kv-cache branch 8 times, most recently from 5e74522 to 35f09ab Compare March 26, 2026 21:25
@lishunyang12
Copy link
Copy Markdown
Author

lishunyang12 commented Mar 26, 2026

Phase 1 Benchmark Results

Model: Qwen/Qwen2.5-1.5B-Instruct, GPU: H200, Mode: enforce_eager

Quality: 100% match at ALL bit-widths

Config First-sentence match Exact match
baseline 12/12 12/12
TurboQuant 4-bit 12/12 12/12
TurboQuant 3-bit 12/12 12/12
TurboQuant 2-bit 12/12 12/12

TTFT: No overhead (<1%)

Config Short prompt Long prompt Overhead
baseline 9.3 ms 9.3 ms -
TQ 4-bit 9.3 ms 9.3 ms 0.99x
TQ 3-bit 9.1 ms 9.3 ms 0.99x
TQ 2-bit 9.2 ms 9.3 ms 0.99x

ITL + E2E Latency: No overhead

Config ITL (ms/tok) E2E (ms) Overhead
baseline 8.5 422.8 -
TQ 4-bit 8.4 417.7 0.99x
TQ 3-bit 8.3 416.7 0.99x
TQ 2-bit 8.4 418.7 0.99x

Prefill Throughput: Identical

Config 32 tok 128 tok 512 tok
baseline 771 5,740 21,121
TQ 4-bit 779 5,781 21,166
TQ 3-bit 780 5,752 21,075
TQ 2-bit 776 5,767 21,163

Batched Throughput (gen tok/s): Slight improvement at bs=16

Config bs=1 bs=4 bs=8 bs=16
baseline 119.0 473.6 935.5 1,524.4
TQ 4-bit 119.6 474.9 946.8 1,849.1
TQ 3-bit 119.5 478.3 944.2 1,853.6
TQ 2-bit 119.8 474.8 941.2 1,847.1

Key takeaways

  1. Zero quality loss — all 12 prompts produce identical output at every bit-width (2, 3, 4-bit), matching the paper's claims
  2. Zero latency overhead — TTFT, ITL, E2E, and prefill throughput are within noise of baseline
  3. ~21% throughput improvement at bs=16 — likely due to reduced memory pressure from quantized KV cache writes
  4. Pre-dequant mode proves the integration works correctly; Phase 2 (packed storage) will deliver actual memory savings

TQ_QJL=1: enables 1-bit residual correction on top of MSE quantization.
TQ_BITS=3: sets key bit width (default 4).
Slot layout updated to include QJL sign bytes + QJL norm.

Signed-off-by: lishunyang <lishunyang12@163.com>
Encode: after MSE quantize, compute residual, project onto S matrix,
extract 1-bit signs, pack into cache slot alongside MSE indices.

Decode: unpack signs, reconstruct correction via sqrt(pi/2)/d * ||r|| * (signs @ S),
add to MSE reconstruction.

Slot layout: [outlier | mse_packed | qjl_signs | norm | qjl_norm]

Activated via TQ_QJL=1 TQ_BITS=3 for 3-bit effective (2-bit MSE + 1-bit QJL).
QJL uses unfused path (fused kernels fall back when QJL enabled).

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
When QJL is enabled, bit_width=3 but mse_bits=2 (1 bit reserved for
QJL signs). The pack/unpack must use mse_bits for correct slot sizing.

Signed-off-by: lishunyang <lishunyang12@163.com>
@lishunyang12 lishunyang12 changed the title [Quantization] Add TurboQuant KV cache quantization (Phase 1, Working on Phase 2) [Quantization] TurboQuant KV cache quantization with separate backend + fused kernels Mar 30, 2026
@lishunyang12 lishunyang12 changed the title [Quantization] TurboQuant KV cache quantization with separate backend + fused kernels [Quantization] TurboQuant KV cache quantization — 2.7x memory, separate backend Mar 30, 2026
@lishunyang12 lishunyang12 changed the title [Quantization] TurboQuant KV cache quantization — 2.7x memory, separate backend [Quantization] Add TurboQuant KV cache quantization — 2.7x memory savings Mar 30, 2026
- Delete triton_turboquant.py (dead O(d²) rotation, never used in prod)
- Remove QJL from packed cache backend (doesn't work due to Hadamard
  ordering mismatch between PyTorch and Triton — keep in emulation mode)
- Remove unused Pi/PiT fields from TurboQuantState
- Remove QJL slot layout from get_kv_cache_spec
- Remove TQ_QJL env var
- Simplify mse_bits → bit_width (no QJL reservation needed)
- Remove test for deleted file

Signed-off-by: lishunyang <lishunyang12@163.com>
Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.


has_outliers = normal_idx is not None and n_outliers > 0
out = torch.empty(N, head_size, dtype=torch.bfloat16, device=cache.device)
scratch = torch.empty(N, BLOCK_D, dtype=torch.float32, device=cache.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm serve /mnt/data3/models/MiniMax/MiniMax-M2.5 -tp 4 --trust-remote-code --kv-cache-dtype turboquant
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]   File "/mnt/data4/jxy/vllm/vllm/model_executor/layers/attention/kv_transfer_utils.py", line 39, in wrapper
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]     return func(*args, **kwargs)
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]   File "/mnt/data4/jxy/vllm/vllm/model_executor/layers/attention/attention.py", line 819, in unified_attention_with_output
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]     self.impl.forward(
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]   File "/mnt/data4/jxy/vllm/vllm/v1/attention/backends/turboquant_attn.py", line 239, in forward
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]     key_cache, value_cache, block_table = self._decode_turboquant_cache(
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]   File "/mnt/data4/jxy/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1181, in _fn
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]     return fn(*args, **kwargs)
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]            ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]   File "/mnt/data4/jxy/vllm/vllm/v1/attention/backends/turboquant_attn.py", line 333, in _decode_turboquant_cache
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]     return self._decode_fused_4bit(
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]            ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]   File "/mnt/data4/jxy/vllm/vllm/v1/attention/backends/turboquant_attn.py", line 395, in _decode_fused_4bit
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]     decoded = fused_paged_decode(
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]               ^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]   File "/mnt/data4/jxy/vllm/vllm/v1/attention/ops/triton_fused_turboquant.py", line 467, in fused_paged_decode
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]     scratch = torch.empty(N, BLOCK_D, dtype=torch.float32, device=cache.device)
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 96.00 GiB. GPU 0 has a total capacity of 139.81 GiB of which 14.61 GiB is free. Including non-PyTorch memory, this process has 125.19 GiB memory in use. Of the allocated memory 121.24 GiB is allocated by PyTorch, with 81.75 MiB allocated in private pools (e.g., CUDA Graphs), and 81.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
(Worker_TP0 pid=1417539) ERROR 03-30 13:58:43 [multiproc_executor.py:949] 

I’m trying to test MiniMax-M2.5 on H20 and ran into the issue mentioned above.

scratch = torch.empty(N, BLOCK_D, dtype=torch.float32, device=cache.device)

Just a quick question: could scratch be consuming a large amount of GPU memory?

@bogoconic1
Copy link
Copy Markdown

Qwen3.5 currently fails with

%%writefile infer_llm_turbo.py

def main():

    from vllm import LLM, SamplingParams
    
    llm = LLM(
        model="/kaggle/input/models/qwen-lm/qwen-3-5/transformers/qwen3.5-35b-a3b/1",
        gpu_memory_utilization=0.95,
        kv_cache_dtype="turboquant",
        max_model_len=32768,
    )
    
    output = llm.generate(
        ["Solve for x in the equation x^2 + 6*x - 9 = 0. Put the answer within \\boxed"],
        SamplingParams(temperature=0.6, max_tokens=4096),
    )
    print(output[0].outputs[0].text)

if __name__ == '__main__':
    main()
(EngineCore pid=4011)   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 124, in decorate_context
(EngineCore pid=4011)     return func(*args, **kwargs)
(EngineCore pid=4011)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=4011)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 5855, in profile_cudagraph_memory
(EngineCore pid=4011)     self._init_minimal_kv_cache_for_profiling()
(EngineCore pid=4011)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 5795, in _init_minimal_kv_cache_for_profiling
(EngineCore pid=4011)     kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec)
(EngineCore pid=4011)                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=4011)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/core/kv_cache_utils.py", line 1257, in get_kv_cache_groups
(EngineCore pid=4011)     kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
(EngineCore pid=4011)                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore pid=4011)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/core/kv_cache_utils.py", line 942, in unify_kv_cache_spec_page_size
(EngineCore pid=4011)     raise NotImplementedError(
(EngineCore pid=4011) NotImplementedError: The page size of the layer is not divisible by the maximum page size. Cannot unify by adjusting block_size.

@chaunceyjiang
Copy link
Copy Markdown
Collaborator

Benchmark — MiniMax-M2.5, H20-3e
baseline

vllm serve /mnt/data3/models/MiniMax/MiniMax-M2.5 -tp 4   --trust-remote-code  --max-model-len=4090
vllm bench serve --backend vllm  --endpoint /v1/completions --dataset-name random --random-input 3000 --random-output 500 --max-concurrency 64 --num-prompt 512 --ignore-eos --temperature 0.0
============ Serving Benchmark Result ============
Successful requests:                     512       
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  188.56    
Total input tokens:                      1536000   
Total generated tokens:                  256000    
Request throughput (req/s):              2.72      
Output token throughput (tok/s):         1357.67   
Peak output token throughput (tok/s):    2112.00   
Peak concurrent requests:                79.00     
Total token throughput (tok/s):          9503.71   
---------------Time to First Token----------------
Mean TTFT (ms):                          1751.63   
Median TTFT (ms):                        1453.59   
P99 TTFT (ms):                           8178.78   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.65     
Median TPOT (ms):                        43.99     
P99 TPOT (ms):                           46.52     
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.65     
Median ITL (ms):                         31.06     
P99 ITL (ms):                            364.40    
==================================================

GPU KV cache size: 1,120,752 tokens
Maximum concurrency for 4,090 tokens per request: 273.62x

TurboQuant

 vllm serve /mnt/data3/models/MiniMax/MiniMax-M2.5 -tp 4   --trust-remote-code --kv-cache-dtype turboquant --max-model-len=4090
vllm bench serve --backend vllm  --endpoint /v1/completions --dataset-name random --random-input 3000 --random-output 500 --max-concurrency 64 --num-prompt 512 --ignore-eos --temperature 0.0
============ Serving Benchmark Result ============
Successful requests:                     512       
Failed requests:                         0         
Maximum request concurrency:             64        
Benchmark duration (s):                  701.14    
Total input tokens:                      1536000   
Total generated tokens:                  256000    
Request throughput (req/s):              0.73      
Output token throughput (tok/s):         365.12    
Peak output token throughput (tok/s):    448.00    
Peak concurrent requests:                72.00     
Total token throughput (tok/s):          2555.83   
---------------Time to First Token----------------
Mean TTFT (ms):                          2369.38   
Median TTFT (ms):                        1695.48   
P99 TTFT (ms):                           12208.58  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          170.67    
Median TPOT (ms):                        171.64    
P99 TPOT (ms):                           174.13    
---------------Inter-token Latency----------------
Mean ITL (ms):                           170.67    
Median ITL (ms):                         154.75    
P99 ITL (ms):                            569.67    
==================================================

GPU KV cache size: 3,019,056 tokens
Maximum concurrency for 4,090 tokens per request: 737.07x

@Luis-xu
Copy link
Copy Markdown

Luis-xu commented Mar 30, 2026

I tested the long-context performance of Qwen3-8B on an H20. The results below clearly demonstrate a significant drop in performance after enabling TurboQuant.

kv cache type ttft (s) tpot (ms) decode throughput (tokens/s) total throughput (tokens/s)
bf16 0.79 19.56 90.63 8269.74
tq 1.75 284.84 3.43 1027.93

Is this normal?

@vibhavagarwal5
Copy link
Copy Markdown
Contributor

@Luis-xu can you check #38479, i try to close out this gap although getting it to parity is very hard at this point i feel

@AlexRice13
Copy link
Copy Markdown

This PR may have better code quality than India vibe coding ones, because the contributor uses anime avatar.

@MidasMining
Copy link
Copy Markdown

FYI — PR #38479 has working hybrid Mamba model support. Tested Nemotron-Cascade-2-30B-A3B (Mamba+MoE+Attention hybrid, head_dim=128) on 8x RTX A4000 with --kv-cache-dtype tq3 + TQ_VALUE_BITS=8. 100% quality on a 14-check reasoning benchmark at 2x KV compression. The combined K+V slot layout in #38479 sidesteps the page size incompatibility listed in your Known Limitations.

@gaby
Copy link
Copy Markdown

gaby commented Mar 31, 2026

Why is there multiple TurboQuant PR's already? Instead of focusing on one and getting the feature shipped.

…instances)

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
@q2000s
Copy link
Copy Markdown

q2000s commented Mar 31, 2026

Why is there multiple TurboQuant PR's already? Instead of focusing on one and getting the feature shipped.

Second this. @lishunyang12 Are you going to check the other implementation and make a combination, or just give up on supporting Qwen3.5's hybrid Mamaba model? I am running Qwen3.5 models, and really appreciate if turboquant can work peacefully with the models.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lishunyang12.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Comment on lines +419 to +424
# Initialize on CUDA if available, CPU otherwise
init_device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Initialize on CUDA if available, CPU otherwise
init_device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# Initialize on accelerator if available, CPU otherwise
init_device = torch.device(current_platform.device_name)

key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this kernel support turboquant?

@yangyang-cs95
Copy link
Copy Markdown

I started with the latest commit aa6e58e in the H20 environment using the command:
vllm serve /data/Qwen3-8B --trust-remote-code --kv-cache-dtype turboquant --gpu-memory-utilization 0.5
There is still an error:
EngineCore pid=3660882) torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 80.00 GiB. GPU 0 has a total capacity of 95.09 GiB of which 23.40 GiB is free. Including non-PyTorch memory, this process has 71.68 GiB memory in use. Of the allocated memory 71.10 GiB is allocated by PyTorch, with 48.00 MiB allocated in private pools (e.g., CUDA Graphs), and 68.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
This occurs regardless of whether --gpu-memory-utilization is set to 0.7 or other values.

@lishunyang12
Copy link
Copy Markdown
Author

Close now as I afraid it don't have bandwidth to push it further anymore in near future (I am still an uni stu and exam period is coming). Hope this pr can be preserved as a reference that might be useful for final integration of turboquant into this repo. I learned a lot also from reading paper, POC to benchmarking and optimizing. Please move to #38479 as it is another promising integration also.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.