Skip to content

[Bugfix][Attention][TurboQuant] Pad head_dim to power-of-2 for WHT#41414

Open
TheTom wants to merge 4 commits intovllm-project:mainfrom
TheTom:pr/tq-headdim-padding
Open

[Bugfix][Attention][TurboQuant] Pad head_dim to power-of-2 for WHT#41414
TheTom wants to merge 4 commits intovllm-project:mainfrom
TheTom:pr/tq-headdim-padding

Conversation

@TheTom
Copy link
Copy Markdown

@TheTom TheTom commented Apr 30, 2026

Purpose

Fix a latent correctness bug in the TurboQuant rotation path for models with non-power-of-2 head_dim. Reproduced on microsoft/phi-2 (head_dim=80) with turboquant_4bit_nc:

File "vllm/v1/attention/backends/turboquant_attn.py", line 380, in do_kv_cache_update
    y = x_hat @ PiT
RuntimeError: mat1 and mat2 shapes cannot be multiplied (16384x80 and 128x128)

Root cause

_build_hadamard_cached(d) constructs the Sylvester Hadamard by doubling H until H.shape[0] >= d, then normalizes by sqrt(d):

H = torch.tensor([[1.0]])
while H.shape[0] < d:
    H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0)
return (H / math.sqrt(d)).to(...)

For d=80 the loop overshoots to 128×128 and the result is normalized by 1/sqrt(80) — the matrix is the wrong size and not orthonormal at the constructed size. _ensure_on_device stores this 128×128 tensor as layer._tq_PiT; the MSE-K decode/store kernels then attempt q @ PiT with q at width 80 and PiT at 128×128.

The bug is path-specific:

  • MSE-K presets (turboquant_4bit_nc, turboquant_3bit_nc, turboquant_k3v4_nc) call the rotation GEMM and crash at engine init.
  • FP8-K (turboquant_k8v4) bypasses the WHT entirely (in-kernel FP8 cast, no rotation), so the broken matrix is built but never multiplied — the model loads but wastes VRAM on the unused buffer.

Summary of changes

  • Add padded_head_dim = next_power_of_2(head_dim) and needs_padding to TurboQuantConfig. Pow-2 head_dim is identity.
  • On the MSE-K path, run the WHT in padded_head_dim space throughout: zero-pad K and V at the kernel-launch boundary, run store/decode/continuation kernels with D=padded_head_dim, and slice the decode output back to head_dim before returning. Padded V columns hold zero quantization indices and contribute nothing to the reduction.
  • FP8-K path is untouched (raw head_dim, no rotation, kernel masks non-pow-2 loads directly).
  • For pow-2 head_dim (the common case: 64, 128, 256 — every current Qwen3, Llama, Mistral target), padded_head_dim == head_dim and every code path reduces to the prior behavior. Byte-counts in key_packed_size / value_packed_size are bitwise-identical.

Duplicate-work check

Searched open PRs and issues touching TurboQuant + head_dim / Hadamard / Phi-2 / non-power-of-2 before opening this PR. The closest references are:

No open PR addresses the rotation-shape mismatch on non-pow-2 head_dim. Issue #41413 was filed alongside this PR.

Test Plan / Results

Tested on AMD MI300X (gfx942), ROCm 7.2, vLLM ROCm 7.2.1 wheels.

Bug reproduction (Phi-2 d=80 + turboquant_4bit_nc):

Command (control + treatment):

python3 -c "
import os; os.environ['VLLM_ROCM_USE_AITER_FP4BMM']='0'
from vllm import LLM, SamplingParams
llm = LLM(model='microsoft/phi-2', dtype='bfloat16',
          kv_cache_dtype='turboquant_4bit_nc', max_model_len=2048,
          gpu_memory_utilization=0.40)
print(llm.generate(['The capital of France is'],
                   SamplingParams(max_tokens=32, temperature=0.0))[0].outputs[0].text)
"
Branch Result
upstream main c2fb01331 RuntimeError: mat1 and mat2 shapes cannot be multiplied (16384x80 and 128x128)
this PR Paris.

No regression — Qwen3-8B (d=128) 3-chunk PPL on wikitext-2-raw/wiki.test.raw @ 8K:

Command:

python3 -c "
import os, math
os.environ['VLLM_ROCM_USE_AITER_FP4BMM']='0'
from vllm import LLM, SamplingParams
llm = LLM(model='Qwen/Qwen3-8B', dtype='bfloat16', max_model_len=8192,
          kv_cache_dtype='<preset>', gpu_memory_utilization=0.40,
          enable_prefix_caching=False, max_num_batched_tokens=512)
