Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
367 changes: 367 additions & 0 deletions tests/quantization/test_turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,370 @@ def test_single_token_roundtrip(self, preset):
assert cos_sim > threshold, (
f"Preset {preset} head {h}: cosine_sim={cos_sim:.4f} < {threshold}"
)

@pytest.mark.parametrize("kv_group_size", [4, 8, 24])
def test_gqa_roundtrip_k8v4(self, kv_group_size):
"""GQA round-trip for the grouped decode kernel path.

Only turboquant_k8v4 (FP8 keys) uses the grouped kernel; the MSE
presets route to the original scalar kernel, which is already
covered by test_single_token_roundtrip.
"""
preset = "turboquant_k8v4"
from vllm.model_executor.layers.quantization.turboquant.centroids import (
solve_lloyd_max,
)
from vllm.v1.attention.ops.triton_turboquant_decode import (
triton_turboquant_decode_attention,
)
from vllm.v1.attention.ops.triton_turboquant_store import (
triton_turboquant_store,
)

cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
D = 128
Hk = 4
Hq = Hk * kv_group_size
B = 2
seq_len = 32
block_size = 16
num_blocks = (seq_len + block_size - 1) // block_size

device = torch.device(DEVICE_TYPE)

PiT = _build_hadamard(D, DEVICE_TYPE)

centroids, _ = solve_lloyd_max(D, cfg.centroid_bits)
centroids = centroids.float().to(device)
c_sorted, _ = centroids.sort()
midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device)

torch.manual_seed(42)
# Store multiple tokens
keys = torch.randn(seq_len, Hk, D, device=device, dtype=torch.float16)
values = torch.randn(seq_len, Hk, D, device=device, dtype=torch.float16)

padded_slot = cfg.slot_size_aligned
kv_cache = torch.zeros(
num_blocks,
block_size,
Hk,
padded_slot,
device=device,
dtype=torch.uint8,
)
slot_mapping = torch.arange(seq_len, device=device, dtype=torch.int32)

triton_turboquant_store(
keys,
values,
kv_cache,
slot_mapping,
PiT,
midpoints,
mse_bits=cfg.key_mse_bits,
key_packed_size=cfg.key_packed_size,
value_quant_bits=cfg.effective_value_quant_bits,
key_fp8=cfg.key_fp8,
)

# Decode: use last key as query for each batch
query_keys = keys[-B:] # [B, Hk, D]
query = (
query_keys[:, :, None, :]
.expand(B, Hk, kv_group_size, D)
.reshape(B, Hq, D)
.contiguous()
.to(torch.float16)
)

block_table = (
torch.arange(num_blocks, device=device, dtype=torch.int32)
.unsqueeze(0)
.expand(B, -1)
.contiguous()
)
seq_lens = torch.full((B,), seq_len, device=device, dtype=torch.int32)

output = triton_turboquant_decode_attention(
query=query,
kv_cache=kv_cache,
block_table=block_table,
seq_lens=seq_lens,
Pi=PiT,
centroids=centroids,
scale=1.0 / math.sqrt(D),
mse_bits=cfg.key_mse_bits,
key_packed_size=cfg.key_packed_size,
value_quant_bits=cfg.effective_value_quant_bits,
key_fp8=cfg.key_fp8,
norm_correction=cfg.norm_correction,
PiT=PiT,
max_num_kv_splits=8,
)

# Grouped Q heads sharing same KV head should produce similar
# outputs. Check that output is finite and has reasonable norm.
assert output.isfinite().all(), (
f"Preset {preset} GQA={kv_group_size}: non-finite output"
)
out_norms = output.float().norm(dim=-1)
assert (out_norms > 0.01).all(), (
f"Preset {preset} GQA={kv_group_size}: near-zero output"
)

