diff --git a/tests/kernels/test_compressor_kv_cache.py b/tests/kernels/test_compressor_kv_cache.py index 592b58fbe430..122254bc3c41 100644 --- a/tests/kernels/test_compressor_kv_cache.py +++ b/tests/kernels/test_compressor_kv_cache.py @@ -3,12 +3,11 @@ """ Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant. -Two paths tested: +Four test functions cover five paths: A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64 B) Indexer: head_dim=128 (all FP8), quant_block=128 - -These serve as golden references for validating the future fused -compressor+quant+cache kernel. + C) DeepseekV4 Attention magnitude range: correctness across small/large values + D) Indexer fused Triton kernel: compress+norm+rope+quant+insert """ import math @@ -21,6 +20,12 @@ dequantize_and_gather_k_cache, quantize_and_insert_k_cache, ) +from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import ( + _fused_kv_compress_norm_rope_insert_indexer_attn, + _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, +) + +from .test_fused_indexer_q_rope_quant import quantize_to_mxfp4 def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float): @@ -309,3 +314,222 @@ def test_deepseek_v4_quant_magnitude_range(): f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, " f"magnitude={magnitude:.4f}" ) + + +# ── Test D: Indexer fused K-cache insert (Triton kernels) ──────────────────── +# +# Both kernels share the same Triton signature; use_fp4 selects between them. +# Full pipeline: state-cache gather → softmax-weighted compress → RMSNorm → +# GPT-J RoPE → quant (MXFP4 or FP8) → paged cache insert. + + +def _reference_kv_compress_norm_rope( + state_cache: torch.Tensor, + block_table: torch.Tensor, + positions: torch.Tensor, + rms_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + compress_ratio: int = 1, + overlap: int = 0, + use_fp4: bool = False, + rms_eps: float = 1e-6, + fp8_max: float = 448.0, +): + """Compress → RMSNorm → GPT-J RoPE → quantize. + + Gathers (1+overlap)*compress_ratio state entries per output token, applies + per-element softmax over the scores, and computes the weighted kv sum. + Returns (quantized_values, scale) matching the kernel's output layout. + """ + device = state_cache.device + head_dim = rms_weight.shape[0] + rope_dim = cos_sin_cache.shape[-1] + state_block_size = state_cache.shape[1] + state_width = state_cache.shape[-1] // 2 + nope_dim = head_dim - rope_dim + total = (1 + overlap) * compress_ratio + results = [] + for pos in positions.tolist(): + src = torch.arange(pos - total + 1, pos + 1, dtype=torch.int64, device=device) + valid = src >= 0 + idx = src.clamp(min=0) + pages = block_table[0, idx // state_block_size] + offsets = idx % state_block_size + raw = state_cache[pages, offsets].float() # [total, state_dim] + + # Group 0 (tokens 0..cr-1): kv[:H], score[SW:SW+H] + # Group 1 (tokens cr..2cr-1): kv[H:2H], score[SW+H:SW+2H] + if overlap: + sw = state_width + g0_kv = raw[:compress_ratio, :head_dim] + g1_kv = raw[compress_ratio:, head_dim : 2 * head_dim] + g0_scores = raw[:compress_ratio, sw : sw + head_dim] + g1_scores = raw[compress_ratio:, sw + head_dim : sw + 2 * head_dim] + kv = torch.cat([g0_kv, g1_kv]) + scores = torch.cat([g0_scores, g1_scores]) + else: + kv = raw[:, :head_dim] + scores = raw[:, state_width : state_width + head_dim] + + scores[~valid] = float("-inf") + kv[~valid] = 0.0 + weights = torch.softmax(scores, dim=0) + compressed = (kv * weights).sum(dim=0) # [H] + var = (compressed * compressed).mean() + normed = compressed * torch.rsqrt(var + rms_eps) * rms_weight.float() + compressed_pos = (pos // compress_ratio) * compress_ratio + cos, sin = cos_sin_cache[compressed_pos].float().chunk(2) + nope, rope = normed.split([nope_dim, rope_dim]) + rope = torch.stack( + [rope[0::2] * cos - rope[1::2] * sin, rope[1::2] * cos + rope[0::2] * sin], + dim=-1, + ).reshape(rope_dim) + results.append(torch.cat([nope, rope]).to(state_cache.dtype)) + result = torch.stack(results) + + if use_fp4: + return quantize_to_mxfp4(result) + else: + pairs = [ + _ue8m0_reference(result[t], head_dim, fp8_max) for t in range(len(result)) + ] + quants, scales = zip(*pairs) + return torch.stack(quants), torch.cat(scales) + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32]) +@pytest.mark.parametrize("kv_block_size", [16, 32]) +@pytest.mark.parametrize("use_fp4", [False, True]) +def test_fused_kv_insert_indexer(num_tokens: int, kv_block_size: int, use_fp4: bool): + """Fused K compress+norm+rope+quant+insert for the indexer KV cache.""" + HEAD_DIM = 128 + ROPE_DIM = 64 + BLOCK_SIZE = 16 + RMS_EPS = 1e-6 + FP8_MAX = 448.0 + + device = "cuda" + torch.manual_seed(42) + compress_ratio = 4 + + if use_fp4: + TOKEN_STRIDE = HEAD_DIM // 2 # packed nibbles: 64 bytes + SCALE_DIM = HEAD_DIM // 32 # ue8m0 bytes: 4 + QUANT_BLOCK = 32 + kernel = _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn + else: + TOKEN_STRIDE = HEAD_DIM # FP8 bytes: 128 + SCALE_DIM = 4 # 1 float32: 4 bytes + QUANT_BLOCK = HEAD_DIM + kernel = _fused_kv_compress_norm_rope_insert_indexer_attn + + # overlap=1 whenever compress_ratio==4, matching DeepseekCompressor logic. + overlap = 1 if compress_ratio == 4 else 0 + coff = 1 + overlap # multiplier for state_dim per entry + + num_pages = (compress_ratio * num_tokens - 1) // BLOCK_SIZE + 2 + state_cache = torch.randn( + num_pages, + BLOCK_SIZE, + 2 * coff * HEAD_DIM, # kv_state + score_state, each coff*HEAD_DIM wide + dtype=torch.bfloat16, + device=device, + ) + block_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0) + token_to_req = torch.zeros(num_tokens, dtype=torch.int32, device=device) + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + positions = torch.arange( + compress_ratio - 1, + compress_ratio * num_tokens, + compress_ratio, + dtype=torch.int64, + device=device, + ) + rms_weight = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device) + cos_sin_cache = torch.randn(compress_ratio * num_tokens, ROPE_DIM, device=device) + + kv_n_blocks = (num_tokens + kv_block_size - 1) // kv_block_size + 1 + kv_cache = torch.zeros( + kv_n_blocks, + kv_block_size * (TOKEN_STRIDE + SCALE_DIM), + dtype=torch.uint8, + device=device, + ) + + kernel[(num_tokens,)]( + state_cache, + state_cache.stride(0), + state_cache.stride(1), + token_to_req, + positions, + slot_mapping, + block_table, + block_table.stride(0), + BLOCK_SIZE, + rms_weight, + RMS_EPS, + cos_sin_cache, + cos_sin_cache.stride(0), + kv_cache, + slot_mapping, + kv_block_size, + HEAD_SIZE=HEAD_DIM, + TRITON_BLOCK_SIZE=HEAD_DIM, + STATE_WIDTH=coff * HEAD_DIM, + COMPRESS_RATIO=compress_ratio, + OVERLAP=overlap, + ROPE_HEAD_DIM=ROPE_DIM, + FP8_MAX=FP8_MAX, + QUANT_BLOCK=QUANT_BLOCK, + TOKEN_STRIDE=TOKEN_STRIDE, + SCALE_DIM=SCALE_DIM, + KV_BLOCK_STRIDE=kv_cache.stride(0), + num_warps=1, + ) + + k_quant, scale = _reference_kv_compress_norm_rope( + state_cache, + block_table, + positions, + rms_weight, + cos_sin_cache, + compress_ratio, + overlap, + use_fp4, + rms_eps=RMS_EPS, + fp8_max=FP8_MAX, + ) + + if use_fp4: + for i in range(num_tokens): + blk, pos = i // kv_block_size, i % kv_block_size + val_off = pos * TOKEN_STRIDE + fp4_actual = kv_cache[blk, val_off : val_off + TOKEN_STRIDE] + assert torch.equal(k_quant[i], fp4_actual), ( + f"token {i}: packed nibbles differ, " + f"{(k_quant[i] != fp4_actual).sum()} " + f"/ {TOKEN_STRIDE}" + ) + + scale_off = kv_block_size * TOKEN_STRIDE + pos * SCALE_DIM + scale_actual = kv_cache[blk, scale_off : scale_off + SCALE_DIM] + assert torch.equal(scale_actual, scale[i]), ( + f"token {i}: ue8m0 {scale_actual.tolist()} != {scale[i].tolist()}" + ) + + else: + k_quant = k_quant.view(torch.uint8) + for i in range(num_tokens): + blk, pos = i // kv_block_size, i % kv_block_size + val_off = pos * TOKEN_STRIDE + assert torch.equal( + k_quant[i], kv_cache[blk, val_off : val_off + TOKEN_STRIDE] + ), f"token {i}: FP8 bytes differ" + + scale_off = kv_block_size * TOKEN_STRIDE + pos * SCALE_DIM + actual_scale = kv_cache[blk, scale_off : scale_off + SCALE_DIM].view( + torch.float32 + ) + assert torch.equal(actual_scale, scale[i : i + 1]), ( + f"token {i}: scale {actual_scale.item()} != {scale[i].item()}" + ) diff --git a/tests/kernels/test_fused_indexer_q_rope_quant.py b/tests/kernels/test_fused_indexer_q_rope_quant.py index 03d5ad4c8ac7..be2039ce513e 100644 --- a/tests/kernels/test_fused_indexer_q_rope_quant.py +++ b/tests/kernels/test_fused_indexer_q_rope_quant.py @@ -30,6 +30,56 @@ MAX_POS = 4096 +def quantize_to_mxfp4( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reference MXFP4 quantization. + + Args: + x: [..., head_dim] where head_dim is divisible by 32 + Returns: + packed: [..., head_dim//2] uint8 2 E2M1 nibbles/byte, low nibble = even index + scales: [..., head_dim//32] uint8 1 ue8m0 byte + """ + MXFP4_BLOCK_SIZE = 32 + orig_shape = x.shape + head_dim = orig_shape[-1] + n_blocks = head_dim // MXFP4_BLOCK_SIZE + + x_f32 = x.float().reshape(-1, n_blocks, MXFP4_BLOCK_SIZE) + + # Per-block ue8m0 scale: 2^ceil(log2(amax / 6.0)), stored as byte = exp + 127 + # 6 * 2^-126 is from https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/inference/kernel.py#L163 + amax = x_f32.abs().amax(dim=-1, keepdim=True).clamp(min=6 * (2**-126)) + log2_ratio = (amax * (1.0 / 6.0)).log2().ceil().clamp(-127.0, 127.0) + scale = log2_ratio.exp2() + ue8m0 = (log2_ratio + 127.0).to(torch.uint8) # [*, n_blocks] + + # E2M1 round-to-nearest-even: midpoints round to the even code. + # E2M1 values: [0.00, 0.50, 1.00, 1.50, 2.00, 3.00, 4.00, 6.00] + # boundaries: [ 0.25, 0.75, 1.25, 1.75, 2.50, 3.50, 5.00] + x_scaled = (x_f32 / scale).clamp(-6.0, 6.0) + abs_x = x_scaled.abs() + code = torch.zeros_like(abs_x, dtype=torch.int32) + code = torch.where(abs_x > 0.25, 1, code) + code = torch.where(abs_x >= 0.75, 2, code) + code = torch.where(abs_x > 1.25, 3, code) + code = torch.where(abs_x >= 1.75, 4, code) + code = torch.where(abs_x > 2.5, 5, code) + code = torch.where(abs_x >= 3.5, 6, code) + code = torch.where(abs_x > 5.0, 7, code) + sign = ((x_scaled.view(torch.int32) >> 31) & 1).to(torch.uint8) + nibble = code.to(torch.uint8) | (sign << 3) + + # Pack: even-index element → low nibble, odd-index → high nibble + nibble_flat = nibble.reshape(-1, head_dim) + packed = (nibble_flat[:, 0::2] | (nibble_flat[:, 1::2] << 4)).contiguous() + packed = packed.reshape(*orig_shape[:-1], head_dim // 2) + + scales = ue8m0.view(*orig_shape[:-1], n_blocks) + return packed, scales + + def _reference( positions: torch.Tensor, q: torch.Tensor, @@ -37,6 +87,7 @@ def _reference( weights: torch.Tensor, softmax_scale: float, head_scale: float, + use_fp4: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: q_rot = q.clone() ops.rotary_embedding( @@ -49,22 +100,33 @@ def _reference( HEAD_DIM - ROPE_DIM, # rope_dim_offset → rotate the tail False, ) - q_fp8, q_scale = per_token_group_quant_fp8( - q_rot.view(-1, HEAD_DIM).contiguous(), - HEAD_DIM, - use_ue8m0=True, - ) - q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM) - q_scale = q_scale.view(-1, N_HEAD) - weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale - return q_fp8, weights_out + if use_fp4: + q_packed, ue8m0 = quantize_to_mxfp4(q_rot.view(-1, N_HEAD, HEAD_DIM)) + # Pack 4 ue8m0 bytes into 1 int32 + q_scale = ue8m0.view(torch.int32).squeeze(-1) + # FP4 path: q_scale stays separate (cannot be folded into a per-token scalar) + weights_out = weights.to(torch.float32) * softmax_scale * head_scale + return (q_packed, q_scale), weights_out + + else: + q_fp8, q_scale = per_token_group_quant_fp8( + q_rot.view(-1, HEAD_DIM).contiguous(), + HEAD_DIM, + use_ue8m0=True, + ) + q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM) + q_scale = q_scale.view(-1, N_HEAD) + + weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale + return q_fp8, weights_out @pytest.mark.parametrize("num_tokens", [1, 7, 32, 257]) @pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("use_fp4", [False, True]) @torch.inference_mode() -def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype): +def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype, use_fp4): device = "cuda" torch.manual_seed(0) @@ -77,21 +139,32 @@ def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype): softmax_scale = HEAD_DIM**-0.5 head_scale = N_HEAD**-0.5 - q_fp8_ref, weights_ref = _reference( - positions, q, cos_sin_cache, weights, softmax_scale, head_scale + q_quant_ref, weights_ref = _reference( + positions, q, cos_sin_cache, weights, softmax_scale, head_scale, use_fp4 ) - q_fp8_fused, weights_fused = fused_indexer_q_rope_quant( - positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale + q_quant_fused, weights_fused = fused_indexer_q_rope_quant( + positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale, use_fp4 ) + if use_fp4: + q_quant_ref, q_scale_ref = q_quant_ref + q_quant_fused, q_scale_fused = q_quant_fused + + assert torch.equal(q_scale_ref, q_scale_fused), ( + f"q_scale mismatch: " + f"{(q_scale_ref != q_scale_fused).sum().item()} " + f"/ {q_scale_ref.numel()} bytes differ" + ) + # fp8 tensors aren't directly comparable via torch.equal — reinterpret as int8. - ref_bits = q_fp8_ref.view(torch.int8) - fused_bits = q_fp8_fused.view(torch.int8) + ref_bits = q_quant_ref.view(torch.int8) + fused_bits = q_quant_fused.view(torch.int8) assert torch.equal(ref_bits, fused_bits), ( - f"q_fp8 mismatch: " + f"q_quant_fused mismatch: " f"{(ref_bits != fused_bits).sum().item()} / {ref_bits.numel()} bytes differ" ) + assert weights_fused.dtype == torch.float32 assert torch.equal(weights_ref, weights_fused), ( f"weights mismatch: max abs diff " f"{(weights_ref - weights_fused).abs().max().item()}" diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py index 26b076f34238..2f97d8733c95 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py @@ -21,7 +21,7 @@ from vllm.triton_utils import tl, triton -from .fused_indexer_q import _e2m1_nibble +from .fused_indexer_q import _fp32x2_to_fp4x2 # ============================================================================= @@ -566,18 +566,18 @@ def _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn( tl.max(tl.abs(even_2d), axis=1), tl.max(tl.abs(odd_2d), axis=1), ) - amax = tl.maximum(amax, 1e-4) + amax = tl.maximum(amax, 6.0 * (2**-126)) # ue8m0 block scale: 2^ceil(log2(amax / 6.0)), stored as (exp + 127) byte. - log2_ratio = tl.ceil(tl.log2(amax / 6.0)) + log2_ratio = tl.ceil(tl.log2(amax * (1.0 / 6.0))) log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0) inv_scale = tl.exp2(-log2_ratio) ue8m0 = (log2_ratio + 127.0).to(tl.uint8) # [N_QUANT_BLOCKS] inv_scale_col = tl.reshape(inv_scale, (N_QUANT_BLOCKS, 1)) - lo_nib = _e2m1_nibble(even_2d * inv_scale_col) # (N_BLOCKS, HALF_BLOCK) uint8 - hi_nib = _e2m1_nibble(odd_2d * inv_scale_col) - packed = lo_nib | (hi_nib << 4) + packed = _fp32x2_to_fp4x2( + even_2d * inv_scale_col, odd_2d * inv_scale_col + ) # (N_BLOCKS, HALF_BLOCK) uint8 packed_flat = tl.reshape(packed, (TOKEN_STRIDE,)) tl.store(val_ptr + tl.arange(0, TOKEN_STRIDE), packed_flat) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index 0254a46752c6..f94fc013f5c6 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -24,36 +24,22 @@ def _get_cos_sin( @triton.jit -def _e2m1_nibble(x): - """Quantize fp32 x (already scale-divided) to E2M1 4-bit nibble in uint8. - Matches torch.bucketize with boundaries - [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] and right=False (each boundary - belongs to the lower bucket), plus sign bit.""" - abs_x = tl.minimum(tl.abs(x), 6.0) - code = tl.where( - abs_x <= 0.25, - 0.0, - tl.where( - abs_x <= 0.75, - 1.0, - tl.where( - abs_x <= 1.25, - 2.0, - tl.where( - abs_x <= 1.75, - 3.0, - tl.where( - abs_x <= 2.5, - 4.0, - tl.where(abs_x <= 3.5, 5.0, tl.where(abs_x <= 5.0, 6.0, 7.0)), - ), - ), - ), - ), - ) - code_u8 = code.to(tl.uint8) - sign = ((x < 0) & (code_u8 != 0)).to(tl.uint8) - return code_u8 | (sign << 3) +def _fp32x2_to_fp4x2(x_lo, x_hi): + # NOTE: $1 is high nibble, $2 is low nibble + return tl.inline_asm_elementwise( + """ + { + .reg .b8 tmp; + cvt.rn.satfinite.e2m1x2.f32 tmp, $1, $2; + cvt.u32.u8 $0, tmp; + } + """, + constraints="=r,f,f", + args=[x_hi, x_lo], + dtype=tl.uint32, + is_pure=True, + pack=1, + ).to(tl.uint8) @triton.jit @@ -65,17 +51,16 @@ def _quantize_mxfp4_pair(x_lo, x_hi): - ue8m0 : scalar uint8 (block scale = 2^(ue8m0 - 127)) """ amax = tl.maximum(tl.max(tl.abs(x_lo)), tl.max(tl.abs(x_hi))) - amax = tl.maximum(amax, 1e-4) + # 6 * 2^-126 is from https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/inference/kernel.py#L163 + amax = tl.maximum(amax, 6.0 * (2**-126)) # ue8m0 block scale: 2^ceil(log2(amax/6.0)). - log2_ratio = tl.math.ceil(tl.math.log2(amax / 6.0)) + log2_ratio = tl.math.ceil(tl.math.log2(amax * (1.0 / 6.0))) log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0) scale = tl.math.exp2(log2_ratio) ue8m0 = (log2_ratio + 127.0).to(tl.uint8) inv_scale = 1.0 / scale - lo_nib = _e2m1_nibble(x_lo * inv_scale) - hi_nib = _e2m1_nibble(x_hi * inv_scale) - packed = lo_nib | (hi_nib << 4) + packed = _fp32x2_to_fp4x2(x_lo * inv_scale, x_hi * inv_scale) return packed, ue8m0