tok = llm.get_tokenizer()
ids = tok.encode(open('wiki.test.raw').read(), add_special_tokens=False)
chunks = [ids[i:i+8191] for i in range(0, len(ids)-8191, 8191)][:3]
total_lp, total_tok = 0.0, 0
sp = SamplingParams(max_tokens=1, temperature=0.0, prompt_logprobs=1)
for ch in chunks:
    o = llm.generate({'prompt_token_ids': ch}, sp, use_tqdm=False)[0]
    for i, lp_dict in enumerate(o.prompt_logprobs[1:]):
        if lp_dict and ch[i+1] in lp_dict:
            total_lp += lp_dict[ch[i+1]].logprob
            total_tok += 1
print(math.exp(-total_lp/total_tok))
"
Preset upstream main this PR Δ
turboquant_k8v4 7.8630 7.8630 0
turboquant_4bit_nc 7.9041 7.9041 0

Both PPLs are byte-identical, same token count (24570).

Unit tests on this PR:

python3 -m pytest tests/quantization/test_turboquant.py -v

130/130 passed in 28.34s.

Includes 14 new tests for the non-pow-2 head_dim path:

  • padded_head_dim is identity for pow-2 head_dim (64, 128, 256)
  • non-pow-2 head_dim rounds up correctly (80→128, 96→128, 192→256, 40→64)
  • MSE preset at head_dim=80: key_packed_size=66, value_packed_size=68 (sized to padded 128)
  • FP8 preset at head_dim=80: key_packed_size=80, value_packed_size=44 (head_dim-sized, FP8 path)
  • Store + decode round-trip across {turboquant_k8v4, turboquant_4bit_nc} × {80, 96}: cosine similarity vs the stored V passes the same thresholds as the pow-2 case (>0.95 FP8, >0.85 MSE) and the returned tensor is sliced back to head_dim

AI assistance

This PR was prepared with AI assistance (Anthropic Claude). Each line of the diff was reviewed by the human submitter, the bug reproduction was run on the human's hardware (AMD MI300X dev cloud), and the no-regression PPL numbers are from runs the human supervised. Commits carry a Co-authored-by: Claude trailer per AGENTS.md.

Fixes #41413.

cc @vibhavagarwal5

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

Agent Guidelines

IMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban.

🚀

@mergify mergify Bot added the v1 label Apr 30, 2026
TheTom and others added 3 commits April 30, 2026 15:49
Sylvester Hadamard construction requires a power-of-2 dimension, so
models with head_dim like 80 (Qwen3-4B) currently can't use TurboQuant —
_build_hadamard returns the wrong-sized matrix and the MSE kernels read
through the resulting shape mismatch.

This commit lays the config groundwork for kernel-side padding:

  * padded_head_dim property → next_power_of_2(head_dim)
  * needs_padding property → True when head_dim isn't a power of 2
  * key_packed_size: MSE keys live in WHT space, so MSE byte count is
    sized to padded_head_dim (FP8 keys are not rotated, stay at head_dim)
  * value_packed_size: MSE-K path iterates a unified D for K and V in
    the fused store kernel, so V is sized to padded_head_dim too on the
    MSE path; FP8 path stays at head_dim

Pow-2 head_dims pass through unchanged: padded_head_dim == head_dim and
all packed sizes are byte-identical to the prior behavior.

Follow-up commits wire padded_head_dim into the backend init, store
kernel launcher, and decode kernel launcher.

Signed-off-by: TheTom <tturney1@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…tinuation

The MSE-K path runs the WHT in padded_head_dim space; this commit makes
the corresponding store, decode, and continuation-prefill paths agree.

Changes:

  * triton_turboquant_store: accept padded_head_dim, zero-pad x_hat and
    v_flat from head_dim to D_pad on the MSE path before the rotation
    GEMM and fused kernel. Pass D=D_pad to the kernel; mse_bytes,
    val_data_bytes, BLOCK_D, BLOCK_VAL, and block_grp are sized to
    D_pad. FP8 path is untouched (raw head_dim, no rotation).

  * triton_turboquant_decode_attention: accept padded_head_dim. On the
    MSE path, pad query before q @ Pi.T, allocate mid_o and output at
    D_pad+1 / D_pad, run stage1/stage2 with HEAD_DIM=D_pad and Lv=D_pad.
    Slice the returned output back to head_dim. The padded V columns
    hold zero quantization indices (see store), so they contribute
    nothing to the reduction. FP8 path unchanged.

  * TurboQuantAttentionImpl.__init__: pre-computed _mse_bytes and
    _val_data_bytes use padded_head_dim on the MSE path.

  * TurboQuantAttentionImpl._continuation_prefill: dequant buffers
    sized at D_pad, _tq_full_dequant_kv launched at HEAD_DIM=D_pad,
    inverse-rotate at D_pad, then slice K/V back to head_dim before the
    flash-attn concat.

  * TurboQuantAttentionImpl._store_kv and the two
    triton_turboquant_decode_attention call sites pass
    padded_head_dim through.

