Skip to content

QVAC-14555: TurboQuant (Vulkan): KV cache quantization (TBQ3_0 / TBQ4_0 / PQ3_0 / PQ4_0)#115

Merged
gianni-cor merged 27 commits into
tetherto:temp-7248from
jesusmb1995:turboquant
Apr 28, 2026
Merged

QVAC-14555: TurboQuant (Vulkan): KV cache quantization (TBQ3_0 / TBQ4_0 / PQ3_0 / PQ4_0)#115
gianni-cor merged 27 commits into
tetherto:temp-7248from
jesusmb1995:turboquant

Conversation

@jesusmb1995
Copy link
Copy Markdown

@jesusmb1995 jesusmb1995 commented Mar 27, 2026

Summary

Implements TurboQuant KV cache quantization (Zandieh et al., ICLR 2026) for CPU and Vulkan backends with full Flash Attention support. Compresses KV cache to 3.25-4.25 bits per value, enabling ~4-5x larger context windows on the same hardware.

Paper: https://arxiv.org/pdf/2504.19874
Community discussion:

Recommended configurations:

  • High compression + speed: K=pq3_0 V=pq3_0 — codebook-only, no QJL overhead. Minimal PPL/speed loss at 3.25 bpw with a small retrieval quality trade-off on long contexts.
  • High compression + and long-context quality: K=tbq3_0 V=pq3_0 — QJL-corrected keys with codebook-only values. Best retrieval accuracy at 3.75 avg bpw, with a moderate speed cost from QJL correction in the FA shader.

Features

  • Full set of TurboQuant types: tbq3_0, tbq4_0, pq3_0, pq4_0 (and _64 variants)
  • Automatic head_dim detection (64 vs 128) — user specifies pq3_0, internal type auto-selects
  • Coopmat1 and Coopmat2 Flash Attention support (noticeable prefill speedup)
  • Pre-compiled fused Flash Attention shaders for mixed K/V types (asymmetric compression)
  • QJL Stage 2 correction in all FA paths (scalar, cm1, cm2)
  • Comprehensive test/benchmark scripts (perplexity, throughput, RULER)
  • Cooperative copy_to_quant Vulkan path for TBQ/PQ (faster KV writes)

How does TurboQuant work?

Random rotations spread values evenly across coordinates, preventing concentration on a few axes where zero-coordinates waste bits. In high dimensions, the marginal distribution of each coordinate of a unit-sphere vector follows a Beta distribution that converges to N(0, 1/d) as d grows. The algorithm exploits this by placing Lloyd-Max codebook centroids at optimal positions for this known distribution, minimizing MSE reconstruction error. Centroids are found by solving a continuous 1-dimensional k-means problem.

An additional QJL correction step (Stage 2) reduces bias in dot-product estimation. It quantizes the residual error from Stage 1 to 1-bit by storing only the signs of the residual vector after applying a random rotation (Hadamard × sign diagonal). Since only signs are stored (no centroid rounding), the paper proves this yields an unbiased dot-product estimator. This step is important for maintaining retrieval quality on long contexts.

Optimization details

  • Hadamard instead of dense rotation: Rotations based on Hadamard use the butterfly pattern in O(d log d) instead of O(d²). Hadamard is deterministic, but applying a random sign diagonal preserves randomness while remaining orthogonal and invertible.

  • Dense rotation for K/V/Q at graph level, FHT in shader for QJL: At block sizes d=64/128, O(d²) is negligible and utilizes better GPU parallelism for the graph-level rotation. The butterfly FHT is used inside the Flash Attention shader for the QJL projection, avoiding the need to copy a dense matrix into the shader (which would add memory pressure). Since there is no Q cache, the QJL projection of Q must be recomputed every step to apply corrections against the 1-bit signs stored in K blocks.

Type Bits/val Block size Compression vs FP16 Description
q4_0 4.50 18 B 3.5x Baseline: 16 linear values
pq3_0 3.25 52 B 4.9x 8 Lloyd-Max centroids
pq4_0 4.25 68 B 3.8x 16 Lloyd-Max centroids
tbq3_0 4.25 68 B 3.8x 8 centroids + QJL correction
tbq4_0 5.25 84 B 3.0x 16 centroids + QJL correction

