Skip to content

[Feature] KV cache per-token-head Int2/Int4 Quantization + Triton_Quant_KV Interface#39074

Open
JartX wants to merge 44 commits intovllm-project:mainfrom
JartX:feature/int2_int4_per_token_head
Open

[Feature] KV cache per-token-head Int2/Int4 Quantization + Triton_Quant_KV Interface#39074
JartX wants to merge 44 commits intovllm-project:mainfrom
JartX:feature/int2_int4_per_token_head

Conversation

@JartX
Copy link
Copy Markdown
Contributor

@JartX JartX commented Apr 6, 2026

The PR focuses on expanding per token by adding Int2/Int4 per token

_effective_head_size has been added to allow calculation based on the type of per quant

The tests were performed on Qwen3.5 35B A3B GPTQ

I used a Walsh–Hadamard-based rotation. For the INT2 setting, I applied a plain Walsh–Hadamard transform to the KV vectors. For the INT4 setting, I used a single-round Randomized Hadamard Transform, implemented as a deterministic sign flip followed by the Walsh–Hadamard transform on the triton kernels

GSM8K

INT8_PER_TOKEN_HEAD

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.940 ± 0.0151
strict-match 5 exact_match 0.916 ± 0.0176

INT4_PER_TOKEN_HEAD

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.912 ± 0.0273
strict-match 5 exact_match 0.91 ± 0.0273

INT2_PER_TOKEN_HEAD

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.864 ± 0.0217
strict-match 5 exact_match 0.832 ± 0.0237

FP8_PER_TOKEN_HEAD

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.940 ± 0.0151
strict-match 5 exact_match 0.916 ± 0.0176

FP8 per-tensor

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.924 ± 0.0168
strict-match 5 exact_match 0.912 ± 0.0180

FP16

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.940 ± 0.0151
strict-match 5 exact_match 0.912 ± 0.0180

JartX added 8 commits April 2, 2026 23:53
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 6, 2026

Documentation preview: https://vllm--39074.org.readthedocs.build/en/39074/

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Apr 6, 2026
Signed-off-by: JartX <sagformas@epdcenter.es>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces INT2 and INT4 per-token-head KV cache quantization for the Triton attention backend. It implements specialized Triton kernels for INT4 asymmetric quantization with zero-point steganography and INT2 quantization using Walsh-Hadamard Transforms and Lloyd-Max centroids. The update includes modifications to the attention backend for packed storage handling and comprehensive tests for accuracy. Review feedback focused on improving the error messages for assertions within the quantization kernels to prevent silent failures and aid debugging.

Comment thread vllm/v1/attention/ops/triton_reshape_and_cache_flash.py Outdated
Comment thread vllm/v1/attention/ops/triton_reshape_and_cache_flash.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 6, 2026

Hi @JartX, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enable CI to help test

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 6, 2026
JartX added 4 commits April 6, 2026 15:42
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
JartX added 2 commits April 15, 2026 10:42
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX marked this pull request as draft April 15, 2026 12:02
JartX added 5 commits April 15, 2026 14:37
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
# Conflicts:
#	docs/design/attention_backends.md
#	tests/quantization/test_per_token_kv_cache.py
#	vllm/config/cache.py
Signed-off-by: JartX <sagformas@epdcenter.es>
@mergify mergify bot removed the needs-rebase label Apr 15, 2026
@JartX JartX force-pushed the feature/int2_int4_per_token_head branch from f4a799b to 0013943 Compare April 16, 2026 17:46
@JartX JartX marked this pull request as ready for review April 16, 2026 17:50
JartX added 2 commits April 16, 2026 23:57
Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: JartX <sagformas@epdcenter.es>
@kylesayrs
Copy link
Copy Markdown
Contributor

I regression tested with the following commands:

vllm serve meta-llama/Llama-3.1-8B --max_model_len 8192  --kv-cache-dtype int8_per_token_head --attention-backend triton_attn --gpu-memory-utilization 0.90
vllm serve meta-llama/Llama-3.1-8B --max_model_len 8192  --kv-cache-dtype int2_per_token_head --attention-backend triton_attn --gpu-memory-utilization 0.90
vllm serve meta-llama/Llama-3.1-8B --max_model_len 8192  --kv-cache-dtype int4_per_token_head --attention-backend triton_attn --gpu-memory-utilization 0.90