For pow-2 head_dim (the common case: 64, 128, 256), padded_head_dim
equals head_dim, D_pad == D throughout, and every code path reduces to
the prior behavior. Tests added in a follow-up commit cover the
non-pow-2 case end-to-end.

Signed-off-by: TheTom <tturney1@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add unit tests covering the padded_head_dim path:

  TestTurboQuantConfig:
    * padded_head_dim is identity for power-of-2 head_dim (64, 128, 256)
    * non-pow-2 head_dim rounds up to next power of 2 (80→128, 96→128,
      192→256, 40→64) and needs_padding flips True
    * MSE preset (turboquant_4bit_nc) at head_dim=80: key_packed_size=66
      and value_packed_size=68 (sized to padded 128, both K and V)
    * FP8 preset (turboquant_k8v4) at head_dim=80: key_packed_size=80 and
      value_packed_size=44 (FP8 keys not rotated, V on FP8 path stays at
      head_dim too)

  TestStoreDecodeRoundTrip:
    * test_non_pow2_head_dim_roundtrip across (k8v4, 4bit_nc) × (80, 96):
      build Hadamard at padded_head_dim, store random K/V into TQ cache,
      decode with query=key, assert per-head cosine similarity vs the
      stored V passes the same threshold as the pow-2 case (>0.95 FP8,
      >0.85 MSE) and that the returned tensor is sliced back to head_dim.

  Also fix the test-helper _build_hadamard to normalize at the padded
  size when called with non-pow-2 d, matching the serving-path
  construction.

Signed-off-by: TheTom <tturney1@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@TheTom TheTom force-pushed the pr/tq-headdim-padding branch from 350d76a to d3bf5bb Compare April 30, 2026 20:49
@TheTom TheTom changed the title [Attention][TurboQuant] Pad head_dim to power-of-2 for WHT (fix #41413) [Bugfix][Attention][TurboQuant] Pad head_dim to power-of-2 for WHT Apr 30, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for non-power-of-2 head dimensions in the TurboQuant quantization path by padding to the next power of 2 for Walsh-Hadamard Transform (WHT) operations. Key changes include updating the configuration logic to handle padded dimensions, modifying Triton kernels to perform zero-padding and slicing, and adding comprehensive tests for these scenarios. A performance issue was noted regarding the allocation of workspace buffers in the decode path, which currently uses the unpadded head size and will trigger unnecessary re-allocations for models requiring padding.

Comment thread vllm/v1/attention/backends/turboquant_attn.py
@mergify mergify Bot added the bug Something isn't working label Apr 30, 2026
…SE-K

Fix a perf regression flagged by gemini-code-assist on PR vllm-project#41414: the
decode workspace buffers (mid_o_buf, output_buf) in
TurboQuantAttentionImpl._decode_attention were still allocated at
head_size, while the launcher now requires padded_head_dim wide buffers
on the MSE-K path. The shape-fit checks in
triton_turboquant_decode_attention would then fail every decode call,
triggering a fresh torch.empty() per call and defeating the
WorkspaceManager-shared buffer pool.

Use padded_head_dim for the workspace allocation when not on the FP8-K
path. FP8-K stays at head_size because that path never enters WHT
space. Pow-2 head_dim is unchanged on both paths (padded_head_dim ==
head_size).

Signed-off-by: TheTom <tturney1@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@TheTom
Copy link
Copy Markdown
Author

TheTom commented Apr 30, 2026

Thanks @gemini-code-assist — good catch. Fixed in 2c6fc7e: _decode_attention now sizes the mid_o_buf / output_buf / lse_buf workspace allocations to padded_head_dim on the MSE-K path (FP8-K stays at head_size since that path never enters WHT space). Pow-2 head_dim is unchanged because padded_head_dim == head_size.

Re-verified on Phi-2 (d=80) + turboquant_4bit_nc: still produces Paris. and no per-decode torch.empty (workspace buffer-fits checks now pass).

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Thanks for the update, @TheTom. It's great to hear that the fix is verified on Phi-2 and that the workspace buffer-fits checks are now passing. The padding approach for head_dim in the WHT path seems robust. I'll mark this as ready for CI.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: TurboQuant fails on non-power-of-2 head_dim (Phi-2, MSE-K presets)

1 participant