Implementation overview

  • vulkan-shaders-gen.cpp — orchestrates SPIR-V compilation of all variant combos
  • ggml-vulkan.cpp — host-side: creates pipeline objects, dispatches compute

TurboQuant KV cache shader flow (TBQ/PQ is ONLY a KV cache type, never model weights):

STEP 1: Write to cache (same for all paths)

  • copy_to_quant.comp: float K/V → TBQ/PQ quantized blocks
    • L2 norm, codebook binary search, 3/4-bit index packing
    • TBQ only: also computes QJL residual (qjl[], d_r)
    • PQ only: no QJL, smaller block, faster

STEP 2: Read cache at attention time (paths diverge here)

PATH A: Scalar Flash Attention (broad HW support, baseline)

  • flash_attn.comp
  • Includes: types.glsl, tq_utils.comp (via flash_attn_base.glsl), dequant_funcs.glsl
  • Dequantizes K/V inline, element by element
  • For TBQ/PQ K: uses centroid-gather optimization (reorders Q·K into per-centroid partial sums)
  • For TBQ K only: applies QJL correction to attention scores
  • Full fused kernel: QK^T → softmax → PV → output

PATH B: Cooperative matrix v1 Flash Attention (KHR, cross-vendor)

  • flash_attn_cm1.comp
  • K is fully dequantized into shared memory, then coopMatMulAdd for K·Q^T (subgroup-scope 16×16 tiles)
  • P·V accumulation is still scalar with inline dequant
  • Same QJL correction as scalar (applied to sfsh[] after coopmat store)

PATH C: Cooperative matrix v2 Flash Attention (NV only, most efficient)

  • flash_attn_cm2.comp
  • K and V loaded via coopMatLoadTensorNV with decode callback (dequant-on-load, no shared memory staging)
  • Both K·Q^T and P·V use coopMatMulAdd (workgroup-scope matrices)
  • QJL correction via raw byte reads from data_k[] with hardcoded byte offsets per type

PATH D: No-FA fallback, small N (MUL_MAT with N ≤ 8, e.g. decode)

  • mul_mat_vec_tbq3_0.comp / mul_mat_vec_tbq4_0.comp
  • Fused dequant + dot product, no centroid gather
  • QJL correction applied in the same kernel

PATH E: No-FA fallback, large N (K·Q MUL_MAT with N > 8, e.g. prefill)

  • This is the path exercised by -fa off with a TBQ/PQ K cache. Only the K·Q matmul is affected: V stays f16 under -fa off (upstream guard), so V·A stays on the existing f16 path.
  • Stage 1: mul_mm.comp runs with TBQ/PQ load_a_to_shmem — centroid dequant × d into shared memory, then generic tiled matmul (scalar / cm1 pipelines; cm2 falls through to cm1/scalar since no _mat_f16 cm2 shader exists for TBQ/PQ).
  • Stage 2 (TBQ only): mul_mm_tbq_qjl_correction.comp is dispatched after the main matmul as an additive pass — one workgroup per (row, col, batch), QUANT_K threads running the same Walsh–Hadamard + QJL dot product as the vec shader, accumulating d_r · √(π/2) / QUANT_K · sum_qjl(H(B)) into D.
  • PQ has no Stage 2 (no qjl[] / d_r), so Stage 1 alone is exact.
  • Requires B (src1) as f32; the scheduler is expected to feed f32 on this path. f16 src1 for standalone TBQ MUL_MAT reports not supported and falls back to CPU.
  • Fixes external review Issue 3 on PR QVAC-14555: TurboQuant (Vulkan): KV cache quantization (TBQ3_0 / TBQ4_0 / PQ3_0 / PQ4_0) #115: before this patch supports_op claimed TBQ/PQ MUL_MAT on cm2 devices (RTX 5090) but had no pipeline behind it, so the correctness run segfaulted. tests/test-backend-ops.cpp now covers all 8 TBQ/PQ types × n ∈ {1,8,16,32} as a repro.
  • Non-dim01-contiguous quantized src0 (permuted layouts) is now routed to the matrix path as well, so TBQ/PQ MUL_MAT works regardless of src0 stride pattern.

