Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 22 additions & 50 deletions vllm/model_executor/layers/sparse_attn_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,7 @@
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
fp8_mqa_logits,
fp8_mqa_logits_torch,
fp8_paged_mqa_logits,
fp8_paged_mqa_logits_torch,
is_deep_gemm_supported,
)
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, has_deep_gemm
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
Expand Down Expand Up @@ -114,23 +108,14 @@ def sparse_attn_indexer(
chunk.block_table,
chunk.cu_seq_lens,
)
if is_deep_gemm_supported():
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
clean_logits=False,
)
else:
logits = fp8_mqa_logits_torch(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
clean_logits=False,
)
num_rows = logits.shape[0]

topk_indices = topk_indices_buffer[
Expand Down Expand Up @@ -194,26 +179,16 @@ def sparse_attn_indexer(
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
if is_deep_gemm_supported():
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
clean_logits=False,
)
else:
logits = fp8_paged_mqa_logits_torch(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
max_model_len=max_model_len,
)
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
clean_logits=False,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]

Expand Down Expand Up @@ -333,12 +308,9 @@ def __init__(
self.max_model_len = max_model_len
self.max_total_seq_len = max_total_seq_len
self.topk_indices_buffer = topk_indices_buffer
if current_platform.is_cuda() and not is_deep_gemm_supported():
logger.warning_once(
"DeepGEMM is not supported or available. SparseAttnIndexer will use a "
"less efficient PyTorch implementation. "
"Please make sure you have the required hardware and software setup "
"for DeepGEMM to achieve optimal performance."
if current_platform.is_cuda() and not has_deep_gemm():
raise RuntimeError(
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
)

def forward_native(
Expand Down
121 changes: 0 additions & 121 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,135 +415,14 @@ def should_use_deepgemm_for_fp8_linear(
)


def fp8_mqa_logits_torch(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging (CUDA fallback).

This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.

Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.

Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv_fp8, scale = kv
seq_len_kv = kv_fp8.shape[0]
k = kv_fp8.to(torch.bfloat16)
q = q.to(torch.bfloat16)

mask_lo = (
torch.arange(0, seq_len_kv, device=q.device)[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device=q.device)[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi

score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))

return logits


def fp8_paged_mqa_logits_torch(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache (CUDA fallback).

This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
Handles head_dim = 132 (128 + 4 for RoPE).

Args:
q: Query tensor of shape [B, next_n, H, D].
kv_cache: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
max_model_len: Maximum sequence length used to size the logits output.

Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
fp8_dtype = current_platform.fp8_dtype()
batch_size, next_n, heads, dim = q.size()
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
scale = scale.contiguous().view(torch.float)
q = q.float()
kv_cache = kv_cache.view(fp8_dtype).float() * scale
num_blocks, block_size, _, dim = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
for i in range(batch_size):
context_len = context_lens[i].item()
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_idx in range(cdiv(context_len, block_size)):
block_id = block_tables[i][block_idx]
qx, kx = q[i], kv_cache[block_id]
k_offsets = torch.arange(
block_idx * block_size, (block_idx + 1) * block_size, device=q.device
)
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype
),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n : (i + 1) * next_n,
block_idx * block_size : (block_idx + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
return logits


__all__ = [
"calc_diff",
"DeepGemmQuantScaleFMT",
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"fp8_mqa_logits",
"fp8_mqa_logits_torch",
"fp8_paged_mqa_logits",
"fp8_paged_mqa_logits_torch",
"get_paged_mqa_logits_metadata",
"per_block_cast_to_fp8",
"is_deep_gemm_e8m0_used",
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata,
has_deep_gemm,
is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv
Expand Down Expand Up @@ -449,7 +450,7 @@ def build(
batch_size = num_decodes

# DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and is_deep_gemm_supported():
if current_platform.is_cuda() and has_deep_gemm():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens,
self.kv_cache_spec.block_size,
Expand Down
Loading