diff --git a/tests/quantization/test_turboquant.py b/tests/quantization/test_turboquant.py index f074ce119ae8..a84f69c73240 100644 --- a/tests/quantization/test_turboquant.py +++ b/tests/quantization/test_turboquant.py @@ -545,3 +545,72 @@ def test_single_token_roundtrip(self, preset): assert cos_sim > threshold, ( f"Preset {preset} head {h}: cosine_sim={cos_sim:.4f} < {threshold}" ) + + @pytest.mark.parametrize( + "preset", + ["turboquant_k8v4", "turboquant_4bit_nc"], + ) + def test_sparse_v_threshold_zero_matches_baseline(self, preset): + """Sparse V with threshold=0 must produce the same output as off. + + When the threshold is below any softmax probability, no tile is + skipped, so the sparse-V path must be byte-equivalent to the + non-sparse path. Any divergence indicates the wrapping logic + re-orders or skips an arithmetic op. + """ + from vllm.model_executor.layers.quantization.turboquant.centroids import ( + solve_lloyd_max, + ) + from vllm.v1.attention.ops.triton_turboquant_decode import ( + triton_turboquant_decode_attention, + ) + from vllm.v1.attention.ops.triton_turboquant_store import ( + triton_turboquant_store, + ) + + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + D, Hk, Hq, B = 128, 4, 4, 1 + block_size, num_blocks = 16, 1 + device = torch.device(DEVICE_TYPE) + + H = _build_hadamard(D, DEVICE_TYPE) + centroids, _ = solve_lloyd_max(D, cfg.centroid_bits) + centroids = centroids.float().to(device) + c_sorted, _ = centroids.sort() + midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device) + + torch.manual_seed(7) + key = torch.randn(B, Hk, D, device=device, dtype=torch.float16) + value = torch.randn(B, Hk, D, device=device, dtype=torch.float16) + kv_cache = torch.zeros( + num_blocks, block_size, Hk, cfg.slot_size_aligned, + device=device, dtype=torch.uint8, + ) + slot_mapping = torch.tensor([0], device=device, dtype=torch.int32) + triton_turboquant_store( + key, value, kv_cache, slot_mapping, H, midpoints, + mse_bits=cfg.key_mse_bits, key_packed_size=cfg.key_packed_size, + value_quant_bits=cfg.effective_value_quant_bits, + key_fp8=cfg.key_fp8, + ) + + query = key.expand(B, Hq, D).contiguous().to(torch.float16) + block_table = torch.tensor([[0]], device=device, dtype=torch.int32) + seq_lens = torch.tensor([1], device=device, dtype=torch.int32) + common = dict( + kv_cache=kv_cache, block_table=block_table, seq_lens=seq_lens, + Pi=H, centroids=centroids, scale=1.0 / math.sqrt(D), + mse_bits=cfg.key_mse_bits, key_packed_size=cfg.key_packed_size, + value_quant_bits=cfg.effective_value_quant_bits, + key_fp8=cfg.key_fp8, norm_correction=cfg.norm_correction, + PiT=H, max_num_kv_splits=4, + ) + out_off = triton_turboquant_decode_attention(query=query, **common) + out_on_zero_thresh = triton_turboquant_decode_attention( + query=query, sparse_v=True, sparse_v_threshold=0.0, **common, + ) + # threshold=0 means tl.max(p) < 0 is never true, so no tile skipped + assert torch.allclose(out_off, out_on_zero_thresh, atol=1e-6, rtol=1e-6), ( + f"Preset {preset}: sparse_v=True, threshold=0 diverges from " + f"sparse_v=False. max|Δ|={(out_off - out_on_zero_thresh).abs().max():.3e}" + ) diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index af2d0fb0830f..caef74e912df 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -18,6 +18,7 @@ import functools import math +import os from dataclasses import dataclass from typing import Any, ClassVar @@ -68,6 +69,40 @@ # per continuation, eliminating the O(N²/chunk_size) collapse at long context. _CONTINUATION_DECODE_THRESHOLD = 128 +# Sparse V: skip the per-tile V load + dequant when the tile's softmax +# probability is entirely below threshold. The kernel-side branch costs a +# tl.max + comparison; the savings come from avoiding the V-load and +# dequant arithmetic on tiles that contribute negligibly. Only worthwhile +# at long context where many tiles are far from the query. +# +# Off by default. Per-platform validation is needed before flipping the +# default-on (current bench is AMD MI300X only). Users can opt in: +# +# Env-var overrides per-process: +# VLLM_TQ_SPARSE_V "1" | "0" | "auto" (default "0") +# "auto" turns on at seq_len >= ctx threshold +# VLLM_TQ_SPARSE_V_THRESHOLD softmax-prob cutoff (default 0.001) +# VLLM_TQ_SPARSE_V_CTX_THRESHOLD min seq_len for "auto" mode (default 8192) +_TQ_SPARSE_V_MODE = os.environ.get("VLLM_TQ_SPARSE_V", "0").lower() +_TQ_SPARSE_V_THRESHOLD = float(os.environ.get("VLLM_TQ_SPARSE_V_THRESHOLD", "0.001")) +_TQ_SPARSE_V_CTX_THRESHOLD = int( + os.environ.get("VLLM_TQ_SPARSE_V_CTX_THRESHOLD", "8192") +) + + +def _tq_sparse_v_enabled(max_seq_len: int) -> bool: + """Whether sparse V is engaged for this forward pass. + + Off by default ('0'). '1' forces on, 'auto' gates on context length + so the per-tile branch overhead only kicks in when the cache is large + enough for the savings to pay for it. + """ + if _TQ_SPARSE_V_MODE == "1": + return True + if _TQ_SPARSE_V_MODE == "auto": + return max_seq_len >= _TQ_SPARSE_V_CTX_THRESHOLD + return False + def _build_hadamard(d: int, device_str: str) -> torch.Tensor: """Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device). @@ -662,6 +697,8 @@ def _prefill_attention( key_fp8=self.tq_config.key_fp8, norm_correction=self.tq_config.norm_correction, PiT=PiT, + sparse_v=_tq_sparse_v_enabled(seq_len), + sparse_v_threshold=_TQ_SPARSE_V_THRESHOLD, ) else: # Large continuation: dequant cached K/V and use @@ -874,5 +911,7 @@ def _decode_attention( lse_buf=lse_buf, buf_holder=layer, max_num_kv_splits=self.max_num_kv_splits, + sparse_v=_tq_sparse_v_enabled(int(attn_metadata.max_seq_len)), + sparse_v_threshold=_TQ_SPARSE_V_THRESHOLD, ) return result diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index 3adaf2610d8d..708a28181589 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -83,6 +83,8 @@ def _tq_decode_stage1( KEY_FP8: tl.constexpr, # 1 if K is stored as FP8 NORM_CORRECTION: tl.constexpr = 0, # 1 = re-normalize centroids FP8_E4B15: tl.constexpr = 0, # 1 = use e4b15 (Ampere/Ada), 0 = e4nv (Hopper+) + SPARSE_V: tl.constexpr = 0, # 1 = skip V load+accum on tiles below softmax threshold + SPARSE_V_THRESHOLD: tl.constexpr = 0.001, ): bid = tl.program_id(0) # batch index hid = tl.program_id(1) # q_head index @@ -231,77 +233,95 @@ def _tq_decode_stage1( p = tl.exp(scores - n_e_max) # ============================================================ - # VALUE LOAD + DEQUANTIZE: [BLOCK_KV, BLOCK_D] + # SPARSE V (opt-in): skip the V load + dequant + weighted-sum + # for tiles whose softmax probability is entirely below + # SPARSE_V_THRESHOLD. The skip path still decays the running + # accumulator and updates l_prev / m_prev so the online softmax + # totals stay exact. The per-tile branch costs a tl.max + a + # comparison; the win comes from avoiding ~10 tl.load calls and + # the dequant arithmetic on tiles that contribute negligibly. # ============================================================ - val_bases = slot_bases + KPS + skip_v_tile = False + if SPARSE_V: + skip_v_tile = tl.max(p) < SPARSE_V_THRESHOLD - if VQB == 3: - val_addrs0 = val_bases[:, None] + val_byte_idx[None, :] - val_raw0 = tl.load( - KV_cache_ptr + val_addrs0, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ).to(tl.int32) - val_raw1 = tl.load( - KV_cache_ptr + val_addrs0 + 1, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ).to(tl.int32) - raw16 = val_raw0 | (val_raw1 << 8) - v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32) - - sc_bases = val_bases + VAL_DATA_BYTES - sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( - tl.uint16 - ) - sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( - tl.uint16 - ) - v_scales = ( - (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - ) - zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( - tl.uint16 - ) - zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( - tl.uint16 - ) - v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - values = v_idx * v_scales[:, None] + v_zeros[:, None] - else: # VQB == 4 - vb_idx = d_offs // 2 - vb_shift = (d_offs % 2) * 4 - val_addrs = val_bases[:, None] + vb_idx[None, :] - val_raw = tl.load( - KV_cache_ptr + val_addrs, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ).to(tl.int32) - v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32) + if skip_v_tile: + acc = acc * re_scale + else: + # ======================================================== + # VALUE LOAD + DEQUANTIZE: [BLOCK_KV, BLOCK_D] + # ======================================================== + val_bases = slot_bases + KPS + + if VQB == 3: + val_addrs0 = val_bases[:, None] + val_byte_idx[None, :] + val_raw0 = tl.load( + KV_cache_ptr + val_addrs0, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + val_raw1 = tl.load( + KV_cache_ptr + val_addrs0 + 1, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + raw16 = val_raw0 | (val_raw1 << 8) + v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32) + + sc_bases = val_bases + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( + tl.uint16 + ) + sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_scales = ( + (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + ) + zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( + tl.uint16 + ) + zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_zeros = ( + (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + ) + values = v_idx * v_scales[:, None] + v_zeros[:, None] + else: # VQB == 4 + vb_idx = d_offs // 2 + vb_shift = (d_offs % 2) * 4 + val_addrs = val_bases[:, None] + vb_idx[None, :] + val_raw = tl.load( + KV_cache_ptr + val_addrs, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32) + + sc_bases = val_bases + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( + tl.uint16 + ) + sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_scales = ( + (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + ) + zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( + tl.uint16 + ) + zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_zeros = ( + (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + ) + values = v_idx * v_scales[:, None] + v_zeros[:, None] - sc_bases = val_bases + VAL_DATA_BYTES - sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( - tl.uint16 - ) - sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( - tl.uint16 - ) - v_scales = ( - (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - ) - zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( - tl.uint16 - ) - zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( - tl.uint16 - ) - v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - values = v_idx * v_scales[:, None] + v_zeros[:, None] + acc = acc * re_scale + tl.sum(p[:, None] * values, 0) - # ============================================================ - # WEIGHTED VALUE ACCUMULATION - # ============================================================ - acc = acc * re_scale + tl.sum(p[:, None] * values, 0) l_prev = l_prev * re_scale + tl.sum(p, 0) m_prev = n_e_max @@ -503,6 +523,8 @@ def triton_turboquant_decode_attention( lse_buf: torch.Tensor | None = None, buf_holder: Any = None, max_num_kv_splits: int = 32, # fixed split count (must be constant for cudagraph) + sparse_v: bool = False, + sparse_v_threshold: float = 0.001, ) -> torch.Tensor: """Launch fused TQ decode attention (Triton stage1 + stage2). @@ -583,6 +605,8 @@ def triton_turboquant_decode_attention( KEY_FP8=1 if key_fp8 else 0, NORM_CORRECTION=1 if norm_correction else 0, FP8_E4B15=fp8_e4b15, + SPARSE_V=1 if sparse_v else 0, + SPARSE_V_THRESHOLD=sparse_v_threshold, num_warps=1, num_stages=1, )