# Q heads within same GQA group used the same query key,
# so their outputs should be identical (same KV, same Q).
out_fp32 = output.float()
for b in range(B):
for kh in range(Hk):
base_h = kh * kv_group_size
ref = out_fp32[b, base_h]
for g in range(1, kv_group_size):
h = base_h + g
cos = torch.nn.functional.cosine_similarity(
ref.unsqueeze(0), out_fp32[b, h].unsqueeze(0)
).item()
assert cos > 0.99, (
f"Preset {preset} GQA={kv_group_size} "
f"batch={b} heads {base_h} vs {h}: "
f"cosine={cos:.4f} (expected >0.99 for same query)"
)

def test_grouped_vs_original_kernel_k8v4(self):
"""Direct A/B of grouped vs scalar kernel on turboquant_k8v4.

Forces both kernels on the same inputs and verifies outputs match
within fp16 tl.dot precision tolerance. This is the primary
correctness check for the grouped kernel change.
"""
preset = "turboquant_k8v4"
from vllm.model_executor.layers.quantization.turboquant.centroids import (
solve_lloyd_max,
)
from vllm.v1.attention.ops.triton_turboquant_decode import (
_fwd_kernel_stage2,
_get_layout,
_tq_decode_stage1,
_tq_grouped_decode_stage1,
_use_fp8_e4b15,
)
from vllm.v1.attention.ops.triton_turboquant_store import (
triton_turboquant_store,
)

cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
D = 128
Hk = 4
kv_group_size = 4
Hq = Hk * kv_group_size
B = 2
seq_len = 48
block_size = 16
num_blocks = (seq_len + block_size - 1) // block_size
NUM_KV_SPLITS = 8
device = torch.device(DEVICE_TYPE)

PiT = _build_hadamard(D, DEVICE_TYPE)

centroids, _ = solve_lloyd_max(D, cfg.centroid_bits)
centroids = centroids.float().to(device)
c_sorted, _ = centroids.sort()
midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device)

torch.manual_seed(99)
keys = torch.randn(seq_len, Hk, D, device=device, dtype=torch.float16)
values = torch.randn(seq_len, Hk, D, device=device, dtype=torch.float16)

padded_slot = cfg.slot_size_aligned
kv_cache = torch.zeros(
num_blocks,
block_size,
Hk,
padded_slot,
device=device,
dtype=torch.uint8,
)
slot_mapping = torch.arange(seq_len, device=device, dtype=torch.int32)
triton_turboquant_store(
keys,
values,
kv_cache,
slot_mapping,
PiT,
midpoints,
mse_bits=cfg.key_mse_bits,
key_packed_size=cfg.key_packed_size,
value_quant_bits=cfg.effective_value_quant_bits,
key_fp8=cfg.key_fp8,
)

torch.manual_seed(77)
query = torch.randn(B, Hq, D, device=device, dtype=torch.float16)

if cfg.key_fp8:
q_rot = query.contiguous()
else:
q_rot = (query.float() @ PiT).contiguous()

layout = _get_layout(
D,
cfg.key_mse_bits,
cfg.effective_value_quant_bits,
cfg.key_packed_size,
)
fp8_e4b15 = _use_fp8_e4b15(device.index or 0)

block_table = (
torch.arange(num_blocks, device=device, dtype=torch.int32)
.unsqueeze(0)
.expand(B, -1)
.contiguous()
)
seq_lens = torch.full((B,), seq_len, device=device, dtype=torch.int32)

