Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/quantization/test_turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,72 @@ def test_single_token_roundtrip(self, preset):
assert cos_sim > threshold, (
f"Preset {preset} head {h}: cosine_sim={cos_sim:.4f} < {threshold}"
)

@pytest.mark.parametrize(
"preset",
["turboquant_k8v4", "turboquant_4bit_nc"],
)
def test_sparse_v_threshold_zero_matches_baseline(self, preset):
"""Sparse V with threshold=0 must produce the same output as off.

When the threshold is below any softmax probability, no tile is
skipped, so the sparse-V path must be byte-equivalent to the
non-sparse path. Any divergence indicates the wrapping logic
re-orders or skips an arithmetic op.
"""
from vllm.model_executor.layers.quantization.turboquant.centroids import (
solve_lloyd_max,
)
from vllm.v1.attention.ops.triton_turboquant_decode import (
triton_turboquant_decode_attention,
)
from vllm.v1.attention.ops.triton_turboquant_store import (
triton_turboquant_store,
)

cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128)
D, Hk, Hq, B = 128, 4, 4, 1
block_size, num_blocks = 16, 1
device = torch.device(DEVICE_TYPE)

H = _build_hadamard(D, DEVICE_TYPE)
centroids, _ = solve_lloyd_max(D, cfg.centroid_bits)
centroids = centroids.float().to(device)
c_sorted, _ = centroids.sort()
midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device)

torch.manual_seed(7)
key = torch.randn(B, Hk, D, device=device, dtype=torch.float16)
value = torch.randn(B, Hk, D, device=device, dtype=torch.float16)
kv_cache = torch.zeros(
num_blocks, block_size, Hk, cfg.slot_size_aligned,
device=device, dtype=torch.uint8,
)
slot_mapping = torch.tensor([0], device=device, dtype=torch.int32)
triton_turboquant_store(
key, value, kv_cache, slot_mapping, H, midpoints,
mse_bits=cfg.key_mse_bits, key_packed_size=cfg.key_packed_size,
value_quant_bits=cfg.effective_value_quant_bits,
key_fp8=cfg.key_fp8,
)

query = key.expand(B, Hq, D).contiguous().to(torch.float16)
block_table = torch.tensor([[0]], device=device, dtype=torch.int32)
seq_lens = torch.tensor([1], device=device, dtype=torch.int32)
common = dict(
kv_cache=kv_cache, block_table=block_table, seq_lens=seq_lens,
Pi=H, centroids=centroids, scale=1.0 / math.sqrt(D),
mse_bits=cfg.key_mse_bits, key_packed_size=cfg.key_packed_size,
value_quant_bits=cfg.effective_value_quant_bits,
key_fp8=cfg.key_fp8, norm_correction=cfg.norm_correction,
PiT=H, max_num_kv_splits=4,
)
out_off = triton_turboquant_decode_attention(query=query, **common)
out_on_zero_thresh = triton_turboquant_decode_attention(
query=query, sparse_v=True, sparse_v_threshold=0.0, **common,
)
# threshold=0 means tl.max(p) < 0 is never true, so no tile skipped
assert torch.allclose(out_off, out_on_zero_thresh, atol=1e-6, rtol=1e-6), (
f"Preset {preset}: sparse_v=True, threshold=0 diverges from "
f"sparse_v=False. max|Δ|={(out_off - out_on_zero_thresh).abs().max():.3e}"
)
39 changes: 39 additions & 0 deletions vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import functools
import math
import os
from dataclasses import dataclass
from typing import Any, ClassVar

Expand Down Expand Up @@ -68,6 +69,40 @@
# per continuation, eliminating the O(N²/chunk_size) collapse at long context.
_CONTINUATION_DECODE_THRESHOLD = 128