Example usage

llama-cli -m model.gguf --cache-type-k tbq3_0 --cache-type-v pq3_0
llama-cli -m model.gguf --cache-type-k pq3_0 --cache-type-v pq3_0

Works transparently with both head_dim=128 (Llama-3.1, Qwen, Mistral) and head_dim=64 (Llama-3.2-1B/3B) — the right block size is auto-selected.

Results / testing

Automatic CI/CTest should already cover the relevant backend regressions: test-backend-ops includes the TBQ/PQ backend-op cases, and test-copy-tbq-subgroups covers the Vulkan subgroup copy path. For a normal rebase/regression check, those automatic tests should be enough.

The shell scripts below are more useful for manual TurboQuant-focused testing and analysis, especially when you want to skip unrelated tests and compare report numbers more directly. The quickest manual TurboQuant sanity run is:

bash tests/test-turboquant.sh --full

That runs the TurboQuant correctness/sanity checks fairly quickly.

For a heavier but still manageable check against the simple report numbers, run the PPL and throughput scripts on the 5090 node from my checkout so the same in-place generated input files are reused:

cd ~/jberlanga/llama.cpp
bash tests/test-kv-cache-quantization-perp.sh -c large -m "mistral-q4km"
bash tests/test-kv-cache-quantization-perf.sh -c huge -m "mistral-q4km"

Compare the resulting PPL vs F16 and TG% numbers against the simple report. These longer scripts are probably overkill for a simple backend regression, but they are useful for analysis/report validation.

The test scripts are in this PR, but the input text is downloaded or auto-generated the first time the scripts run. In theory, test-kv-cache-quantization-perp.sh should use the same text offsets because the slice seed is fixed, but existing cached wiki.test.offset_<n_ctx>.raw files are reused. The report also includes a zip of the same generated input files for anyone who wants to reproduce the exact input texts. RULER is more sensitive because its data is generated from the NVIDIA RULER repo and depends on the generated validation.jsonl, tokenizer, dependency versions, and source data. PPL can also vary across hardware/backends due to numerical differences. For strict reproducibility, use the same hardware and the same already-generated input folders.

Please see Asana for latest available data: https://app.asana.com/1/45238840754660/task/1214143691877486/comment/1214346089994897?focus=true

PR for testing integration on LLM Addon: tetherto/qvac#1564

Limitations

  • head_dim must be 64 or 128. Codebooks and Hadamard transform are pre-computed for these dimensions.
  • d=64 quality is poor on small models — expected, as KV cache quantization generally degrades more on small models.
  • Metal shaders and vectorized CPU not yet implemented.
  • Optimized Flash Attention shaders require K to be PQ or TBQ, and V to be PQ, TBQ, Q4, Q8, or F16.
  • Quantized V with -fa off is not supported by this PR. Upstream llama_init_from_model rejects quantized V when flash attention is disabled ("V cache quantization requires flash_attn"), and that guard is intentionally left in place. The -fa off K·Q MUL_MAT fix in this PR would extend cleanly to A·V for a quantized V as well, but the v_trans V-cache layout used under -fa off is populated by ggml_set_rows with row_size=1, which corrupts any blck_size > 1 type at write time (reproducible on CPU as well, independent of backend). Fixing that is a KV-cache refactor out of scope here; the guard will be revisited once that lands.

TBQ / PQ Vulkan support matrix

What runs on GPU vs. is refused by the context, across FA on/off on dense and MoE models. The MoE-KV-cache rows behave the same as dense because attention itself is plain MUL_MAT / FLASH_ATTN_EXT, not MUL_MAT_ID; MoE routing (MUL_MAT_ID) only applies to the FFN weights, which are never stored as TBQ/PQ.

Scenario FA K-type V-type Path Status
Dense / MoE — KV cache on tbq3/4_0 or pq3/4_0 pq/tbq/q4_0/q8_0/f16 Fused FA (scalar / cm1 / cm2), QJL in kernel Full GPU
Dense / MoE — KV cache on tbq3/4_0 or pq3/4_0 other quantized (q5_0, q4_1, iq4_nl, k-quants, …) No matching Vulkan FA pipeline → per-layer backend split Runs, but attention falls back to CPU
Dense / MoE — KV cache off tbq3/4_0 or pq3/4_0 f16 K·Q via mul_mm.comp + QJL correction; V·A on the existing f16 path Full GPU (Path E)
Dense / MoE — KV cache off tbq3/4_0 or pq3/4_0 any quantized type (incl. tbq/pq) Context init refused (upstream FA-off rule: "V cache quantization requires flash_attn")