# --- Run original (scalar) kernel ---
mid_o_orig = torch.empty(
B,
Hq,
NUM_KV_SPLITS,
D + 1,
dtype=torch.float32,
device=device,
)
grid_orig = (B, Hq, NUM_KV_SPLITS)
_tq_decode_stage1[grid_orig](
q_rot,
kv_cache,
block_table,
seq_lens,
centroids,
mid_o_orig,
q_rot.stride(0),
q_rot.stride(1),
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
block_table.stride(0),
mid_o_orig.stride(0),
mid_o_orig.stride(1),
mid_o_orig.stride(2),
NUM_KV_HEADS=Hk,
HEAD_DIM=D,
BLOCK_SIZE=block_size,
NUM_KV_SPLITS=NUM_KV_SPLITS,
KV_GROUP_SIZE=kv_group_size,
MSE_BITS=cfg.key_mse_bits,
MSE_BYTES=layout["mse_bytes"],
KPS=cfg.key_packed_size,
VQB=cfg.effective_value_quant_bits,
VAL_DATA_BYTES=layout["val_data_bytes"],
ATTN_SCALE=1.0 / math.sqrt(D),
BLOCK_D=layout["BLOCK_D"],
BLOCK_KV=4,
KEY_FP8=1 if cfg.key_fp8 else 0,
NORM_CORRECTION=1 if cfg.norm_correction else 0,
FP8_E4B15=fp8_e4b15,
num_warps=1,
num_stages=1,
)
out_orig = torch.empty(B, Hq, D, dtype=torch.float32, device=device)
lse_orig = torch.empty(B, Hq, dtype=torch.float32, device=device)
_fwd_kernel_stage2[(B, Hq)](
mid_o_orig,
out_orig,
lse_orig,
seq_lens,
mid_o_orig.stride(0),
mid_o_orig.stride(1),
mid_o_orig.stride(2),
out_orig.stride(0),
out_orig.stride(1),
lse_orig.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=layout["BLOCK_D"],
Lv=D,
num_warps=4,
num_stages=2,
)

# --- Run grouped kernel ---
import triton as _triton

BLOCK_H = 16
heads_per_kv_head = _triton.cdiv(kv_group_size, BLOCK_H)
head_groups = Hk * heads_per_kv_head

mid_o_grouped = torch.empty(
B,
Hq,
NUM_KV_SPLITS,
D + 1,
dtype=torch.float32,
device=device,
)
grid_grouped = (B, head_groups, NUM_KV_SPLITS)
_tq_grouped_decode_stage1[grid_grouped](
q_rot,
kv_cache,
block_table,
seq_lens,
mid_o_grouped,
q_rot.stride(0),
q_rot.stride(1),
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
block_table.stride(0),
mid_o_grouped.stride(0),
mid_o_grouped.stride(1),
mid_o_grouped.stride(2),
HEAD_DIM=D,
BLOCK_SIZE=block_size,
NUM_KV_SPLITS=NUM_KV_SPLITS,
KV_GROUP_SIZE=kv_group_size,
Q_HEAD_NUM=Hq,
KPS=cfg.key_packed_size,
VQB=cfg.effective_value_quant_bits,
VAL_DATA_BYTES=layout["val_data_bytes"],
ATTN_SCALE=1.0 / math.sqrt(D),
BLOCK_D=layout["BLOCK_D"],
BLOCK_KV=16,
BLOCK_H=BLOCK_H,
FP8_E4B15=fp8_e4b15,
num_warps=4,
num_stages=2,
)
out_grouped = torch.empty(B, Hq, D, dtype=torch.float32, device=device)
lse_grouped = torch.empty(B, Hq, dtype=torch.float32, device=device)
_fwd_kernel_stage2[(B, Hq)](
mid_o_grouped,
out_grouped,
lse_grouped,
seq_lens,
mid_o_grouped.stride(0),
mid_o_grouped.stride(1),
mid_o_grouped.stride(2),
out_grouped.stride(0),
out_grouped.stride(1),
lse_grouped.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=layout["BLOCK_D"],
Lv=D,
num_warps=4,
num_stages=2,
)

# Compare: cosine similarity per head should be very high.
# fp16 tl.dot introduces minor precision diff, so allow small gap.
for b in range(B):
for h in range(Hq):
cos = torch.nn.functional.cosine_similarity(
out_orig[b, h].unsqueeze(0),
out_grouped[b, h].unsqueeze(0),
).item()
threshold = 0.98 if cfg.key_fp8 else 0.95
assert cos > threshold, (
f"Preset {preset} batch={b} head={h}: "
f"orig vs grouped cosine={cos:.4f} < {threshold}"
)
Loading
Loading