Skip to content

Deterministic Hadamard KQ rotation#1

Open
captainpete wants to merge 13 commits intomainfrom
feat/hadamard-kq-rotation
Open

Deterministic Hadamard KQ rotation#1
captainpete wants to merge 13 commits intomainfrom
feat/hadamard-kq-rotation

Conversation

@captainpete
Copy link
Copy Markdown
Owner

@captainpete captainpete commented Apr 5, 2026

Purpose

First incremental step toward serving compressed-tensors checkpoints that use transform_config for Hadamard KV cache rotation (SpinQuant R3). See vllm-project#28538 and vllm-project/compressed-tensors#436.

Scope of this PR:

  • Deterministic Hadamard only (hadamard type, randomize=False)
  • FP16/BF16 weights; no INT4 or FP8 KV cache quantization
  • K_CACHE and Q_ATTN transform locations only
  • Llama-family models (other architectures need one line added to their load_weights)

What is not yet supported (raises NotImplementedError at model load):

  • random-hadamard type and randomize=True (require loading a per-layer rotation matrix from the checkpoint)
  • ROCm (CUDA-only kernel)

compressed-tensors PR#436 introduced TransformConfig, TransformScheme, and TransformArgs as the schema for describing these rotations. This PR reads that schema from the checkpoint's quantization_config and applies it at inference.

Changes:

  • CompressedTensorsKVCacheMethod.create_weights() reads transform_config, identifies layers with K_CACHE or Q_ATTN transform locations, and sets _kq_attn_transform = True on the corresponding Attention module
  • Attention.forward() applies ops.hadacore_transform to Q and K after reshape, immediately before the paged attention call
  • Validation at model load raises clear errors for unsupported configurations: random-hadamard, randomize=True, ROCm, head_dim mismatch, non-power-of-two head_dim, and conflict with FP8 KV scales
  • is_hadamard_transform_weight() in weight_utils.py handles stored rotation matrices in R3 checkpoints (R3_q_attn.weight, R3_k_cache.weight). The deterministic FWHT reconstructs H from head_dim alone, so these weights are skipped on load following the same pattern as maybe_remap_kv_scale_name

The rotation is checkpoint-driven: it fires when the checkpoint carries transform_config, with no CLI argument required. This aligns with the direction in #sig-quantization to deprecate the quantization= CLI arg for offline quantization.

Known limitation: is_hadamard_transform_weight() is currently called only in llama.py. Other models need the same call added when R3 checkpoints are produced for those architectures. This is the same per-model propagation pattern as maybe_remap_kv_scale_name, which is already in 64 model files.

Test Plan

Unit tests (no GPU, runs on all platforms):

.venv/bin/python -m pytest tests/quantization/test_hadamard_kv_dispatch.py -v
.venv/bin/python -m pytest tests/model_executor/test_weight_utils.py::TestIsHadamardTransformWeight -v

End-to-end against a real R3 checkpoint (~16GB VRAM, not suitable for CI):

# With rotation (default)
VLLM_WORKER_MULTIPROC_METHOD=spawn .venv/bin/python tests/quantization/test_hadamard_kq_r3_e2e.py

# Baseline (rotation stripped from config.json)
VLLM_WORKER_MULTIPROC_METHOD=spawn .venv/bin/python tests/quantization/test_hadamard_kq_r3_e2e.py --no-rotation

Test Result

Unit tests: all pass (no GPU required).

E2E on nm-testing/Meta-Llama-3-8B-Instruct-spinquantR3, RTX 3090, bf16:

Run Throughput
Hadamard rotation enabled 240.7 tok/s
Baseline (rotation stripped) 240.5 tok/s

Throughput is indistinguishable between the two runs (0.1% difference, within noise). Both runs produce coherent outputs with no crashes or NaNs. The accuracy benefit of Hadamard rotation is only measurable under INT4/FP8 KV cache quantization; the bf16 run confirms correctness of the dispatch and kernel execution path.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Add _has_kq_attn_transform() to CompressedTensorsKVCacheMethod, which
walks the checkpoint's transform_config to determine whether a given
attention layer should have K_CACHE/Q_ATTN Hadamard rotation applied.
Result is stored as layer._kq_attn_transform (bool) in create_weights.

Per-layer targeting uses is_match(), mirroring the linear transform
dispatch path.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Guard against random-hadamard and random-matrix scheme types targeting
K_CACHE/Q_ATTN. Silently returning True would apply no rotation at
serving time, corrupting attention for checkpoints that require the
stored matrix.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Raise ValueError at model load if:
- the layer has no head_size attribute
- scheme.head_dim is set but doesn't match layer.head_size
- head_dim is not a power of two (hadacore kernel constraint)
- head_dim exceeds 2^15 (hadacore kernel constraint)

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
hadacore_transform is a CUDA-only kernel. Raising at model load gives a
clear error rather than a cryptic op failure during the first forward
pass.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Apply ops.hadacore_transform to Q and K post-RoPE, post-reshape when
_kq_attn_transform is set. Initialise the flag to False before
_init_kv_cache_quant so create_weights has the final word.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
…eights

compressed-tensors checkpoints with transform_config (e.g. SpinQuant R3)
store Hadamard rotation matrices alongside model weights. vLLM applies the
deterministic FWHT via ops.hadacore_transform, which reconstructs H from
head_dim alone -- the stored matrix is never needed at serving time.

Adds is_hadamard_transform_weight(name) to detect these weights by their
{group}_{location}.weight suffix pattern so callers can skip them cleanly
rather than hitting a KeyError. Mirrors the existing maybe_remap_kv_scale_name
pattern in weight_utils.py.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
R3 checkpoints store rotation matrices (e.g. R3_q_attn.weight,
R3_k_cache.weight) at the self_attn level. vLLM's Attention module sits
one level deeper (self_attn.attn), so these weights have no corresponding
parameter in params_dict.

Replace the broad name-not-in-params_dict guard with the targeted
is_hadamard_transform_weight() check from weight_utils so that only
known transform matrices are skipped -- other unexpected weight names
still raise KeyError as intended.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Standalone script (not pytest) that loads nm-testing/Meta-Llama-3-8B-
Instruct-spinquantR3 and runs inference with K_CACHE/Q_ATTN Hadamard
rotation enabled (default) or disabled (--no-rotation flag).

The --no-rotation baseline strips transform_config from config.json into
a temp directory so the same checkpoint can be loaded without rotation.
Outputs throughput and sample text for both runs.

Not suitable for CI (~16GB VRAM). Run with:
  VLLM_WORKER_MULTIPROC_METHOD=spawn .venv/bin/python \
    tests/quantization/test_hadamard_kq_r3_e2e.py [--no-rotation]

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
@captainpete captainpete force-pushed the feat/hadamard-kq-rotation branch from 029b010 to 4452c84 Compare April 6, 2026 01:22
Add class-level annotation to narrow self.quant_config from the base
QuantizationConfig to CompressedTensorsConfig, fixing attr-defined
errors on transform_config. Replace stale type: ignore[attr-defined]
comments on total_num_heads/total_num_kv_heads with assert guards now
that mypy can see the actual int | None types.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Peter Hollows <github@dojo7.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant