Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions tests/kernels/attention/test_deepgemm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
calc_diff,
fp8_mqa_logits,
fp8_paged_mqa_logits,
fp8_paged_mqa_logits_torch,
get_num_sms,
get_paged_mqa_logits_metadata,
)
Expand Down Expand Up @@ -200,6 +201,89 @@ def _ref_fp8_paged_mqa_logits(
return logits


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
def test_fp8_paged_mqa_logits_torch_matches_deepgemm():
torch.manual_seed(0)
random.seed(0)

max_model_len = 4096
num_blocks, blocksize = max_model_len * 2, 64
batch_size, next_n = 3, 2
heads, index_dim = 32, 128

q = torch.randn(
(batch_size, next_n, heads, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
kv_cache = torch.randn(
(num_blocks, blocksize, 1, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
weights = torch.randn(
(batch_size * next_n, heads),
device="cuda",
dtype=torch.float32,
)
context_lens = torch.tensor([1537, 2049, 3073], device="cuda", dtype=torch.int32)
max_block_len = cdiv(int(context_lens.max().item()), blocksize)
block_tables = torch.zeros(
(batch_size, max_block_len),
device="cuda",
dtype=torch.int32,
)

block_idx_pool = list(range(num_blocks))
random.shuffle(block_idx_pool)
counter = 0
for i, ctx_len in enumerate(context_lens.tolist()):
for j in range(cdiv(ctx_len, blocksize)):
block_tables[i, j] = block_idx_pool[counter]
counter += 1

q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms()
)

fallback_logits = fp8_paged_mqa_logits_torch(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
max_model_len,
)
deepgemm_logits = fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
schedule_metadata,
max_model_len,
clean_logits=True,
)

positions = torch.arange(max_model_len, device="cuda").unsqueeze(0)
row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n
next_n_offset = torch.arange(batch_size * next_n, device="cuda") % next_n
valid_mask = positions <= (
context_lens[row_indices] - next_n + next_n_offset
).unsqueeze(1)

fallback_logits = fallback_logits.masked_fill(~valid_mask, 0)
deepgemm_logits = deepgemm_logits.masked_fill(~valid_mask, 0)
diff = calc_diff(fallback_logits, deepgemm_logits)
assert diff < 1e-3, f"{diff=}"


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
Expand Down Expand Up @@ -298,3 +382,111 @@ def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
ref_logits = ref_logits.masked_fill(~mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
def test_fp8_paged_mqa_logits_torch_cuda_graph_capture():
torch.manual_seed(0)
random.seed(0)

batch_size, next_n = 2, 2
heads, index_dim = 4, 16
max_model_len, blocksize = 128, 16
num_blocks = 32

q = torch.randn(
(batch_size, next_n, heads, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
kv_cache = torch.randn(
(num_blocks, blocksize, 1, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
weights = torch.randn(
(batch_size * next_n, heads),
device="cuda",
dtype=torch.float32,
)
context_lens = torch.tensor([33, 58], device="cuda", dtype=torch.int32)
block_tables = torch.zeros(
(batch_size, cdiv(max_model_len, blocksize)),
device="cuda",
dtype=torch.int32,
)

next_block = 0
for i, ctx_len in enumerate(context_lens.tolist()):
num_ctx_blocks = cdiv(ctx_len, blocksize)
block_tables[i, :num_ctx_blocks] = torch.tensor(
list(range(next_block, next_block + num_ctx_blocks)),
device="cuda",
dtype=torch.int32,
)
next_block += num_ctx_blocks

q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)

expected = fp8_paged_mqa_logits_torch(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
max_model_len,
)
torch.accelerator.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
logits = fp8_paged_mqa_logits_torch(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
max_model_len,
)

# Update inputs in-place after capture and verify replay uses the new values.
q_new = torch.randn_like(q)
kv_cache_new = torch.randn_like(kv_cache)
weights_new = torch.randn_like(weights)
context_lens_new = torch.tensor([41, 71], device="cuda", dtype=torch.int32)
block_tables_new = torch.zeros_like(block_tables)

next_block = 0
for i, ctx_len in enumerate(context_lens_new.tolist()):
num_ctx_blocks = cdiv(ctx_len, blocksize)
block_tables_new[i, :num_ctx_blocks] = torch.tensor(
list(range(next_block, next_block + num_ctx_blocks)),
device="cuda",
dtype=torch.int32,
)
next_block += num_ctx_blocks

q_fp8.copy_(q_new.to(torch.float8_e4m3fn))
kv_cache_fp8.copy_(kv_cache_cast_to_fp8(kv_cache_new))
weights.copy_(weights_new)
context_lens.copy_(context_lens_new)
block_tables.copy_(block_tables_new)

expected_updated = fp8_paged_mqa_logits_torch(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
max_model_len,
)

logits.zero_()
graph.replay()
torch.accelerator.synchronize()

torch.testing.assert_close(logits, expected_updated, equal_nan=True)

# Sanity check: output changed after replacing inputs.
assert not torch.allclose(expected, expected_updated, equal_nan=True)
56 changes: 10 additions & 46 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,13 +472,11 @@ def fp8_paged_mqa_logits_torch(
block_tables: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache (CUDA fallback).

This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
Handles head_dim = 132 (128 + 4 for RoPE).
"""Compute FP8 MQA logits using paged KV-cache (Triton kernel).

Args:
q: Query tensor of shape [B, next_n, H, D].
q: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
Expand All @@ -493,48 +491,14 @@ def fp8_paged_mqa_logits_torch(
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
fp8_dtype = current_platform.fp8_dtype()
batch_size, next_n, heads, dim = q.size()
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
scale = scale.contiguous().view(torch.float)
q = q.float()
kv_cache = kv_cache.view(fp8_dtype).float() * scale
num_blocks, block_size, _, dim = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
# Import here to avoid circular dependency
from vllm.v1.attention.ops.triton_fp8_paged_mqa_logits import (
fp8_paged_mqa_logits_triton,
)

return fp8_paged_mqa_logits_triton(
q, kv_cache, weights, context_lens, block_tables, max_model_len
)
for i in range(batch_size):
context_len = context_lens[i].item()
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_idx in range(cdiv(context_len, block_size)):
block_id = block_tables[i][block_idx]
qx, kx = q[i], kv_cache[block_id]
k_offsets = torch.arange(
block_idx * block_size, (block_idx + 1) * block_size, device=q.device
)
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype
),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n : (i + 1) * next_n,
block_idx * block_size : (block_idx + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
return logits


__all__ = [
Expand Down
Loading