Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 160 additions & 3 deletions tests/quantization/test_turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,43 @@ def test_boundary_skip_layers_cap_at_half(self):
layers = TurboQuantConfig.get_boundary_skip_layers(8, 10)
assert len(layers) == 8

# ---- Non-power-of-2 head_dim padding ----

@pytest.mark.parametrize("head_dim", [64, 128, 256])
def test_padded_head_dim_pow2_passthrough(self, head_dim):
"""Power-of-2 head_dim is unchanged: padded_head_dim == head_dim."""
cfg = TurboQuantConfig.from_cache_dtype("turboquant_4bit_nc", head_dim=head_dim)
assert cfg.padded_head_dim == head_dim
assert cfg.needs_padding is False

@pytest.mark.parametrize(
"head_dim,expected_padded",
[(80, 128), (96, 128), (192, 256), (40, 64)],
)
def test_padded_head_dim_non_pow2(self, head_dim, expected_padded):
"""Non-power-of-2 head_dim rounds up to the next power of 2."""
cfg = TurboQuantConfig.from_cache_dtype("turboquant_4bit_nc", head_dim=head_dim)
assert cfg.padded_head_dim == expected_padded
assert cfg.needs_padding is True

def test_mse_key_packed_size_uses_padded(self):
"""MSE keys live in WHT space, so byte count tracks padded_head_dim."""
cfg = TurboQuantConfig.from_cache_dtype("turboquant_4bit_nc", head_dim=80)
assert cfg.padded_head_dim == 128
# 4-bit MSE: ceil(128 * 4 / 8) + 2 (vec_norm fp16) = 66
assert cfg.key_packed_size == 66
# 4-bit V on MSE-K path uses padded too: ceil(128 * 4 / 8) + 4 = 68
assert cfg.value_packed_size == 68

def test_fp8_key_packed_size_stays_at_head_dim(self):
"""FP8 keys are not rotated, so byte count tracks head_dim directly."""
cfg = TurboQuantConfig.from_cache_dtype("turboquant_k8v4", head_dim=80)
assert cfg.padded_head_dim == 128
assert cfg.key_packed_size == 80 # 1 byte per element, no rotation
# FP8 K + uniform V: V also stays at head_dim (kernel reads raw)
# ceil(80 * 4 / 8) + 4 = 44
assert cfg.value_packed_size == 44


# ============================================================================
# Centroids tests (CPU-only)
Expand Down Expand Up @@ -398,11 +435,18 @@ def test_rotation_matrix_det_is_pm1(self):


def _build_hadamard(d: int, device: str = "cpu") -> torch.Tensor:
"""Reproduce the serving-path Hadamard construction."""
"""Reproduce the serving-path Hadamard construction.

For non-power-of-2 d, the Sylvester construction overshoots and
produces a matrix at next_power_of_2(d). Normalize by sqrt of the
matrix size so the result is orthonormal at that size; the serving
path uses padded_head_dim throughout for the same reason.
"""
target = next_power_of_2(d)
H = torch.tensor([[1.0]])
while H.shape[0] < d:
while H.shape[0] < target:
H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0)
return (H / math.sqrt(d)).to(torch.device(device))
return (H / math.sqrt(target)).to(torch.device(device))


@pytest.mark.skipif(not GPGPU_AVAILABLE, reason="GPGPU not available")
Expand Down Expand Up @@ -545,3 +589,116 @@ def test_single_token_roundtrip(self, preset):
assert cos_sim > threshold, (
f"Preset {preset} head {h}: cosine_sim={cos_sim:.4f} < {threshold}"
)

@pytest.mark.parametrize(
"preset,head_dim",
[
("turboquant_k8v4", 80),
("turboquant_4bit_nc", 80),
("turboquant_k8v4", 96),
("turboquant_4bit_nc", 96),
],
)
def test_non_pow2_head_dim_roundtrip(self, preset, head_dim):
"""Store + decode at non-power-of-2 head_dim (Qwen3-4B = 80)."""
from vllm.model_executor.layers.quantization.turboquant.centroids import (
solve_lloyd_max,
)
from vllm.v1.attention.ops.triton_turboquant_decode import (
triton_turboquant_decode_attention,
)
from vllm.v1.attention.ops.triton_turboquant_store import (
triton_turboquant_store,
)

cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=head_dim)
D = head_dim
D_pad = cfg.padded_head_dim
assert cfg.needs_padding, f"head_dim={head_dim} should need padding"

Hk = 4
Hq = 4
B = 1
block_size = 16
num_blocks = 1

device = torch.device(DEVICE_TYPE)

# Hadamard at padded dim — this is what the serving path does.
H = _build_hadamard(D_pad, DEVICE_TYPE)
PiT = H
Pi = H

centroids, _ = solve_lloyd_max(D_pad, cfg.centroid_bits)
centroids = centroids.float().to(device)
c_sorted, _ = centroids.sort()
midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device)

torch.manual_seed(123)
key = torch.randn(B, Hk, D, device=device, dtype=torch.float16)
value = torch.randn(B, Hk, D, device=device, dtype=torch.float16)

padded_slot = cfg.slot_size_aligned
kv_cache = torch.zeros(
num_blocks,
block_size,
Hk,
padded_slot,
device=device,
dtype=torch.uint8,
)
slot_mapping = torch.tensor([0], device=device, dtype=torch.int32)

