[Feature] KV cache per-token-head Int2/Int4 Quantization + Triton_Quant_KV Interface#39074
[Feature] KV cache per-token-head Int2/Int4 Quantization + Triton_Quant_KV Interface#39074JartX wants to merge 44 commits intovllm-project:mainfrom
Conversation
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
|
Documentation preview: https://vllm--39074.org.readthedocs.build/en/39074/ |
Signed-off-by: JartX <sagformas@epdcenter.es>
There was a problem hiding this comment.
Code Review
This pull request introduces INT2 and INT4 per-token-head KV cache quantization for the Triton attention backend. It implements specialized Triton kernels for INT4 asymmetric quantization with zero-point steganography and INT2 quantization using Walsh-Hadamard Transforms and Lloyd-Max centroids. The update includes modifications to the attention backend for packed storage handling and comprehensive tests for accuracy. Review feedback focused on improving the error messages for assertions within the quantization kernels to prevent silent failures and aid debugging.
|
Hi @JartX, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
# Conflicts: # docs/design/attention_backends.md # tests/quantization/test_per_token_kv_cache.py # vllm/config/cache.py
f4a799b to
0013943
Compare
Signed-off-by: JartX <sagformas@epdcenter.es>
|
I regression tested with the following commands: vllm serve meta-llama/Llama-3.1-8B --max_model_len 8192 --kv-cache-dtype int8_per_token_head --attention-backend triton_attn --gpu-memory-utilization 0.90
vllm serve meta-llama/Llama-3.1-8B --max_model_len 8192 --kv-cache-dtype int2_per_token_head --attention-backend triton_attn --gpu-memory-utilization 0.90
vllm serve meta-llama/Llama-3.1-8B --max_model_len 8192 --kv-cache-dtype int4_per_token_head --attention-backend triton_attn --gpu-memory-utilization 0.90 |
| cdiv_fn, | ||
| compute_kv_seq_mask, | ||
| compute_tile_loop_bounds, | ||
| find_seq_idx, |
There was a problem hiding this comment.
Are all of these helper functions related to KV cache quantization? find_seq_idx for example doesn't have anything to with quantization afaik.
| if IS_3D: | ||
| tiles_per_segment = cdiv_fn(seq_len, NUM_SEGMENTS_PER_SEQ * TILE_SIZE) | ||
| if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: | ||
| return |
There was a problem hiding this comment.
This PR contains too many changes imo. Can we first create a PR with the 2D/3D refactor independently of KV cache quantization?
There was a problem hiding this comment.
I think this is probably reasonable to split, but I wanted to explain why I’d prefer to keep it in the same PR first.
The 2D/3D merge wasn’t really the end goal by itself. It mostly fell out of the INT4/INT2 work, since otherwise we would have ended up carrying four near-identical variants: 2D and 3D, each in packed and non-packed forms.
Bundling it this way also keeps the relevant flows testable end to end. If this gets split too early, the intermediate pieces only contain part of the execution path, so you can’t really exercise the actual packed/non-packed and 2D/3D combinations as they would run in practice.
After the last round of comments, the diff also separates fairly cleanly:
triton_attention_helpers.py is just helper extraction, with no intended behavior change.
triton_unified_attention.py contains the 2D/3D unification behind IS_3D, plus the dispatch branch for the packed backend.
packed_per_token_head.py is a new implementation and doesn’t modify an existing path.
So while the overall diff looks large at first glance, most of that is the new file plus the kernel body move, rather than lots of scattered edits across the codebase.
For that reason, I’d prefer to keep it in the same PR: it keeps the review aligned with the actual change we want to land, and it preserves a testable end-to-end path instead of splitting it into intermediate states that are harder to validate meaningfully.
I’m happy to split it if you’d still prefer that direction — I just wanted to give the rationale for keeping it together as submitted.
| segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) | ||
| overall_max = tl.max(segm_max) | ||
|
|
||
| # load and rescale segment exp sums |
|
|
||
|
|
||
| @triton.jit | ||
| def _attn_packed( |
There was a problem hiding this comment.
I still don't understand why we need a complete re-implementation of the kernel in this file
There was a problem hiding this comment.
Fair point. I think it looks more like a reimplementation than it really is.
Most of the logic that should be shared already lives in vllm/v1/attention/ops/triton_attention_helpers.py, including sequence/query length handling, tile-loop bounds, the online softmax steps, scalar reduction storage, and the usual masking / ALiBi / soft-cap / QQ-bias pieces.
What’s left in _attn_packed is mostly the stuff that doesn’t map cleanly onto the main kernel structure.
First, packed Q is fundamentally different. We load it as PACKING_FACTOR interleaved streams — 2-way for INT4 nibble pairs, 4-way for INT2 quartets — so the tensor shape and GEMM layout are different from the normal path.
Second, the accumulation is different too. Instead of a single acc, the packed path needs per-stream accumulators (acc_s0..acc_s3) because the score comes from a split dot product:
tl.dot(Q_s0, K_s0) + tl.dot(Q_s1, K_s1) [+ ...].
Third, the sub-byte dequant path is specific to these formats. INT4 uses unpack_int4_nibbles, which just gives plain unsigned values in [0, 15]. INT2 uses unpack_int2_quartet plus the Lloyd-Max centroid LUT. None of that exists in the non-packed kernel today.
INT4 asymmetric zero-point handling also adds correction terms that don’t have an equivalent in the normal path:
on the score side: raw_dot - Q_sum * k_zp before scaling
on the accumulator side: subtracting Pv_zp_sum from every stream
And finally, the epilogue is different because we have to write out PACKING_FACTOR interleaved stripes using PACKED_HEAD_PADDED offsets instead of doing one flat store.
So yes, we could fold this into kernel_unified_attention behind something like PACKING_FACTOR: tl.constexpr. I’m not against that. My hesitation is that it would add a lot more constexpr branching to the main kernel and inline a bunch of packed-only logic into the common path, even though the vast majority of callers are still on the non-packed variants.
That’s why I kept the sub-byte kernel separate: the shared masking / bias / online-softmax logic is already factored into helpers, and the remaining differences are the parts that are genuinely structurally different.
If you still want the single-kernel direction, I’m happy to do it — I just wanted to explain why it’s split this way today.
Signed-off-by: JartX <sagformas@epdcenter.es>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: JartX <sagformas@epdcenter.es>
The PR focuses on expanding per token by adding Int2/Int4 per token
_effective_head_sizehas been added to allow calculation based on the type of per quantThe tests were performed on Qwen3.5 35B A3B GPTQ
I used a Walsh–Hadamard-based rotation. For the INT2 setting, I applied a plain Walsh–Hadamard transform to the KV vectors. For the INT4 setting, I used a single-round Randomized Hadamard Transform, implemented as a deterministic sign flip followed by the Walsh–Hadamard transform on the triton kernels
GSM8K
INT8_PER_TOKEN_HEAD
INT4_PER_TOKEN_HEAD
INT2_PER_TOKEN_HEAD
FP8_PER_TOKEN_HEAD
FP8 per-tensor
FP16