Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 228 additions & 4 deletions tests/kernels/test_compressor_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
"""
Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant.

Two paths tested:
Four test functions cover five paths:
A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64
B) Indexer: head_dim=128 (all FP8), quant_block=128

These serve as golden references for validating the future fused
compressor+quant+cache kernel.
C) DeepseekV4 Attention magnitude range: correctness across small/large values
D) Indexer fused Triton kernel: compress+norm+rope+quant+insert
"""

import math
Expand All @@ -21,6 +20,12 @@
dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
)

from .test_fused_indexer_q_rope_quant import quantize_to_mxfp4


def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float):
Expand Down Expand Up @@ -309,3 +314,222 @@ def test_deepseek_v4_quant_magnitude_range():
f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, "
f"magnitude={magnitude:.4f}"
)


# ── Test D: Indexer fused K-cache insert (Triton kernels) ────────────────────
#
# Both kernels share the same Triton signature; use_fp4 selects between them.
# Full pipeline: state-cache gather → softmax-weighted compress → RMSNorm →
# GPT-J RoPE → quant (MXFP4 or FP8) → paged cache insert.


def _reference_kv_compress_norm_rope(
state_cache: torch.Tensor,
block_table: torch.Tensor,
positions: torch.Tensor,
rms_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
compress_ratio: int = 1,
overlap: int = 0,
use_fp4: bool = False,
rms_eps: float = 1e-6,
fp8_max: float = 448.0,
):
"""Compress → RMSNorm → GPT-J RoPE → quantize.

Gathers (1+overlap)*compress_ratio state entries per output token, applies
per-element softmax over the scores, and computes the weighted kv sum.
Returns (quantized_values, scale) matching the kernel's output layout.
"""
device = state_cache.device
head_dim = rms_weight.shape[0]
rope_dim = cos_sin_cache.shape[-1]
state_block_size = state_cache.shape[1]
state_width = state_cache.shape[-1] // 2
nope_dim = head_dim - rope_dim
total = (1 + overlap) * compress_ratio
results = []
for pos in positions.tolist():
src = torch.arange(pos - total + 1, pos + 1, dtype=torch.int64, device=device)
valid = src >= 0
idx = src.clamp(min=0)
pages = block_table[0, idx // state_block_size]
offsets = idx % state_block_size
raw = state_cache[pages, offsets].float() # [total, state_dim]

# Group 0 (tokens 0..cr-1): kv[:H], score[SW:SW+H]
# Group 1 (tokens cr..2cr-1): kv[H:2H], score[SW+H:SW+2H]
if overlap:
sw = state_width
g0_kv = raw[:compress_ratio, :head_dim]
g1_kv = raw[compress_ratio:, head_dim : 2 * head_dim]
g0_scores = raw[:compress_ratio, sw : sw + head_dim]
g1_scores = raw[compress_ratio:, sw + head_dim : sw + 2 * head_dim]
kv = torch.cat([g0_kv, g1_kv])
scores = torch.cat([g0_scores, g1_scores])
else:
kv = raw[:, :head_dim]
scores = raw[:, state_width : state_width + head_dim]

scores[~valid] = float("-inf")
kv[~valid] = 0.0
weights = torch.softmax(scores, dim=0)
compressed = (kv * weights).sum(dim=0) # [H]
var = (compressed * compressed).mean()
normed = compressed * torch.rsqrt(var + rms_eps) * rms_weight.float()
compressed_pos = (pos // compress_ratio) * compress_ratio
cos, sin = cos_sin_cache[compressed_pos].float().chunk(2)
nope, rope = normed.split([nope_dim, rope_dim])
rope = torch.stack(
[rope[0::2] * cos - rope[1::2] * sin, rope[1::2] * cos + rope[0::2] * sin],
dim=-1,
).reshape(rope_dim)
results.append(torch.cat([nope, rope]).to(state_cache.dtype))
result = torch.stack(results)

if use_fp4:
return quantize_to_mxfp4(result)
else:
pairs = [
_ue8m0_reference(result[t], head_dim, fp8_max) for t in range(len(result))
]
quants, scales = zip(*pairs)
return torch.stack(quants), torch.cat(scales)


@pytest.mark.parametrize("num_tokens", [1, 7, 32])
@pytest.mark.parametrize("kv_block_size", [16, 32])
@pytest.mark.parametrize("use_fp4", [False, True])
def test_fused_kv_insert_indexer(num_tokens: int, kv_block_size: int, use_fp4: bool):
"""Fused K compress+norm+rope+quant+insert for the indexer KV cache."""
HEAD_DIM = 128
ROPE_DIM = 64
BLOCK_SIZE = 16
RMS_EPS = 1e-6
FP8_MAX = 448.0

device = "cuda"
torch.manual_seed(42)
compress_ratio = 4

if use_fp4:
TOKEN_STRIDE = HEAD_DIM // 2 # packed nibbles: 64 bytes
SCALE_DIM = HEAD_DIM // 32 # ue8m0 bytes: 4
QUANT_BLOCK = 32
kernel = _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
else:
TOKEN_STRIDE = HEAD_DIM # FP8 bytes: 128
SCALE_DIM = 4 # 1 float32: 4 bytes
QUANT_BLOCK = HEAD_DIM
kernel = _fused_kv_compress_norm_rope_insert_indexer_attn

# overlap=1 whenever compress_ratio==4, matching DeepseekCompressor logic.
overlap = 1 if compress_ratio == 4 else 0
coff = 1 + overlap # multiplier for state_dim per entry

num_pages = (compress_ratio * num_tokens - 1) // BLOCK_SIZE + 2
state_cache = torch.randn(
num_pages,
BLOCK_SIZE,
2 * coff * HEAD_DIM, # kv_state + score_state, each coff*HEAD_DIM wide
dtype=torch.bfloat16,
device=device,
)
block_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0)
token_to_req = torch.zeros(num_tokens, dtype=torch.int32, device=device)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
positions = torch.arange(
compress_ratio - 1,
compress_ratio * num_tokens,
compress_ratio,
dtype=torch.int64,
device=device,
)
rms_weight = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
cos_sin_cache = torch.randn(compress_ratio * num_tokens, ROPE_DIM, device=device)

kv_n_blocks = (num_tokens + kv_block_size - 1) // kv_block_size + 1
kv_cache = torch.zeros(
kv_n_blocks,
kv_block_size * (TOKEN_STRIDE + SCALE_DIM),
dtype=torch.uint8,
device=device,
)

kernel[(num_tokens,)](
state_cache,
state_cache.stride(0),
state_cache.stride(1),
token_to_req,
positions,
slot_mapping,
block_table,
block_table.stride(0),
BLOCK_SIZE,
rms_weight,
RMS_EPS,
cos_sin_cache,
cos_sin_cache.stride(0),
kv_cache,
slot_mapping,
kv_block_size,
HEAD_SIZE=HEAD_DIM,
TRITON_BLOCK_SIZE=HEAD_DIM,
STATE_WIDTH=coff * HEAD_DIM,
COMPRESS_RATIO=compress_ratio,
OVERLAP=overlap,
ROPE_HEAD_DIM=ROPE_DIM,
FP8_MAX=FP8_MAX,
QUANT_BLOCK=QUANT_BLOCK,
TOKEN_STRIDE=TOKEN_STRIDE,
SCALE_DIM=SCALE_DIM,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=1,
)

k_quant, scale = _reference_kv_compress_norm_rope(
state_cache,
block_table,
positions,
rms_weight,
cos_sin_cache,
compress_ratio,
overlap,
use_fp4,
rms_eps=RMS_EPS,
fp8_max=FP8_MAX,
)

if use_fp4:
for i in range(num_tokens):
blk, pos = i // kv_block_size, i % kv_block_size
val_off = pos * TOKEN_STRIDE
fp4_actual = kv_cache[blk, val_off : val_off + TOKEN_STRIDE]
assert torch.equal(k_quant[i], fp4_actual), (
f"token {i}: packed nibbles differ, "
f"{(k_quant[i] != fp4_actual).sum()} "
f"/ {TOKEN_STRIDE}"
)

scale_off = kv_block_size * TOKEN_STRIDE + pos * SCALE_DIM
scale_actual = kv_cache[blk, scale_off : scale_off + SCALE_DIM]
assert torch.equal(scale_actual, scale[i]), (
f"token {i}: ue8m0 {scale_actual.tolist()} != {scale[i].tolist()}"
)

else:
k_quant = k_quant.view(torch.uint8)
for i in range(num_tokens):
blk, pos = i // kv_block_size, i % kv_block_size
val_off = pos * TOKEN_STRIDE
assert torch.equal(
k_quant[i], kv_cache[blk, val_off : val_off + TOKEN_STRIDE]
), f"token {i}: FP8 bytes differ"

scale_off = kv_block_size * TOKEN_STRIDE + pos * SCALE_DIM
actual_scale = kv_cache[blk, scale_off : scale_off + SCALE_DIM].view(
torch.float32
)
assert torch.equal(actual_scale, scale[i : i + 1]), (
f"token {i}: scale {actual_scale.item()} != {scale[i].item()}"
)
Loading
Loading