triton_turboquant_store(
key,
value,
kv_cache,
slot_mapping,
PiT,
midpoints,
mse_bits=cfg.key_mse_bits,
key_packed_size=cfg.key_packed_size,
value_quant_bits=cfg.effective_value_quant_bits,
key_fp8=cfg.key_fp8,
padded_head_dim=cfg.padded_head_dim,
)

query = key.expand(B, Hq, D).contiguous().to(torch.float16)
block_table = torch.tensor([[0]], device=device, dtype=torch.int32)
seq_lens = torch.tensor([1], device=device, dtype=torch.int32)

output = triton_turboquant_decode_attention(
query=query,
kv_cache=kv_cache,
block_table=block_table,
seq_lens=seq_lens,
Pi=Pi,
centroids=centroids,
scale=1.0 / math.sqrt(D),
mse_bits=cfg.key_mse_bits,
key_packed_size=cfg.key_packed_size,
value_quant_bits=cfg.effective_value_quant_bits,
key_fp8=cfg.key_fp8,
norm_correction=cfg.norm_correction,
PiT=PiT,
max_num_kv_splits=4,
padded_head_dim=cfg.padded_head_dim,
)

# Output is sliced back to head_dim by the launcher.
assert output.shape == (B, Hq, D), (
f"Expected output shape {(B, Hq, D)}, got {tuple(output.shape)}"
)

out_fp32 = output.float()
val_fp32 = value.expand(B, Hq, D).float()
for h in range(Hq):
cos_sim = torch.nn.functional.cosine_similarity(
out_fp32[0, h].unsqueeze(0),
val_fp32[0, h].unsqueeze(0),
).item()
threshold = 0.95 if cfg.key_fp8 else 0.85
assert cos_sim > threshold, (
f"Preset {preset} d={head_dim} head {h}: "
f"cosine_sim={cos_sim:.4f} < {threshold}"
)
44 changes: 37 additions & 7 deletions vllm/model_executor/layers/quantization/turboquant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import math
from dataclasses import dataclass

from vllm.utils.math_utils import next_power_of_2

# Named TQ presets: each maps to frozen config parameters.
# key_quant_bits: 8 = FP8 keys, 3-4 = MSE (Lloyd-Max) quantized keys.
# value_quant_bits: 3-4 = uniform quantized values.
Expand Down Expand Up @@ -81,6 +83,22 @@ def key_fp8(self) -> bool:
"""Whether keys are stored as FP8 — no rotation/quantization needed."""
return self.key_quant_bits == 8

@property
def padded_head_dim(self) -> int:
"""Head dimension used for the WHT rotation.

Sylvester Hadamard construction requires a power-of-2 dimension.
Models with non-power-of-2 head_dim (e.g. Qwen3-4B at 80) are padded
to the next power of 2 for the WHT and sliced back at the I/O
boundary. Pow-2 head_dims pass through unchanged.
"""
return next_power_of_2(self.head_dim)

@property
def needs_padding(self) -> bool:
"""Whether head_dim requires zero-padding for the WHT."""
return self.padded_head_dim != self.head_dim

@property
def mse_bits(self) -> int:
"""MSE quantizer bit-width (determines centroid count: 2^mse_bits).
Expand Down Expand Up @@ -114,15 +132,19 @@ def key_packed_size(self) -> int:
"""Packed bytes for a single KEY vector.

FP8 mode (key_quant_bits=8):
head_dim bytes (1 byte per element, no overhead).
head_dim bytes (1 byte per element, no overhead). FP8 keys are not
rotated, so storage tracks the model's real head_dim.

TQ mode:
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
TQ mode (MSE keys with WHT rotation):
- MSE indices: ceil(padded_head_dim * key_mse_bits / 8) bytes
- vec_norm: 2 bytes (float16)

MSE indices live in WHT space, so the byte count is sized to the
padded dimension when head_dim is not a power of 2.
"""
if self.key_fp8:
return self.head_dim # 1 byte per element
mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8)
return self.head_dim # 1 byte per element, no rotation
mse_bytes = math.ceil(self.padded_head_dim * self.key_mse_bits / 8)
norm_bytes = 2 # vec_norm fp16
return mse_bytes + norm_bytes

Expand All @@ -135,9 +157,17 @@ def effective_value_quant_bits(self) -> int:
def value_packed_size(self) -> int:
"""Packed bytes for a single VALUE vector.

Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
Uniform quantization: ceil(D * bits / 8) + 4 bytes (scale + zero fp16).

On the FP8 K path the kernel operates on raw head_dim (no WHT), so
D = head_dim. On the MSE K path the kernel iterates a unified D for
both K and V; when head_dim is not a power of 2 the K-side requires
padded_head_dim for the WHT, so V is sized to padded_head_dim too
(zero-padded at store time, sliced at decode time). Pow-2 head_dims
pass through unchanged.
"""
data_bytes = math.ceil(self.head_dim * self.value_quant_bits / 8)
d = self.head_dim if self.key_fp8 else self.padded_head_dim
data_bytes = math.ceil(d * self.value_quant_bits / 8)
return data_bytes + 4 # +2 scale(fp16) +2 zero(fp16)

@property
Expand Down
Loading
Loading