Notes:

  • Head dimensions of both 128 and 64 are supported; the _64 block variants (tbq*_0_64, pq*_0_64) have their own pipelines, codebooks, and sign tables.
  • MoE FFN weights are not in this table on purpose: TBQ/PQ are KV-cache quantizations only (llama-quantize has no TBQ/PQ target, and no GGUF stores FFN experts in those types), so MUL_MAT_ID never receives TBQ/PQ src0. Attention in MoE models is a plain MUL_MAT / FLASH_ATTN_EXT and therefore falls under the "KV cache" rows above.

Remaining work

  • SIMD optimization — AVX2/NEON for CPU quantize/dequantize
  • Metal shaders — Apple GPU backend support
  • 2-bit variant — even higher compression
  • Direct cosine similarity evaluation

@jesusmb1995 jesusmb1995 self-assigned this Mar 27, 2026
@jesusmb1995 jesusmb1995 changed the title Draft: TurboQuant TurboQuant: KV cache quantization with Hadamard transform (TQ3_0 / TQ4_0) Mar 27, 2026
@jesusmb1995

This comment was marked as outdated.

@jesusmb1995

This comment was marked as outdated.

@jesusmb1995 jesusmb1995 force-pushed the turboquant branch 2 times, most recently from 69522fb to 6497a86 Compare March 31, 2026 16:37
@jesusmb1995 jesusmb1995 changed the title TurboQuant: KV cache quantization with Hadamard transform (TQ3_0 / TQ4_0) TurboQuant: KV cache quantization with Hadamard transform (TBQ3_0 / TBQ4_0) Mar 31, 2026
@zoq
Copy link
Copy Markdown

zoq commented Apr 1, 2026

Are you planning to merge this before the rebase to the latest version of llama.cpp?

@jesusmb1995 jesusmb1995 force-pushed the turboquant branch 2 times, most recently from f7ba069 to 9d2a659 Compare April 7, 2026 18:20
@jesusmb1995
Copy link
Copy Markdown
Author

jesusmb1995 commented Apr 7, 2026

Are you planning to merge this before the rebase to the latest version of llama.cpp?

Not particularly, if the rebase to latest version of llama.cpp will happen soon then I will change the target to the correct temp branch. I think its better if I target latest llama.cpp version.

Edit: @zoq Since it seems we want this merged in about 1-2 weeks, I would target this version for now. yes, planning to merge this before the rebase.

Comment thread ggml/src/ggml-vulkan/ggml-vulkan.cpp
@gianni-cor

This comment was marked as resolved.

@gianni-cor

This comment has been minimized.

@jesusmb1995

This comment has been minimized.

Fix a latent correctness bug in the TurboQuant / PolarQuant copy_to_quant
cooperative shader that silently produces wrong bytes on any device whose
gl_SubgroupSize is less than the 32-thread workgroup (Intel Xe/Arc at 8/16,
ARM Mali 4/8/16, some Adreno configurations). Make the path cover every
supported subgroup size, plumb a runtime knob for testing, and add a
dedicated test suite with both real-hardware and software-Vulkan coverage.

Motivation
----------
The original copy_to_quant.comp TBQ/PQ path uses subgroupAdd() for the
per-block norm reductions and subgroupBallot() for the QJL sign-bit sketch,
assuming gl_SubgroupSize == 32 (= the workgroup size). On devices where the
native subgroup is smaller, those ops reduce only within a subgroup, not the
whole workgroup, so each subgroup sees its own partial sum and the output
bytes become whatever the first-subgroup partial happened to produce. The
SET_ROWS path has the same issue. The bug does not reproduce on most
production GPUs (NVIDIA fixed-32, AMD RDNA 32/64, Apple 32) but bites Intel
and several mobile GPUs.