# Sparse V: skip the per-tile V load + dequant when the tile's softmax
# probability is entirely below threshold. The kernel-side branch costs a
# tl.max + comparison; the savings come from avoiding the V-load and
# dequant arithmetic on tiles that contribute negligibly. Only worthwhile
# at long context where many tiles are far from the query.
#
# Off by default. Per-platform validation is needed before flipping the
# default-on (current bench is AMD MI300X only). Users can opt in:
#
# Env-var overrides per-process:
# VLLM_TQ_SPARSE_V "1" | "0" | "auto" (default "0")
# "auto" turns on at seq_len >= ctx threshold
# VLLM_TQ_SPARSE_V_THRESHOLD softmax-prob cutoff (default 0.001)
# VLLM_TQ_SPARSE_V_CTX_THRESHOLD min seq_len for "auto" mode (default 8192)
_TQ_SPARSE_V_MODE = os.environ.get("VLLM_TQ_SPARSE_V", "0").lower()
_TQ_SPARSE_V_THRESHOLD = float(os.environ.get("VLLM_TQ_SPARSE_V_THRESHOLD", "0.001"))
_TQ_SPARSE_V_CTX_THRESHOLD = int(
os.environ.get("VLLM_TQ_SPARSE_V_CTX_THRESHOLD", "8192")
)


def _tq_sparse_v_enabled(max_seq_len: int) -> bool:
"""Whether sparse V is engaged for this forward pass.

Off by default ('0'). '1' forces on, 'auto' gates on context length
so the per-tile branch overhead only kicks in when the cache is large
enough for the savings to pay for it.
"""
if _TQ_SPARSE_V_MODE == "1":
return True
if _TQ_SPARSE_V_MODE == "auto":
return max_seq_len >= _TQ_SPARSE_V_CTX_THRESHOLD
return False