cdiv_fn,
compute_kv_seq_mask,
compute_tile_loop_bounds,
find_seq_idx,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all of these helper functions related to KV cache quantization? find_seq_idx for example doesn't have anything to with quantization afaik.

if IS_3D:
tiles_per_segment = cdiv_fn(seq_len, NUM_SEGMENTS_PER_SEQ * TILE_SIZE)
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
return
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR contains too many changes imo. Can we first create a PR with the 2D/3D refactor independently of KV cache quantization?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably reasonable to split, but I wanted to explain why I’d prefer to keep it in the same PR first.

The 2D/3D merge wasn’t really the end goal by itself. It mostly fell out of the INT4/INT2 work, since otherwise we would have ended up carrying four near-identical variants: 2D and 3D, each in packed and non-packed forms.

Bundling it this way also keeps the relevant flows testable end to end. If this gets split too early, the intermediate pieces only contain part of the execution path, so you can’t really exercise the actual packed/non-packed and 2D/3D combinations as they would run in practice.

After the last round of comments, the diff also separates fairly cleanly:

triton_attention_helpers.py is just helper extraction, with no intended behavior change.

triton_unified_attention.py contains the 2D/3D unification behind IS_3D, plus the dispatch branch for the packed backend.

packed_per_token_head.py is a new implementation and doesn’t modify an existing path.

So while the overall diff looks large at first glance, most of that is the new file plus the kernel body move, rather than lots of scattered edits across the codebase.

For that reason, I’d prefer to keep it in the same PR: it keeps the review aligned with the actual change we want to land, and it preserves a testable end-to-end path instead of splitting it into intermediate states that are harder to validate meaningfully.

I’m happy to split it if you’d still prefer that direction — I just wanted to give the rationale for keeping it together as submitted.

segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf"))
overall_max = tl.max(segm_max)

# load and rescale segment exp sums
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we removing comments?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored



@triton.jit
def _attn_packed(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand why we need a complete re-implementation of the kernel in this file

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. I think it looks more like a reimplementation than it really is.

Most of the logic that should be shared already lives in vllm/v1/attention/ops/triton_attention_helpers.py, including sequence/query length handling, tile-loop bounds, the online softmax steps, scalar reduction storage, and the usual masking / ALiBi / soft-cap / QQ-bias pieces.

What’s left in _attn_packed is mostly the stuff that doesn’t map cleanly onto the main kernel structure.

First, packed Q is fundamentally different. We load it as PACKING_FACTOR interleaved streams — 2-way for INT4 nibble pairs, 4-way for INT2 quartets — so the tensor shape and GEMM layout are different from the normal path.

Second, the accumulation is different too. Instead of a single acc, the packed path needs per-stream accumulators (acc_s0..acc_s3) because the score comes from a split dot product:
tl.dot(Q_s0, K_s0) + tl.dot(Q_s1, K_s1) [+ ...].

Third, the sub-byte dequant path is specific to these formats. INT4 uses unpack_int4_nibbles, which just gives plain unsigned values in [0, 15]. INT2 uses unpack_int2_quartet plus the Lloyd-Max centroid LUT. None of that exists in the non-packed kernel today.

INT4 asymmetric zero-point handling also adds correction terms that don’t have an equivalent in the normal path:

on the score side: raw_dot - Q_sum * k_zp before scaling

on the accumulator side: subtracting Pv_zp_sum from every stream

And finally, the epilogue is different because we have to write out PACKING_FACTOR interleaved stripes using PACKED_HEAD_PADDED offsets instead of doing one flat store.

So yes, we could fold this into kernel_unified_attention behind something like PACKING_FACTOR: tl.constexpr. I’m not against that. My hesitation is that it would add a lot more constexpr branching to the main kernel and inline a bunch of packed-only logic into the common path, even though the vast majority of callers are still on the non-packed variants.

That’s why I kept the sub-byte kernel separate: the shared masking / bias / online-softmax logic is already factored into helpers, and the remaining differences are the parts that are genuinely structurally different.

If you still want the single-kernel direction, I’m happy to do it — I just wanted to explain why it’s split this way today.

@JartX JartX marked this pull request as draft April 18, 2026 00:16
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 18, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JartX.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 18, 2026
@mergify mergify bot removed the needs-rebase label Apr 18, 2026
@JartX JartX marked this pull request as ready for review April 18, 2026 10:01
Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

7 participants