From 158226925f3f8e0cb1bb8ea951af2dcb1781c56f Mon Sep 17 00:00:00 2001 From: hoseung-kim Date: Fri, 24 Apr 2026 02:32:49 +0000 Subject: [PATCH 1/3] perf: truboquant GQA head grouping Signed-off-by: hoseung-kim --- .../attention/ops/triton_turboquant_decode.py | 404 ++++++++++++++++-- 1 file changed, 367 insertions(+), 37 deletions(-) diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index a789f9be7bb2..55d7d152fc6c 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -313,6 +313,291 @@ def _tq_decode_stage1( tl.store(Mid_o_ptr + out_base + HEAD_DIM, lse) +# --------------------------------------------------------------------------- +# Stage 1 (grouped): GQA head grouping + tl.dot tensor-core scoring +# --------------------------------------------------------------------------- + + +@triton.jit +def _tq_grouped_decode_stage1( + Q_rot_ptr, + KV_cache_ptr, + Block_table_ptr, + Seq_lens_ptr, + Centroids_ptr, + Mid_o_ptr, + stride_qb, + stride_qh, + stride_cache_block, + stride_cache_pos, + stride_cache_head, + stride_bt_b, + stride_mid_b, + stride_mid_h, + stride_mid_s, + NUM_KV_HEADS: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + KV_GROUP_SIZE: tl.constexpr, + Q_HEAD_NUM: tl.constexpr, + MSE_BITS: tl.constexpr, + MSE_BYTES: tl.constexpr, + KPS: tl.constexpr, + VQB: tl.constexpr, + VAL_DATA_BYTES: tl.constexpr, + ATTN_SCALE: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_KV: tl.constexpr, + BLOCK_H: tl.constexpr, + KEY_FP8: tl.constexpr, + NORM_CORRECTION: tl.constexpr = 0, + FP8_E4B15: tl.constexpr = 0, +): + """GQA-grouped TQ decode stage1. + + Each CTA processes min(BLOCK_H, KV_GROUP_SIZE) Q heads that share + one KV head, loading K/V once and computing scores via tl.dot. + """ + bid = tl.program_id(0) + head_group_id = tl.program_id(1) + sid = tl.program_id(2) + + # Map head_group_id → KV head + Q head range + VALID_BLOCK_H: tl.constexpr = BLOCK_H if KV_GROUP_SIZE > BLOCK_H else KV_GROUP_SIZE + kv_head = head_group_id // tl.cdiv(KV_GROUP_SIZE, BLOCK_H) + cur_head = head_group_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (head_group_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < Q_HEAD_NUM) + + seq_len = tl.load(Seq_lens_ptr + bid) + split_len = tl.cdiv(seq_len, NUM_KV_SPLITS) + split_start = split_len * sid + split_end = tl.minimum(split_start + split_len, seq_len) + + if split_start >= split_end: + # Still must write valid -inf LSE for masked heads + out_base = bid * stride_mid_b + cur_head * stride_mid_h + sid * stride_mid_s + tl.store(Mid_o_ptr + out_base + HEAD_DIM, float("-inf"), mask=mask_h) + return + + d_offs = tl.arange(0, BLOCK_D) + d_mask = d_offs < HEAD_DIM + kv_range = tl.arange(0, BLOCK_KV) + + # Load Q: [BLOCK_H, BLOCK_D] + q_base = bid * stride_qb + cur_head[:, None] * stride_qh + d_offs[None, :] + q_rot = tl.load( + Q_rot_ptr + q_base, + mask=mask_h[:, None] & d_mask[None, :], + other=0.0, + ).to(tl.float32) + + # Precompute MSE bit/byte index vectors (loop-invariant) + if not KEY_FP8: + mse_bit_off = d_offs * MSE_BITS + mse_byte_idx = mse_bit_off // 8 + mse_bit_shift = mse_bit_off % 8 + mse_mask = (1 << MSE_BITS) - 1 + + if VQB == 3: + val_bit_off = d_offs * 3 + val_byte_idx = val_bit_off // 8 + val_bit_shift = val_bit_off % 8 + + # Online softmax accumulators: [BLOCK_H] + m_prev = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + + bt_base = bid * stride_bt_b + + for start_n in range(split_start, split_end, BLOCK_KV): + kv_offs = start_n + kv_range + kv_mask = kv_offs < split_end + + page_idx = kv_offs // BLOCK_SIZE + page_off = kv_offs % BLOCK_SIZE + block_nums = tl.load( + Block_table_ptr + bt_base + page_idx, mask=kv_mask, other=0 + ).to(tl.int64) + + slot_bases = ( + block_nums * stride_cache_block + + page_off.to(tl.int64) * stride_cache_pos + + tl.cast(kv_head, tl.int64) * stride_cache_head + ) + + # ============================================================ + # K DEQUANT → k_float [BLOCK_KV, BLOCK_D] + # ============================================================ + if KEY_FP8: + k_addrs = slot_bases[:, None] + d_offs[None, :] + k_raw = tl.load( + KV_cache_ptr + k_addrs, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ) + if FP8_E4B15: + # SM < 8.9: SW emulation requires float32 intermediate + k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) + else: + k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) + + # scores = q_rot @ k_float^T : [BLOCK_H, BLOCK_KV] + scores = tl.dot(q_rot.to(tl.float16), tl.trans(k_float.to(tl.float16))) + scores = (scores * ATTN_SCALE).to(tl.float32) + scores = tl.where(mask_h[:, None] & kv_mask[None, :], scores, -float("inf")) + else: + # MSE unpack → centroid gather → k_dequant [BLOCK_KV, BLOCK_D] + mse_addrs0 = slot_bases[:, None] + mse_byte_idx[None, :] + mse_raw0 = tl.load( + KV_cache_ptr + mse_addrs0, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + mse_raw1 = tl.load( + KV_cache_ptr + mse_addrs0 + 1, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + raw16 = mse_raw0 | (mse_raw1 << 8) + mse_idx = (raw16 >> mse_bit_shift[None, :]) & mse_mask + + c_vals = tl.load( + Centroids_ptr + mse_idx, + mask=kv_mask[:, None] & d_mask[None, :], + other=0.0, + ) + + if NORM_CORRECTION: + c_norm_sq = tl.sum( + tl.where(d_mask[None, :], c_vals * c_vals, 0.0), axis=1 + ) + c_inv_norm = 1.0 / tl.sqrt(c_norm_sq + 1e-16) + c_vals = c_vals * c_inv_norm[:, None] + + # term1 = q_rot @ c_vals^T : [BLOCK_H, BLOCK_KV] + term1 = tl.dot(q_rot.to(tl.float16), tl.trans(c_vals.to(tl.float16))) + + norm_bases = slot_bases + MSE_BYTES + n_lo = tl.load(KV_cache_ptr + norm_bases, mask=kv_mask, other=0).to( + tl.uint16 + ) + n_hi = tl.load(KV_cache_ptr + norm_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + vec_norms = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + + scores = vec_norms[None, :] * term1.to(tl.float32) * ATTN_SCALE + scores = tl.where(mask_h[:, None] & kv_mask[None, :], scores, -float("inf")) + + # ============================================================ + # ONLINE SOFTMAX: [BLOCK_H] + # ============================================================ + n_e_max = tl.maximum(tl.max(scores, 1), m_prev) + re_scale = tl.exp(m_prev - n_e_max) + p = tl.exp(scores - n_e_max[:, None]) + + # ============================================================ + # V DEQUANT → values [BLOCK_KV, BLOCK_D] + # ============================================================ + val_bases = slot_bases + KPS + + if VQB == 3: + val_addrs0 = val_bases[:, None] + val_byte_idx[None, :] + val_raw0 = tl.load( + KV_cache_ptr + val_addrs0, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + val_raw1 = tl.load( + KV_cache_ptr + val_addrs0 + 1, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + raw16_val = val_raw0 | (val_raw1 << 8) + v_idx = ((raw16_val >> val_bit_shift[None, :]) & 0x7).to(tl.float32) + + sc_bases = val_bases + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( + tl.uint16 + ) + sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_scales = ( + (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + ) + zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( + tl.uint16 + ) + zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + values = v_idx * v_scales[:, None] + v_zeros[:, None] + else: # VQB == 4 + vb_idx = d_offs // 2 + vb_shift = (d_offs % 2) * 4 + val_addrs = val_bases[:, None] + vb_idx[None, :] + val_raw = tl.load( + KV_cache_ptr + val_addrs, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32) + + sc_bases = val_bases + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( + tl.uint16 + ) + sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_scales = ( + (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + ) + zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( + tl.uint16 + ) + zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + values = v_idx * v_scales[:, None] + v_zeros[:, None] + + # ============================================================ + # ACCUMULATE: acc += p @ values via tl.dot + # ============================================================ + acc = acc * re_scale[:, None] + tl.dot( + p.to(tl.float16), values.to(tl.float16) + ).to(tl.float32) + l_prev = l_prev * re_scale + tl.sum(p, 1) + m_prev = n_e_max + + # Store partial results per Q head + safe_l = tl.where(l_prev > 0.0, l_prev, 1.0) + out_base = ( + bid * stride_mid_b + cur_head[:, None] * stride_mid_h + sid * stride_mid_s + ) + tl.store( + Mid_o_ptr + out_base + d_offs[None, :], + acc / safe_l[:, None], + mask=mask_h[:, None] & d_mask[None, :], + ) + lse = m_prev + tl.log(safe_l) + tl.store( + Mid_o_ptr + + bid * stride_mid_b + + cur_head * stride_mid_h + + sid * stride_mid_s + + HEAD_DIM, + lse, + mask=mask_h, + ) + + # --------------------------------------------------------------------------- # Pre-dequant kernel: Bulk dequant K (MSE+norms) and V to fp16 # --------------------------------------------------------------------------- @@ -549,43 +834,88 @@ def triton_turboquant_decode_attention( # Stage 1: split-KV tiled attention scoring + value accumulation fp8_e4b15 = _use_fp8_e4b15(device.index or 0) - BLOCK_KV = 4 - grid = (B, Hq, NUM_KV_SPLITS) - _tq_decode_stage1[grid]( - q_rot, - kv_cache, - block_table, - seq_lens, - centroids, - mid_o, - q_rot.stride(0), - q_rot.stride(1), - kv_cache.stride(0), - kv_cache.stride(1), - kv_cache.stride(2), - block_table.stride(0), - mid_o.stride(0), - mid_o.stride(1), - mid_o.stride(2), - NUM_KV_HEADS=Hk, - HEAD_DIM=D, - BLOCK_SIZE=block_size, - NUM_KV_SPLITS=NUM_KV_SPLITS, - KV_GROUP_SIZE=kv_group_size, - MSE_BITS=mse_bits, - MSE_BYTES=cfg["mse_bytes"], - KPS=key_packed_size, - VQB=value_quant_bits, - VAL_DATA_BYTES=cfg["val_data_bytes"], - ATTN_SCALE=scale, - BLOCK_D=cfg["BLOCK_D"], - BLOCK_KV=BLOCK_KV, - KEY_FP8=1 if key_fp8 else 0, - NORM_CORRECTION=1 if norm_correction else 0, - FP8_E4B15=fp8_e4b15, - num_warps=1, - num_stages=1, - ) + BLOCK_H = 16 + BLOCK_KV_GROUPED = 16 + VALID_BLOCK_H = min(BLOCK_H, kv_group_size) + head_groups = triton.cdiv(Hq, VALID_BLOCK_H) + + if kv_group_size > 1 and key_fp8: + grid = (B, head_groups, NUM_KV_SPLITS) + _tq_grouped_decode_stage1[grid]( + q_rot, + kv_cache, + block_table, + seq_lens, + centroids, + mid_o, + q_rot.stride(0), + q_rot.stride(1), + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + block_table.stride(0), + mid_o.stride(0), + mid_o.stride(1), + mid_o.stride(2), + NUM_KV_HEADS=Hk, + HEAD_DIM=D, + BLOCK_SIZE=block_size, + NUM_KV_SPLITS=NUM_KV_SPLITS, + KV_GROUP_SIZE=kv_group_size, + Q_HEAD_NUM=Hq, + MSE_BITS=mse_bits, + MSE_BYTES=cfg["mse_bytes"], + KPS=key_packed_size, + VQB=value_quant_bits, + VAL_DATA_BYTES=cfg["val_data_bytes"], + ATTN_SCALE=scale, + BLOCK_D=cfg["BLOCK_D"], + BLOCK_KV=BLOCK_KV_GROUPED, + BLOCK_H=BLOCK_H, + KEY_FP8=1 if key_fp8 else 0, + NORM_CORRECTION=1 if norm_correction else 0, + FP8_E4B15=fp8_e4b15, + num_warps=4, + num_stages=2, + ) + else: + BLOCK_KV = 4 + grid = (B, Hq, NUM_KV_SPLITS) + _tq_decode_stage1[grid]( + q_rot, + kv_cache, + block_table, + seq_lens, + centroids, + mid_o, + q_rot.stride(0), + q_rot.stride(1), + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + block_table.stride(0), + mid_o.stride(0), + mid_o.stride(1), + mid_o.stride(2), + NUM_KV_HEADS=Hk, + HEAD_DIM=D, + BLOCK_SIZE=block_size, + NUM_KV_SPLITS=NUM_KV_SPLITS, + KV_GROUP_SIZE=kv_group_size, + MSE_BITS=mse_bits, + MSE_BYTES=cfg["mse_bytes"], + KPS=key_packed_size, + VQB=value_quant_bits, + VAL_DATA_BYTES=cfg["val_data_bytes"], + ATTN_SCALE=scale, + BLOCK_D=cfg["BLOCK_D"], + BLOCK_KV=BLOCK_KV, + KEY_FP8=1 if key_fp8 else 0, + NORM_CORRECTION=1 if norm_correction else 0, + FP8_E4B15=fp8_e4b15, + num_warps=1, + num_stages=1, + ) # Stage 2: Reduce across KV splits if output_buf is not None and output_buf.shape[0] >= B: From 4059a054a28951dcef67dff4b6bd9ea27f22dbd7 Mon Sep 17 00:00:00 2001 From: hoseung-kim Date: Fri, 24 Apr 2026 04:54:14 +0000 Subject: [PATCH 2/3] test: add turboquant gqa grouping test Signed-off-by: hoseung-kim --- tests/quantization/test_turboquant.py | 376 ++++++++++++++++++++++++++ 1 file changed, 376 insertions(+) diff --git a/tests/quantization/test_turboquant.py b/tests/quantization/test_turboquant.py index 90beb64a474c..f5c8fdff1c23 100644 --- a/tests/quantization/test_turboquant.py +++ b/tests/quantization/test_turboquant.py @@ -542,3 +542,379 @@ def test_single_token_roundtrip(self, preset): assert cos_sim > threshold, ( f"Preset {preset} head {h}: cosine_sim={cos_sim:.4f} < {threshold}" ) + + @pytest.mark.parametrize("kv_group_size", [4, 8]) + def test_gqa_roundtrip_k8v4(self, kv_group_size): + """GQA round-trip for the grouped decode kernel path. + + Only turboquant_k8v4 (FP8 keys) uses the grouped kernel; the MSE + presets route to the original scalar kernel, which is already + covered by test_single_token_roundtrip. + """ + preset = "turboquant_k8v4" + from vllm.model_executor.layers.quantization.turboquant.centroids import ( + solve_lloyd_max, + ) + from vllm.v1.attention.ops.triton_turboquant_decode import ( + triton_turboquant_decode_attention, + ) + from vllm.v1.attention.ops.triton_turboquant_store import ( + triton_turboquant_store, + ) + + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + D = 128 + Hk = 4 + Hq = Hk * kv_group_size + B = 2 + seq_len = 32 + block_size = 16 + num_blocks = (seq_len + block_size - 1) // block_size + + device = torch.device(DEVICE_TYPE) + + H = _build_hadamard(D, DEVICE_TYPE) + PiT = H + Pi = H + + centroids, _ = solve_lloyd_max(D, cfg.centroid_bits) + centroids = centroids.float().to(device) + c_sorted, _ = centroids.sort() + midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device) + + torch.manual_seed(42) + # Store multiple tokens + keys = torch.randn(seq_len, Hk, D, device=device, dtype=torch.float16) + values = torch.randn(seq_len, Hk, D, device=device, dtype=torch.float16) + + padded_slot = cfg.slot_size_aligned + kv_cache = torch.zeros( + num_blocks, + block_size, + Hk, + padded_slot, + device=device, + dtype=torch.uint8, + ) + slot_mapping = torch.arange(seq_len, device=device, dtype=torch.int32) + + triton_turboquant_store( + keys, + values, + kv_cache, + slot_mapping, + PiT, + midpoints, + mse_bits=cfg.key_mse_bits, + key_packed_size=cfg.key_packed_size, + value_quant_bits=cfg.effective_value_quant_bits, + key_fp8=cfg.key_fp8, + ) + + # Decode: use last key as query for each batch + query_keys = keys[-B:] # [B, Hk, D] + query = ( + query_keys[:, :, None, :] + .expand(B, Hk, kv_group_size, D) + .reshape(B, Hq, D) + .contiguous() + .to(torch.float16) + ) + + block_table = ( + torch.arange(num_blocks, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(B, -1) + .contiguous() + ) + seq_lens = torch.full((B,), seq_len, device=device, dtype=torch.int32) + + output = triton_turboquant_decode_attention( + query=query, + kv_cache=kv_cache, + block_table=block_table, + seq_lens=seq_lens, + Pi=Pi, + centroids=centroids, + scale=1.0 / math.sqrt(D), + mse_bits=cfg.key_mse_bits, + key_packed_size=cfg.key_packed_size, + value_quant_bits=cfg.effective_value_quant_bits, + key_fp8=cfg.key_fp8, + norm_correction=cfg.norm_correction, + PiT=PiT, + max_num_kv_splits=8, + ) + + # Grouped Q heads sharing same KV head should produce similar + # outputs. Check that output is finite and has reasonable norm. + assert output.isfinite().all(), ( + f"Preset {preset} GQA={kv_group_size}: non-finite output" + ) + out_norms = output.float().norm(dim=-1) + assert (out_norms > 0.01).all(), ( + f"Preset {preset} GQA={kv_group_size}: near-zero output" + ) + + # Q heads within same GQA group used the same query key, + # so their outputs should be identical (same KV, same Q). + out_fp32 = output.float() + for b in range(B): + for kh in range(Hk): + base_h = kh * kv_group_size + ref = out_fp32[b, base_h] + for g in range(1, kv_group_size): + h = base_h + g + cos = torch.nn.functional.cosine_similarity( + ref.unsqueeze(0), out_fp32[b, h].unsqueeze(0) + ).item() + assert cos > 0.99, ( + f"Preset {preset} GQA={kv_group_size} " + f"batch={b} heads {base_h} vs {h}: " + f"cosine={cos:.4f} (expected >0.99 for same query)" + ) + + def test_grouped_vs_original_kernel_k8v4(self): + """Direct A/B of grouped vs scalar kernel on turboquant_k8v4. + + Forces both kernels on the same inputs and verifies outputs match + within fp16 tl.dot precision tolerance. This is the primary + correctness check for the grouped kernel change. + """ + preset = "turboquant_k8v4" + from vllm.model_executor.layers.quantization.turboquant.centroids import ( + solve_lloyd_max, + ) + from vllm.v1.attention.ops.triton_turboquant_decode import ( + _fwd_kernel_stage2, + _get_layout, + _tq_decode_stage1, + _tq_grouped_decode_stage1, + _use_fp8_e4b15, + ) + from vllm.v1.attention.ops.triton_turboquant_store import ( + triton_turboquant_store, + ) + + cfg = TurboQuantConfig.from_cache_dtype(preset, head_dim=128) + D = 128 + Hk = 4 + kv_group_size = 4 + Hq = Hk * kv_group_size + B = 2 + seq_len = 48 + block_size = 16 + num_blocks = (seq_len + block_size - 1) // block_size + NUM_KV_SPLITS = 8 + device = torch.device(DEVICE_TYPE) + + H = _build_hadamard(D, DEVICE_TYPE) + PiT = H + + centroids, _ = solve_lloyd_max(D, cfg.centroid_bits) + centroids = centroids.float().to(device) + c_sorted, _ = centroids.sort() + midpoints = ((c_sorted[:-1] + c_sorted[1:]) / 2).to(device) + + torch.manual_seed(99) + keys = torch.randn(seq_len, Hk, D, device=device, dtype=torch.float16) + values = torch.randn(seq_len, Hk, D, device=device, dtype=torch.float16) + + padded_slot = cfg.slot_size_aligned + kv_cache = torch.zeros( + num_blocks, + block_size, + Hk, + padded_slot, + device=device, + dtype=torch.uint8, + ) + slot_mapping = torch.arange(seq_len, device=device, dtype=torch.int32) + triton_turboquant_store( + keys, + values, + kv_cache, + slot_mapping, + PiT, + midpoints, + mse_bits=cfg.key_mse_bits, + key_packed_size=cfg.key_packed_size, + value_quant_bits=cfg.effective_value_quant_bits, + key_fp8=cfg.key_fp8, + ) + + torch.manual_seed(77) + query = torch.randn(B, Hq, D, device=device, dtype=torch.float16) + + if cfg.key_fp8: + q_rot = query.contiguous() + else: + q_rot = (query.float() @ PiT).contiguous() + + layout = _get_layout( + D, + cfg.key_mse_bits, + cfg.effective_value_quant_bits, + cfg.key_packed_size, + ) + fp8_e4b15 = _use_fp8_e4b15(device.index or 0) + + block_table = ( + torch.arange(num_blocks, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(B, -1) + .contiguous() + ) + seq_lens = torch.full((B,), seq_len, device=device, dtype=torch.int32) + + # --- Run original (scalar) kernel --- + mid_o_orig = torch.empty( + B, + Hq, + NUM_KV_SPLITS, + D + 1, + dtype=torch.float32, + device=device, + ) + grid_orig = (B, Hq, NUM_KV_SPLITS) + _tq_decode_stage1[grid_orig]( + q_rot, + kv_cache, + block_table, + seq_lens, + centroids, + mid_o_orig, + q_rot.stride(0), + q_rot.stride(1), + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + block_table.stride(0), + mid_o_orig.stride(0), + mid_o_orig.stride(1), + mid_o_orig.stride(2), + NUM_KV_HEADS=Hk, + HEAD_DIM=D, + BLOCK_SIZE=block_size, + NUM_KV_SPLITS=NUM_KV_SPLITS, + KV_GROUP_SIZE=kv_group_size, + MSE_BITS=cfg.key_mse_bits, + MSE_BYTES=layout["mse_bytes"], + KPS=cfg.key_packed_size, + VQB=cfg.effective_value_quant_bits, + VAL_DATA_BYTES=layout["val_data_bytes"], + ATTN_SCALE=1.0 / math.sqrt(D), + BLOCK_D=layout["BLOCK_D"], + BLOCK_KV=4, + KEY_FP8=1 if cfg.key_fp8 else 0, + NORM_CORRECTION=1 if cfg.norm_correction else 0, + FP8_E4B15=fp8_e4b15, + num_warps=1, + num_stages=1, + ) + out_orig = torch.empty(B, Hq, D, dtype=torch.float32, device=device) + lse_orig = torch.empty(B, Hq, dtype=torch.float32, device=device) + _fwd_kernel_stage2[(B, Hq)]( + mid_o_orig, + out_orig, + lse_orig, + seq_lens, + mid_o_orig.stride(0), + mid_o_orig.stride(1), + mid_o_orig.stride(2), + out_orig.stride(0), + out_orig.stride(1), + lse_orig.stride(0), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=layout["BLOCK_D"], + Lv=D, + num_warps=4, + num_stages=2, + ) + + # --- Run grouped kernel --- + import triton as _triton + + BLOCK_H = 16 + VALID_BH = min(BLOCK_H, kv_group_size) + head_groups = _triton.cdiv(Hq, VALID_BH) + + mid_o_grouped = torch.empty( + B, + Hq, + NUM_KV_SPLITS, + D + 1, + dtype=torch.float32, + device=device, + ) + grid_grouped = (B, head_groups, NUM_KV_SPLITS) + _tq_grouped_decode_stage1[grid_grouped]( + q_rot, + kv_cache, + block_table, + seq_lens, + centroids, + mid_o_grouped, + q_rot.stride(0), + q_rot.stride(1), + kv_cache.stride(0), + kv_cache.stride(1), + kv_cache.stride(2), + block_table.stride(0), + mid_o_grouped.stride(0), + mid_o_grouped.stride(1), + mid_o_grouped.stride(2), + NUM_KV_HEADS=Hk, + HEAD_DIM=D, + BLOCK_SIZE=block_size, + NUM_KV_SPLITS=NUM_KV_SPLITS, + KV_GROUP_SIZE=kv_group_size, + Q_HEAD_NUM=Hq, + MSE_BITS=cfg.key_mse_bits, + MSE_BYTES=layout["mse_bytes"], + KPS=cfg.key_packed_size, + VQB=cfg.effective_value_quant_bits, + VAL_DATA_BYTES=layout["val_data_bytes"], + ATTN_SCALE=1.0 / math.sqrt(D), + BLOCK_D=layout["BLOCK_D"], + BLOCK_KV=16, + BLOCK_H=BLOCK_H, + KEY_FP8=1 if cfg.key_fp8 else 0, + NORM_CORRECTION=1 if cfg.norm_correction else 0, + FP8_E4B15=fp8_e4b15, + num_warps=4, + num_stages=2, + ) + out_grouped = torch.empty(B, Hq, D, dtype=torch.float32, device=device) + lse_grouped = torch.empty(B, Hq, dtype=torch.float32, device=device) + _fwd_kernel_stage2[(B, Hq)]( + mid_o_grouped, + out_grouped, + lse_grouped, + seq_lens, + mid_o_grouped.stride(0), + mid_o_grouped.stride(1), + mid_o_grouped.stride(2), + out_grouped.stride(0), + out_grouped.stride(1), + lse_grouped.stride(0), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=layout["BLOCK_D"], + Lv=D, + num_warps=4, + num_stages=2, + ) + + # Compare: cosine similarity per head should be very high. + # fp16 tl.dot introduces minor precision diff, so allow small gap. + for b in range(B): + for h in range(Hq): + cos = torch.nn.functional.cosine_similarity( + out_orig[b, h].unsqueeze(0), + out_grouped[b, h].unsqueeze(0), + ).item() + threshold = 0.98 if cfg.key_fp8 else 0.95 + assert cos > threshold, ( + f"Preset {preset} batch={b} head={h}: " + f"orig vs grouped cosine={cos:.4f} < {threshold}" + ) From 89991de14499cf23836adfbc18d4f7d858853da8 Mon Sep 17 00:00:00 2001 From: hoseung-kim Date: Fri, 24 Apr 2026 09:21:42 +0000 Subject: [PATCH 3/3] refactor: remove dead code, unused variables Signed-off-by: hoseung-kim --- tests/quantization/test_turboquant.py | 21 +- .../attention/ops/triton_turboquant_decode.py | 217 +++++------------- 2 files changed, 67 insertions(+), 171 deletions(-) diff --git a/tests/quantization/test_turboquant.py b/tests/quantization/test_turboquant.py index f5c8fdff1c23..d4c6cb546c5b 100644 --- a/tests/quantization/test_turboquant.py +++ b/tests/quantization/test_turboquant.py @@ -543,7 +543,7 @@ def test_single_token_roundtrip(self, preset): f"Preset {preset} head {h}: cosine_sim={cos_sim:.4f} < {threshold}" ) - @pytest.mark.parametrize("kv_group_size", [4, 8]) + @pytest.mark.parametrize("kv_group_size", [4, 8, 24]) def test_gqa_roundtrip_k8v4(self, kv_group_size): """GQA round-trip for the grouped decode kernel path. @@ -573,9 +573,7 @@ def test_gqa_roundtrip_k8v4(self, kv_group_size): device = torch.device(DEVICE_TYPE) - H = _build_hadamard(D, DEVICE_TYPE) - PiT = H - Pi = H + PiT = _build_hadamard(D, DEVICE_TYPE) centroids, _ = solve_lloyd_max(D, cfg.centroid_bits) centroids = centroids.float().to(device) @@ -634,7 +632,7 @@ def test_gqa_roundtrip_k8v4(self, kv_group_size): kv_cache=kv_cache, block_table=block_table, seq_lens=seq_lens, - Pi=Pi, + Pi=PiT, centroids=centroids, scale=1.0 / math.sqrt(D), mse_bits=cfg.key_mse_bits, @@ -708,8 +706,7 @@ def test_grouped_vs_original_kernel_k8v4(self): NUM_KV_SPLITS = 8 device = torch.device(DEVICE_TYPE) - H = _build_hadamard(D, DEVICE_TYPE) - PiT = H + PiT = _build_hadamard(D, DEVICE_TYPE) centroids, _ = solve_lloyd_max(D, cfg.centroid_bits) centroids = centroids.float().to(device) @@ -836,8 +833,8 @@ def test_grouped_vs_original_kernel_k8v4(self): import triton as _triton BLOCK_H = 16 - VALID_BH = min(BLOCK_H, kv_group_size) - head_groups = _triton.cdiv(Hq, VALID_BH) + heads_per_kv_head = _triton.cdiv(kv_group_size, BLOCK_H) + head_groups = Hk * heads_per_kv_head mid_o_grouped = torch.empty( B, @@ -853,7 +850,6 @@ def test_grouped_vs_original_kernel_k8v4(self): kv_cache, block_table, seq_lens, - centroids, mid_o_grouped, q_rot.stride(0), q_rot.stride(1), @@ -864,14 +860,11 @@ def test_grouped_vs_original_kernel_k8v4(self): mid_o_grouped.stride(0), mid_o_grouped.stride(1), mid_o_grouped.stride(2), - NUM_KV_HEADS=Hk, HEAD_DIM=D, BLOCK_SIZE=block_size, NUM_KV_SPLITS=NUM_KV_SPLITS, KV_GROUP_SIZE=kv_group_size, Q_HEAD_NUM=Hq, - MSE_BITS=cfg.key_mse_bits, - MSE_BYTES=layout["mse_bytes"], KPS=cfg.key_packed_size, VQB=cfg.effective_value_quant_bits, VAL_DATA_BYTES=layout["val_data_bytes"], @@ -879,8 +872,6 @@ def test_grouped_vs_original_kernel_k8v4(self): BLOCK_D=layout["BLOCK_D"], BLOCK_KV=16, BLOCK_H=BLOCK_H, - KEY_FP8=1 if cfg.key_fp8 else 0, - NORM_CORRECTION=1 if cfg.norm_correction else 0, FP8_E4B15=fp8_e4b15, num_warps=4, num_stages=2, diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index 55d7d152fc6c..9372a02b8d14 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -324,7 +324,6 @@ def _tq_grouped_decode_stage1( KV_cache_ptr, Block_table_ptr, Seq_lens_ptr, - Centroids_ptr, Mid_o_ptr, stride_qb, stride_qh, @@ -335,14 +334,11 @@ def _tq_grouped_decode_stage1( stride_mid_b, stride_mid_h, stride_mid_s, - NUM_KV_HEADS: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_SIZE: tl.constexpr, NUM_KV_SPLITS: tl.constexpr, KV_GROUP_SIZE: tl.constexpr, Q_HEAD_NUM: tl.constexpr, - MSE_BITS: tl.constexpr, - MSE_BYTES: tl.constexpr, KPS: tl.constexpr, VQB: tl.constexpr, VAL_DATA_BYTES: tl.constexpr, @@ -350,25 +346,31 @@ def _tq_grouped_decode_stage1( BLOCK_D: tl.constexpr, BLOCK_KV: tl.constexpr, BLOCK_H: tl.constexpr, - KEY_FP8: tl.constexpr, - NORM_CORRECTION: tl.constexpr = 0, FP8_E4B15: tl.constexpr = 0, ): - """GQA-grouped TQ decode stage1. + """GQA-grouped TQ decode stage1 for the FP8 key path (k8v4). + + Each CTA processes up to BLOCK_H Q heads that share one KV head, + loading K/V once and computing scores via `tl.dot`. - Each CTA processes min(BLOCK_H, KV_GROUP_SIZE) Q heads that share - one KV head, loading K/V once and computing scores via tl.dot. + Scoped to FP8 keys + 4-bit values: the MSE-quantized key presets + (`turboquant_{4bit,k3v4,3bit}_nc`) retain the original scalar + kernel because their per-token dequant regresses with BLOCK_KV=16. """ bid = tl.program_id(0) head_group_id = tl.program_id(1) sid = tl.program_id(2) - # Map head_group_id → KV head + Q head range - VALID_BLOCK_H: tl.constexpr = BLOCK_H if KV_GROUP_SIZE > BLOCK_H else KV_GROUP_SIZE - kv_head = head_group_id // tl.cdiv(KV_GROUP_SIZE, BLOCK_H) - cur_head = head_group_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) - mask_h = cur_head < (head_group_id + 1) * VALID_BLOCK_H - mask_h = mask_h & (cur_head < Q_HEAD_NUM) + # Map head_group_id → KV head + Q head range. + # CTAs are partitioned per KV head: each KV head owns + # `heads_per_kv_head = ceil(KV_GROUP_SIZE / BLOCK_H)` consecutive CTAs. + # This keeps every CTA confined to a single KV head even when + # KV_GROUP_SIZE > BLOCK_H and not a multiple of it. + heads_per_kv_head: tl.constexpr = tl.cdiv(KV_GROUP_SIZE, BLOCK_H) + kv_head = head_group_id // heads_per_kv_head + group_idx = head_group_id % heads_per_kv_head + cur_head = kv_head * KV_GROUP_SIZE + group_idx * BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = (cur_head < (kv_head + 1) * KV_GROUP_SIZE) & (cur_head < Q_HEAD_NUM) seq_len = tl.load(Seq_lens_ptr + bid) split_len = tl.cdiv(seq_len, NUM_KV_SPLITS) @@ -393,18 +395,6 @@ def _tq_grouped_decode_stage1( other=0.0, ).to(tl.float32) - # Precompute MSE bit/byte index vectors (loop-invariant) - if not KEY_FP8: - mse_bit_off = d_offs * MSE_BITS - mse_byte_idx = mse_bit_off // 8 - mse_bit_shift = mse_bit_off % 8 - mse_mask = (1 << MSE_BITS) - 1 - - if VQB == 3: - val_bit_off = d_offs * 3 - val_byte_idx = val_bit_off // 8 - val_bit_shift = val_bit_off % 8 - # Online softmax accumulators: [BLOCK_H] m_prev = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") l_prev = tl.zeros([BLOCK_H], dtype=tl.float32) @@ -429,68 +419,24 @@ def _tq_grouped_decode_stage1( ) # ============================================================ - # K DEQUANT → k_float [BLOCK_KV, BLOCK_D] + # K DEQUANT → k_float [BLOCK_KV, BLOCK_D] (FP8 only; MSE keys use + # the original scalar kernel). # ============================================================ - if KEY_FP8: - k_addrs = slot_bases[:, None] + d_offs[None, :] - k_raw = tl.load( - KV_cache_ptr + k_addrs, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ) - if FP8_E4B15: - # SM < 8.9: SW emulation requires float32 intermediate - k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) - else: - k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) - - # scores = q_rot @ k_float^T : [BLOCK_H, BLOCK_KV] - scores = tl.dot(q_rot.to(tl.float16), tl.trans(k_float.to(tl.float16))) - scores = (scores * ATTN_SCALE).to(tl.float32) - scores = tl.where(mask_h[:, None] & kv_mask[None, :], scores, -float("inf")) + k_addrs = slot_bases[:, None] + d_offs[None, :] + k_raw = tl.load( + KV_cache_ptr + k_addrs, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ) + if FP8_E4B15: + k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) else: - # MSE unpack → centroid gather → k_dequant [BLOCK_KV, BLOCK_D] - mse_addrs0 = slot_bases[:, None] + mse_byte_idx[None, :] - mse_raw0 = tl.load( - KV_cache_ptr + mse_addrs0, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ).to(tl.int32) - mse_raw1 = tl.load( - KV_cache_ptr + mse_addrs0 + 1, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ).to(tl.int32) - raw16 = mse_raw0 | (mse_raw1 << 8) - mse_idx = (raw16 >> mse_bit_shift[None, :]) & mse_mask + k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) - c_vals = tl.load( - Centroids_ptr + mse_idx, - mask=kv_mask[:, None] & d_mask[None, :], - other=0.0, - ) - - if NORM_CORRECTION: - c_norm_sq = tl.sum( - tl.where(d_mask[None, :], c_vals * c_vals, 0.0), axis=1 - ) - c_inv_norm = 1.0 / tl.sqrt(c_norm_sq + 1e-16) - c_vals = c_vals * c_inv_norm[:, None] - - # term1 = q_rot @ c_vals^T : [BLOCK_H, BLOCK_KV] - term1 = tl.dot(q_rot.to(tl.float16), tl.trans(c_vals.to(tl.float16))) - - norm_bases = slot_bases + MSE_BYTES - n_lo = tl.load(KV_cache_ptr + norm_bases, mask=kv_mask, other=0).to( - tl.uint16 - ) - n_hi = tl.load(KV_cache_ptr + norm_bases + 1, mask=kv_mask, other=0).to( - tl.uint16 - ) - vec_norms = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - - scores = vec_norms[None, :] * term1.to(tl.float32) * ATTN_SCALE - scores = tl.where(mask_h[:, None] & kv_mask[None, :], scores, -float("inf")) + # scores = q_rot @ k_float^T : [BLOCK_H, BLOCK_KV] + scores = tl.dot(q_rot.to(tl.float16), tl.trans(k_float.to(tl.float16))) + scores = (scores * ATTN_SCALE).to(tl.float32) + scores = tl.where(mask_h[:, None] & kv_mask[None, :], scores, -float("inf")) # ============================================================ # ONLINE SOFTMAX: [BLOCK_H] @@ -500,72 +446,36 @@ def _tq_grouped_decode_stage1( p = tl.exp(scores - n_e_max[:, None]) # ============================================================ - # V DEQUANT → values [BLOCK_KV, BLOCK_D] + # V DEQUANT → values [BLOCK_KV, BLOCK_D] (4-bit uniform; VQB==3 + # is an MSE-only path handled by the original scalar kernel). # ============================================================ + tl.static_assert(VQB == 4, "grouped kernel only supports 4-bit values") val_bases = slot_bases + KPS - if VQB == 3: - val_addrs0 = val_bases[:, None] + val_byte_idx[None, :] - val_raw0 = tl.load( - KV_cache_ptr + val_addrs0, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ).to(tl.int32) - val_raw1 = tl.load( - KV_cache_ptr + val_addrs0 + 1, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ).to(tl.int32) - raw16_val = val_raw0 | (val_raw1 << 8) - v_idx = ((raw16_val >> val_bit_shift[None, :]) & 0x7).to(tl.float32) - - sc_bases = val_bases + VAL_DATA_BYTES - sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( - tl.uint16 - ) - sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( - tl.uint16 - ) - v_scales = ( - (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - ) - zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( - tl.uint16 - ) - zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( - tl.uint16 - ) - v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - values = v_idx * v_scales[:, None] + v_zeros[:, None] - else: # VQB == 4 - vb_idx = d_offs // 2 - vb_shift = (d_offs % 2) * 4 - val_addrs = val_bases[:, None] + vb_idx[None, :] - val_raw = tl.load( - KV_cache_ptr + val_addrs, - mask=kv_mask[:, None] & d_mask[None, :], - other=0, - ).to(tl.int32) - v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32) + vb_idx = d_offs // 2 + vb_shift = (d_offs % 2) * 4 + val_addrs = val_bases[:, None] + vb_idx[None, :] + val_raw = tl.load( + KV_cache_ptr + val_addrs, + mask=kv_mask[:, None] & d_mask[None, :], + other=0, + ).to(tl.int32) + v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32) - sc_bases = val_bases + VAL_DATA_BYTES - sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to( - tl.uint16 - ) - sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( - tl.uint16 - ) - v_scales = ( - (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - ) - zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( - tl.uint16 - ) - zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( - tl.uint16 - ) - v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) - values = v_idx * v_scales[:, None] + v_zeros[:, None] + sc_bases = val_bases + VAL_DATA_BYTES + sc_lo = tl.load(KV_cache_ptr + sc_bases, mask=kv_mask, other=0).to(tl.uint16) + sc_hi = tl.load(KV_cache_ptr + sc_bases + 1, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_scales = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + zr_lo = tl.load(KV_cache_ptr + sc_bases + 2, mask=kv_mask, other=0).to( + tl.uint16 + ) + zr_hi = tl.load(KV_cache_ptr + sc_bases + 3, mask=kv_mask, other=0).to( + tl.uint16 + ) + v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32) + values = v_idx * v_scales[:, None] + v_zeros[:, None] # ============================================================ # ACCUMULATE: acc += p @ values via tl.dot @@ -836,8 +746,8 @@ def triton_turboquant_decode_attention( fp8_e4b15 = _use_fp8_e4b15(device.index or 0) BLOCK_H = 16 BLOCK_KV_GROUPED = 16 - VALID_BLOCK_H = min(BLOCK_H, kv_group_size) - head_groups = triton.cdiv(Hq, VALID_BLOCK_H) + heads_per_kv_head = triton.cdiv(kv_group_size, BLOCK_H) + head_groups = Hk * heads_per_kv_head if kv_group_size > 1 and key_fp8: grid = (B, head_groups, NUM_KV_SPLITS) @@ -846,7 +756,6 @@ def triton_turboquant_decode_attention( kv_cache, block_table, seq_lens, - centroids, mid_o, q_rot.stride(0), q_rot.stride(1), @@ -857,14 +766,11 @@ def triton_turboquant_decode_attention( mid_o.stride(0), mid_o.stride(1), mid_o.stride(2), - NUM_KV_HEADS=Hk, HEAD_DIM=D, BLOCK_SIZE=block_size, NUM_KV_SPLITS=NUM_KV_SPLITS, KV_GROUP_SIZE=kv_group_size, Q_HEAD_NUM=Hq, - MSE_BITS=mse_bits, - MSE_BYTES=cfg["mse_bytes"], KPS=key_packed_size, VQB=value_quant_bits, VAL_DATA_BYTES=cfg["val_data_bytes"], @@ -872,13 +778,12 @@ def triton_turboquant_decode_attention( BLOCK_D=cfg["BLOCK_D"], BLOCK_KV=BLOCK_KV_GROUPED, BLOCK_H=BLOCK_H, - KEY_FP8=1 if key_fp8 else 0, - NORM_CORRECTION=1 if norm_correction else 0, FP8_E4B15=fp8_e4b15, num_warps=4, num_stages=2, ) else: + # MHA (kv_group_size==1): use original scalar kernel BLOCK_KV = 4 grid = (B, Hq, NUM_KV_SPLITS) _tq_decode_stage1[grid](