def _build_hadamard(d: int, device_str: str) -> torch.Tensor:
"""Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device).
Expand Down Expand Up @@ -662,6 +697,8 @@ def _prefill_attention(
key_fp8=self.tq_config.key_fp8,
norm_correction=self.tq_config.norm_correction,
PiT=PiT,
sparse_v=_tq_sparse_v_enabled(seq_len),
sparse_v_threshold=_TQ_SPARSE_V_THRESHOLD,
)
else:
# Large continuation: dequant cached K/V and use
Expand Down Expand Up @@ -874,5 +911,7 @@ def _decode_attention(
lse_buf=lse_buf,
buf_holder=layer,
max_num_kv_splits=self.max_num_kv_splits,
sparse_v=_tq_sparse_v_enabled(int(attn_metadata.max_seq_len)),
sparse_v_threshold=_TQ_SPARSE_V_THRESHOLD,
)
return result
158 changes: 91 additions & 67 deletions vllm/v1/attention/ops/triton_turboquant_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def _tq_decode_stage1(
KEY_FP8: tl.constexpr, # 1 if K is stored as FP8
NORM_CORRECTION: tl.constexpr = 0, # 1 = re-normalize centroids
FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+)
SPARSE_V: tl.constexpr = 0, # 1 = skip V load+accum on tiles below softmax threshold
SPARSE_V_THRESHOLD: tl.constexpr = 0.001,
):
bid = tl.program_id(0) # batch index
hid = tl.program_id(1) # q_head index
Expand Down Expand Up @@ -231,77 +233,95 @@ def _tq_decode_stage1(
p = tl.exp(scores - n_e_max)

# ============================================================
# VALUE LOAD + DEQUANTIZE: [BLOCK_KV, BLOCK_D]
# SPARSE V (opt-in): skip the V load + dequant + weighted-sum
# for tiles whose softmax probability is entirely below
# SPARSE_V_THRESHOLD. The skip path still decays the running
# accumulator and updates l_prev / m_prev so the online softmax
# totals stay exact. The per-tile branch costs a tl.max + a
# comparison; the win comes from avoiding ~10 tl.load calls and
# the dequant arithmetic on tiles that contribute negligibly.
# ============================================================
val_bases = slot_bases + KPS
skip_v_tile = False
if SPARSE_V:
skip_v_tile = tl.max(p) < SPARSE_V_THRESHOLD

if VQB == 3:
val_addrs0 = val_bases[:, None] + val_byte_idx[None, :]
val_raw0 = tl.load(
KV_cache_ptr + val_addrs0,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
val_raw1 = tl.load(
KV_cache_ptr + val_addrs0 + 1,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
raw16 = val_raw0 | (val_raw1 << 8)
v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32)

sc_bases = val_bases + VAL_DATA_BYTES
sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(
tl.uint16
)
sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to(
tl.uint16
)
v_scales = (
(sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
)
zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to(
tl.uint16
)
zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to(
tl.uint16
)
v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
values = v_idx * v_scales[:, None] + v_zeros[:, None]
else: # VQB == 4
vb_idx = d_offs // 2
vb_shift = (d_offs % 2) * 4
val_addrs = val_bases[:, None] + vb_idx[None, :]
val_raw = tl.load(
KV_cache_ptr + val_addrs,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32)
if skip_v_tile:
acc = acc * re_scale
else:
# ========================================================
# VALUE LOAD + DEQUANTIZE: [BLOCK_KV, BLOCK_D]
# ========================================================
val_bases = slot_bases + KPS

if VQB == 3:
val_addrs0 = val_bases[:, None] + val_byte_idx[None, :]
val_raw0 = tl.load(
KV_cache_ptr + val_addrs0,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
val_raw1 = tl.load(
KV_cache_ptr + val_addrs0 + 1,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
raw16 = val_raw0 | (val_raw1 << 8)
v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32)

sc_bases = val_bases + VAL_DATA_BYTES
sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(
tl.uint16
)
sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to(
tl.uint16
)
v_scales = (
(sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
)
zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to(
tl.uint16
)
zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to(
tl.uint16
)
v_zeros = (
(zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
)
values = v_idx * v_scales[:, None] + v_zeros[:, None]
else: # VQB == 4
vb_idx = d_offs // 2
vb_shift = (d_offs % 2) * 4
val_addrs = val_bases[:, None] + vb_idx[None, :]
val_raw = tl.load(
KV_cache_ptr + val_addrs,
mask=kv_mask[:, None] & d_mask[None, :],
other=0,
).to(tl.int32)
v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32)

sc_bases = val_bases + VAL_DATA_BYTES
sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(
tl.uint16
)
sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to(
tl.uint16
)
v_scales = (
(sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
)
zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to(
tl.uint16
)
zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to(
tl.uint16
)
v_zeros = (
(zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
)
values = v_idx * v_scales[:, None] + v_zeros[:, None]

sc_bases = val_bases + VAL_DATA_BYTES
sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(
tl.uint16
)
sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to(
tl.uint16
)
v_scales = (
(sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
)
zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to(
tl.uint16
)
zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to(
tl.uint16
)
v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
values = v_idx * v_scales[:, None] + v_zeros[:, None]
acc = acc * re_scale + tl.sum(p[:, None] * values, 0)

# ============================================================
# WEIGHTED VALUE ACCUMULATION
# ============================================================
acc = acc * re_scale + tl.sum(p[:, None] * values, 0)
l_prev = l_prev * re_scale + tl.sum(p, 0)
m_prev = n_e_max

Expand Down Expand Up @@ -503,6 +523,8 @@ def triton_turboquant_decode_attention(
lse_buf: torch.Tensor | None = None,
buf_holder: Any = None,
max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph)
sparse_v: bool = False,
sparse_v_threshold: float = 0.001,
) -> torch.Tensor:
"""Launch fused TQ decode attention (Triton stage1 + stage2).

Expand Down Expand Up @@ -583,6 +605,8 @@ def triton_turboquant_decode_attention(
KEY_FP8=1 if key_fp8 else 0,
NORM_CORRECTION=1 if norm_correction else 0,
FP8_E4B15=fp8_e4b15,
SPARSE_V=1 if sparse_v else 0,
SPARSE_V_THRESHOLD=sparse_v_threshold,
num_warps=1,
num_stages=1,
)
Expand Down
Loading