diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 2dc522598e4e..e718f403b049 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -11,6 +11,7 @@ calc_diff, fp8_mqa_logits, fp8_paged_mqa_logits, + fp8_paged_mqa_logits_torch, get_num_sms, get_paged_mqa_logits_metadata, ) @@ -200,6 +201,89 @@ def _ref_fp8_paged_mqa_logits( return logits +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) +def test_fp8_paged_mqa_logits_torch_matches_deepgemm(): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + num_blocks, blocksize = max_model_len * 2, 64 + batch_size, next_n = 3, 2 + heads, index_dim = 32, 128 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), + device="cuda", + dtype=torch.float32, + ) + context_lens = torch.tensor([1537, 2049, 3073], device="cuda", dtype=torch.int32) + max_block_len = cdiv(int(context_lens.max().item()), blocksize) + block_tables = torch.zeros( + (batch_size, max_block_len), + device="cuda", + dtype=torch.int32, + ) + + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + counter = 0 + for i, ctx_len in enumerate(context_lens.tolist()): + for j in range(cdiv(ctx_len, blocksize)): + block_tables[i, j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + + fallback_logits = fp8_paged_mqa_logits_torch( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + max_model_len, + ) + deepgemm_logits = fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + + positions = torch.arange(max_model_len, device="cuda").unsqueeze(0) + row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n + next_n_offset = torch.arange(batch_size * next_n, device="cuda") % next_n + valid_mask = positions <= ( + context_lens[row_indices] - next_n + next_n_offset + ).unsqueeze(1) + + fallback_logits = fallback_logits.masked_fill(~valid_mask, 0) + deepgemm_logits = deepgemm_logits.masked_fill(~valid_mask, 0) + diff = calc_diff(fallback_logits, deepgemm_logits) + assert diff < 1e-3, f"{diff=}" + + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") @pytest.mark.skipif( @@ -298,3 +382,111 @@ def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool): ref_logits = ref_logits.masked_fill(~mask, 0) diff = calc_diff(logits, ref_logits) assert diff < 1e-3, f"{diff=}" + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +def test_fp8_paged_mqa_logits_torch_cuda_graph_capture(): + torch.manual_seed(0) + random.seed(0) + + batch_size, next_n = 2, 2 + heads, index_dim = 4, 16 + max_model_len, blocksize = 128, 16 + num_blocks = 32 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), + device="cuda", + dtype=torch.float32, + ) + context_lens = torch.tensor([33, 58], device="cuda", dtype=torch.int32) + block_tables = torch.zeros( + (batch_size, cdiv(max_model_len, blocksize)), + device="cuda", + dtype=torch.int32, + ) + + next_block = 0 + for i, ctx_len in enumerate(context_lens.tolist()): + num_ctx_blocks = cdiv(ctx_len, blocksize) + block_tables[i, :num_ctx_blocks] = torch.tensor( + list(range(next_block, next_block + num_ctx_blocks)), + device="cuda", + dtype=torch.int32, + ) + next_block += num_ctx_blocks + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + expected = fp8_paged_mqa_logits_torch( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + max_model_len, + ) + torch.accelerator.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + logits = fp8_paged_mqa_logits_torch( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + max_model_len, + ) + + # Update inputs in-place after capture and verify replay uses the new values. + q_new = torch.randn_like(q) + kv_cache_new = torch.randn_like(kv_cache) + weights_new = torch.randn_like(weights) + context_lens_new = torch.tensor([41, 71], device="cuda", dtype=torch.int32) + block_tables_new = torch.zeros_like(block_tables) + + next_block = 0 + for i, ctx_len in enumerate(context_lens_new.tolist()): + num_ctx_blocks = cdiv(ctx_len, blocksize) + block_tables_new[i, :num_ctx_blocks] = torch.tensor( + list(range(next_block, next_block + num_ctx_blocks)), + device="cuda", + dtype=torch.int32, + ) + next_block += num_ctx_blocks + + q_fp8.copy_(q_new.to(torch.float8_e4m3fn)) + kv_cache_fp8.copy_(kv_cache_cast_to_fp8(kv_cache_new)) + weights.copy_(weights_new) + context_lens.copy_(context_lens_new) + block_tables.copy_(block_tables_new) + + expected_updated = fp8_paged_mqa_logits_torch( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + max_model_len, + ) + + logits.zero_() + graph.replay() + torch.accelerator.synchronize() + + torch.testing.assert_close(logits, expected_updated, equal_nan=True) + + # Sanity check: output changed after replacing inputs. + assert not torch.allclose(expected, expected_updated, equal_nan=True) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index ee104a6cc75c..dce1451e7904 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -472,13 +472,11 @@ def fp8_paged_mqa_logits_torch( block_tables: torch.Tensor, max_model_len: int, ) -> torch.Tensor: - """Compute FP8 MQA logits using paged KV-cache (CUDA fallback). - - This is a pure PyTorch fallback for CUDA when DeepGEMM is not available. - Handles head_dim = 132 (128 + 4 for RoPE). + """Compute FP8 MQA logits using paged KV-cache (Triton kernel). Args: - q: Query tensor of shape [B, next_n, H, D]. + q: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. kv_cache: Paged KV-cache in packed FP8+scale layout with shape [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last 4 bytes per (block,pos) store the `float` dequant scale. @@ -493,48 +491,14 @@ def fp8_paged_mqa_logits_torch( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ - fp8_dtype = current_platform.fp8_dtype() - batch_size, next_n, heads, dim = q.size() - kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] - scale = scale.contiguous().view(torch.float) - q = q.float() - kv_cache = kv_cache.view(fp8_dtype).float() * scale - num_blocks, block_size, _, dim = kv_cache.size() - logits = torch.full( - [batch_size * next_n, max_model_len], - float("-inf"), - device=q.device, - dtype=torch.float32, + # Import here to avoid circular dependency + from vllm.v1.attention.ops.triton_fp8_paged_mqa_logits import ( + fp8_paged_mqa_logits_triton, + ) + + return fp8_paged_mqa_logits_triton( + q, kv_cache, weights, context_lens, block_tables, max_model_len ) - for i in range(batch_size): - context_len = context_lens[i].item() - q_offsets = torch.arange(context_len - next_n, context_len, device=q.device) - weight_slice = ( - weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() - ) - for block_idx in range(cdiv(context_len, block_size)): - block_id = block_tables[i][block_idx] - qx, kx = q[i], kv_cache[block_id] - k_offsets = torch.arange( - block_idx * block_size, (block_idx + 1) * block_size, device=q.device - ) - mask = (k_offsets[None, :] < context_len) & ( - k_offsets[None, :] <= q_offsets[:, None] - ) - s = torch.where( - mask[None, :, :], - (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( - logits.dtype - ), - float("-inf"), - ) - s = torch.relu(s) * weight_slice[..., None] - s = s.sum(dim=0) - logits[ - i * next_n : (i + 1) * next_n, - block_idx * block_size : (block_idx + 1) * block_size, - ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) - return logits __all__ = [ diff --git a/vllm/v1/attention/ops/triton_fp8_paged_mqa_logits.py b/vllm/v1/attention/ops/triton_fp8_paged_mqa_logits.py new file mode 100644 index 000000000000..639ba1d1e38a --- /dev/null +++ b/vllm/v1/attention/ops/triton_fp8_paged_mqa_logits.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton kernel for FP8 paged MQA logits.""" + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + q_ptr, # [B, next_n, H, D], float16 + k_data_ptr, # [num_blocks, block_size, D], uint8(fp8 bitcast) + k_scale_ptr, # [num_blocks, block_size], float32 + weights_ptr, # [B * next_n, H], float32 + context_lens_ptr, # [B], int32 + block_tables_ptr, # [B, max_blocks], int32/int64 + logits_ptr, # [B * next_n, max_model_len], float32 + next_n, + num_heads, + head_dim, + block_size, + max_model_len, + # q strides + stride_q_b, + stride_q_n, + stride_q_h, + stride_q_d, + # k_data strides + stride_kd_blk, + stride_kd_pos, + stride_kd_d, + # k_scale strides + stride_ks_blk, + stride_ks_pos, + # weights strides + stride_w_row, + stride_w_h, + # context_lens stride + stride_ctx, + # block_tables strides + stride_bt_b, + stride_bt_blk, + # logits strides + stride_o_row, + stride_o_col, + BLOCK_K: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_block = tl.program_id(1) + + b = pid_row // next_n + n = pid_row % next_n + + context_len = tl.load(context_lens_ptr + b * stride_ctx) + q_pos = context_len - next_n + n + + block_start = pid_block * block_size + + offs_k = tl.arange(0, BLOCK_K) + offs_d = tl.arange(0, BLOCK_D) + offs_h = tl.arange(0, BLOCK_H) + + kv_pos = block_start + offs_k + in_block = offs_k < block_size + in_ctx = kv_pos < context_len + out_bounds = in_block & (kv_pos < max_model_len) + + block_active = block_start < context_len + physical_block_id = tl.full((), 0, dtype=tl.int64) + if block_active: + physical_block_id = tl.load( + block_tables_ptr + b * stride_bt_b + pid_block * stride_bt_blk + ).to(tl.int64) + + token_valid = block_active & in_block & in_ctx + + # Load K tile [BLOCK_K, BLOCK_D] from packed FP8 bytes. + k_ptrs = ( + k_data_ptr + + physical_block_id * stride_kd_blk + + offs_k[:, None] * stride_kd_pos + + offs_d[None, :] * stride_kd_d + ) + k_mask = token_valid[:, None] & (offs_d[None, :] < head_dim) + k_u8 = tl.load(k_ptrs, mask=k_mask, other=0).to(tl.uint8) + k_vals = k_u8.to(tl.float8e4nv, bitcast=True).to(tl.float16) + + # Load scales [BLOCK_K] and apply dequantization in fp16 for MMA throughput. + k_scale = tl.load( + k_scale_ptr + physical_block_id * stride_ks_blk + offs_k * stride_ks_pos, + mask=token_valid, + other=0.0, + ).to(tl.float16) + k_vals = k_vals * k_scale[:, None] + + # Load Q tile [BLOCK_H, BLOCK_D]. + q_ptrs = ( + q_ptr + + b * stride_q_b + + n * stride_q_n + + offs_h[:, None] * stride_q_h + + offs_d[None, :] * stride_q_d + ) + q_mask = (offs_h[:, None] < num_heads) & (offs_d[None, :] < head_dim) + q_vals = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.float16) + + # Weights [BLOCK_H]. + w = tl.load( + weights_ptr + pid_row * stride_w_row + offs_h * stride_w_h, + mask=offs_h < num_heads, + other=0.0, + ) + + # scores: [BLOCK_H, BLOCK_K] = Q @ K^T + scores = tl.dot(q_vals, tl.trans(k_vals)) + scores = tl.maximum(scores, 0.0) * w[:, None] + acc = tl.sum(scores, axis=0) + + causal = kv_pos <= q_pos + write_valid = token_valid & causal + out_vals = tl.where(write_valid, acc, float("-inf")) + + out_ptrs = logits_ptr + pid_row * stride_o_row + kv_pos * stride_o_col + tl.store(out_ptrs, out_vals, mask=out_bounds) + + +def fp8_paged_mqa_logits_triton( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 paged MQA logits using Triton.""" + if not q_fp8.is_cuda: + raise ValueError("fp8_paged_mqa_logits_triton requires CUDA tensors") + + if q_fp8.ndim != 4: + raise ValueError(f"q_fp8 must be 4D [B, next_n, H, D], got {q_fp8.shape}") + if kv_cache_fp8.ndim != 4: + raise ValueError( + f"kv_cache_fp8 must be 4D [num_blocks, block_size, 1, D+4], " + f"got {kv_cache_fp8.shape}" + ) + + batch_size, next_n, num_heads, head_dim = q_fp8.shape + num_blocks, block_size, _, packed_dim = kv_cache_fp8.shape + + if packed_dim != head_dim + 4: + raise ValueError( + f"kv_cache_fp8 last dim must be head_dim + 4 ({head_dim + 4}), " + f"got {packed_dim}" + ) + + # DeepGEMM-compatible packed layout expects contiguous memory. + if not kv_cache_fp8.is_contiguous(): + kv_cache_fp8 = kv_cache_fp8.contiguous() + + # Convert Q once outside the kernel to avoid repeated per-block conversion. + q = q_fp8.to(torch.float16) + + # Split fused KV cache as zero-copy views: + # [num_blocks, block_size * D] uint8 FP8 bytes + # [num_blocks, block_size] float32 scales + kv_flat = kv_cache_fp8.view(num_blocks, -1) + split = block_size * head_dim + k_data = kv_flat[:, :split].view(num_blocks, block_size, head_dim) + k_scale = kv_flat[:, split:].view(num_blocks, block_size, 4).view(torch.float32) + + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_fp8.device, + dtype=torch.float32, + ) + + block_cols = block_tables.shape[1] + grid = (batch_size * next_n, block_cols) + + block_k = triton.next_power_of_2(block_size) + block_d = triton.next_power_of_2(head_dim) + block_h = triton.next_power_of_2(max(1, num_heads)) + + # Heuristics tuned for HxD(<=64x128)-by-KV(64) decode tiles. + num_warps = 8 if (block_d >= 128 or block_h >= 64) else 4 + + _fp8_paged_mqa_logits_kernel[grid]( + q, + k_data, + k_scale, + weights, + context_lens, + block_tables, + logits, + next_n, + num_heads, + head_dim, + block_size, + max_model_len, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k_data.stride(0), + k_data.stride(1), + k_data.stride(2), + k_scale.stride(0), + k_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens.stride(0), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_K=block_k, + BLOCK_D=block_d, + BLOCK_H=block_h, + num_warps=num_warps, + num_stages=2, + ) + + return logits