Shader changes (copy_to_quant.comp)
-----------------------------------
* New specialization constant SG_SIZE at constant_id = 1 (slot 0 is already
  used by generic_binary_head.glsl's `norepeat` in the SET_ROWS path).
  Defaults to 32 so hosts that pass no spec info get the original shader.
* TQ_WG fixed at 32 (the workgroup size); NSG = TQ_WG / SG_SIZE is the
  number of subgroups per workgroup.
* New helper tq_wg_add(x): if NSG == 1 (SG_SIZE >= TQ_WG) returns
  subgroupAdd(x) -- identical to the original fast path and
  dead-code-eliminated by spec-constant folding; if NSG > 1 the per-
  subgroup subgroupAdd results are written to shared memory (tq_sh_red)
  and stitched with an [[unroll]]-ed sum. Replaces every subgroupAdd() in
  the TBQ/PQ/norm-correction paths.
* QJL sign-bit pack: when SG_SIZE >= TQ_WG the original subgroupBallot
  fast path runs; when SG_SIZE < TQ_WG it falls back to atomicOr into a
  shared uint array and a serial write-out. Same fast-path guard lets
  specialization fold the slow branch away when SG_SIZE == 32.
* SG_SIZE > TQ_WG (e.g. AMD wave64 with WG=32) is treated as NSG == 1
  via clamp(SG_SIZE, TQ_WG) in tq_wg_add, so those devices take the fast
  path even though half the wave is masked off.

Host plumbing (ggml-vulkan.cpp)
-------------------------------
* vk_device_struct grows a tbq_copy_sg_size field (0 = no override).
* Device init reads GGML_VK_TBQ_COPY_SG_SIZE from env, validates against
  {4, 8, 16, 32, 64} intersected with the device's
  [subgroup_min_size, subgroup_max_size], and emits a structured
  "tbq_copy_sg_size_status requested=R applied=A reason=X" line so tests
  can tell whether the override was applied or rejected (distinct from
  success/failure of the run itself).
* ggml_vk_load_shaders picks the (SG_SIZE spec const, requiredSubgroupSize)
  pair used for every CPY-to-quant and SET_ROWS-to-quant pipeline:
    - if the env override is set: that value
    - else if the device supports size control: mul_mat_subgroup_size
    - else: 0 (shader default SG_SIZE=32, no required size) -- matches
      pre-patch behaviour on drivers without VK_EXT_subgroup_size_control.
  The two-element spec-const vector is {0, SG_SIZE} for the plain CPY
  path (slot 0 is ignored by generic_unary_head.glsl) and {1, SG_SIZE}
  for SET_ROWS (slot 0 is `norepeat`, always 1).
* Adds a device-selection opt-in GGML_VK_ALLOW_CPU_DEVICES=1 so tests can
  pick up software Vulkan ICDs (lavapipe, SwiftShader) that ggml-vulkan
  normally filters out. Production code never sets this env var and the
  behaviour is unchanged when it isn't set.

New test (tests/test-copy-tbq-subgroups.cpp + CMakeLists)
---------------------------------------------------------
Self-spawning C++ test that for each (SG in {0, 4, 8, 16, 32, 64}, type,
shape) triple runs GPU quantize, compares against a CPU
ggml_quantize_chunk reference, and reports byte-mismatch + dequant NMSE
+ throughput. Key design choices:
  * Self-spawn (popen of --child N with a different
    GGML_VK_TBQ_COPY_SG_SIZE value per child) because the env var is
    consumed once at device init and can only be changed across processes.
  * Parses the structured status line from the backend to distinguish
    "applied" from "rejected" rows. Rejected rows are labelled
    SKIP-<reason> in the per-case table and excluded from the
    NMSE-spread assertion (they are duplicates of sg=0 and don't add
    independent coverage). Prior phrasing that labelled them OK was
    misleading.
  * --types comma-separated filter keeps the default CI run fast by
    iterating only a subset of TBQ/PQ types.
  * Shared pass/fail rule: nmse(gpu vs cpu) <= 1e-6 for every applied
    SG; the per-case table stays OK on the legs that couldn't exercise
    the stitch path on the host GPU.

Cross-subgroup-size coverage via lavapipe (tests/test-turboquant.sh)
--------------------------------------------------------------------
Real desktop GPUs (NVIDIA, AMD RDNA, Apple, most Adreno) have
minSubgroupSize >= 32, so VK_EXT_subgroup_size_control cannot request the
smaller subgroups the stitch path was written for. To actually exercise
NSG > 1 in CI, the script now also runs the test under lavapipe (Mesa's
CPU Vulkan driver) at LP_NATIVE_VECTOR_WIDTH in {128, 256, 512}, which
gives native subgroupSize {4, 8, 16} respectively and therefore covers
every distinct NSG branch the shader supports:

    LP_NATIVE_VECTOR_WIDTH | lavapipe SG | NSG (= TQ_WG / SG)
    -----------------------+-------------+--------------------
         128               |      4      |  8  (8-way stitch)
         256               |      8      |  4  (4-way stitch)
         512               |     16      |  2  (2-way stitch)

Combined with the native-GPU leg (NSG=1, fast path), this gives full
coverage of the helper's {1, 2, 4, 8} NSG branches on any host.

Usage and modes
---------------
  tests/test-turboquant.sh          # short mode (default): CI-friendly
  tests/test-turboquant.sh --full   # all TBQ/PQ types, full matrix

Short mode restricts the SG-coverage legs to tbq3_0 / pq3_0 / *_64 to keep
default CI runtime bounded; full mode covers all 8 TBQ/PQ types. Both
modes render a Unicode-boxed summary table at the end covering every
subgroup-coverage leg that ran.
@jesusmb1995

This comment was marked as resolved.

@jesusmb1995 jesusmb1995 force-pushed the turboquant branch 2 times, most recently from 01a747f to 97ccf17 Compare April 27, 2026 11:50
Keep omitted V-cache overrides on f16 when flash attention is disabled, and reject explicit quantized V sweeps early.
Do not advertise Vulkan MUL_MAT_ID for TBQ/PQ types because no ID pipelines exist for them. Plain MUL_MAT support remains enabled.
Centralize TBQ/PQ type checks so rotation and Vulkan support gates use the same type set. Keep Hadamard rotation limited to TBQ/PQ KV caches.
Cover small-n permuted TBQ/PQ MUL_MAT cases so standalone QJL and PQ controls are exercised by the TurboQuant test suite.
Run the standalone QJL correction when small-n TBQ is forced onto the matrix path, and index permuted TBQ batches with separate dim2/dim3 strides.
Exercise the head_dim=64 TBQ/PQ variants in standalone MUL_MAT and mixed FLASH_ATTN_EXT so CI catches regressions in the _64 Vulkan paths, not just copy_to_quant.

Use each type's block size when choosing MUL_MAT k so _64 cases run with real 64-block geometry instead of inheriting the d=128 shape.
@jesusmb1995

This comment was marked as resolved.

@jesusmb1995

This comment has been minimized.

@jesusmb1995

This comment has been minimized.

@jesusmb1995

This comment has been minimized.

Keep the GGML type comments aligned with the actual TBQ/PQ block sizes so the enum documents the correct storage cost.
Drop placeholder QJL seed macros that were immediately undefined before use. The numeric seeds stay unchanged; this only removes confusing preprocessor noise around the real constants.
Use subgroup reductions with shared-memory stitching for the standalone TBQ QJL correction, matching the subgroup-size handling used by copy_to_quant. This removes the serial thread-0 reduction while keeping QUANT_K-wide workgroups correct across smaller hardware subgroups.
Exercise the standalone non-FA TBQ QJL correction under lavapipe subgroup sizes 4, 8, and 16. Record the legs in the existing subgroup summary so multi-subgroup reduction regressions are visible in the TurboQuant test run.
oneAPI 2026 removed syclcompat/math.hpp, which the current SYCL helper still includes. Install the versioned 2025.3 compiler and MKL packages in both Ubuntu SYCL jobs so CI keeps using the supported toolchain.
@jesusmb1995

This comment was marked as resolved.

@jesusmb1995

This comment has been minimized.

@gianni-cor gianni-cor merged commit 01ac2a5 into tetherto:temp-7248 Apr 28, 2026
63 of 80 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants