Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
545 changes: 545 additions & 0 deletions tests/kernels/attention/test_rocm_fp8_fp4_paged_mqa_logits.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion vllm/config/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class AttentionConfig:
"""If set, quantize query for attention in prefill."""

use_fp4_indexer_cache: bool = False
"""If set, use fp4 indexer cache for dsv32 family model (not support yet)"""
"""If set, use fp4 indexer cache for dsv32 and dsv4 family models.
But fp4 indexer cache is not supported for dsv32 family models.
Supported on CUDA SM100 datacenter GPUs and ROCm gfx95x GPUs."""

use_non_causal: bool = False
"""Whether to use non-causal (bidirectional) attention."""
Expand Down
108 changes: 89 additions & 19 deletions vllm/model_executor/layers/deepseek_v4_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,11 @@

return self.wo_b(z.flatten(1))

def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
def attn_gemm_parallel_execute(
self,
hidden_states: torch.Tensor,
skip_indexer_scores: bool = False,
) -> tuple[Any, ...]:
aux_streams = self.aux_stream_list
if aux_streams is not None:
assert len(aux_streams) >= 3
Expand Down Expand Up @@ -390,7 +394,8 @@
out_dtype=torch.float32,
)

aux_fns[1] = indexer_weights_proj
if not skip_indexer_scores:
aux_fns[1] = indexer_weights_proj
aux_fns[2] = indexer_compressor_kv_score

def fused_wqa_wkv() -> torch.Tensor:
Expand All @@ -410,6 +415,43 @@

return qr_kv, kv_score, indexer_kv_score, indexer_weights

def _can_skip_c4a_indexer(
self,
attn_metadata: (
dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None
),
) -> bool:
if (
not current_platform.is_rocm()
or self.compress_ratio != 4
or self.indexer is None
or not isinstance(attn_metadata, dict)
):
return False

indexer = self.indexer
if not getattr(indexer, "use_fp4_kv", False):
return False

metadata = attn_metadata.get(indexer.k_cache.prefix)
if metadata is None:
return False

topk_tokens = indexer.topk_tokens
if metadata.num_prefills > 0:

Check failure on line 441 in vllm/model_executor/layers/deepseek_v4_attention.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionMetadata" has no attribute "num_prefills" [attr-defined]

Check failure on line 441 in vllm/model_executor/layers/deepseek_v4_attention.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionMetadata" has no attribute "num_prefills" [attr-defined]
prefill = metadata.prefill

Check failure on line 442 in vllm/model_executor/layers/deepseek_v4_attention.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionMetadata" has no attribute "prefill" [attr-defined]

Check failure on line 442 in vllm/model_executor/layers/deepseek_v4_attention.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionMetadata" has no attribute "prefill" [attr-defined]
if prefill is None:
return False
if any(chunk.max_valid_len > topk_tokens for chunk in prefill.chunks):
return False

if metadata.num_decodes > 0:

Check failure on line 448 in vllm/model_executor/layers/deepseek_v4_attention.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionMetadata" has no attribute "num_decodes" [attr-defined]

Check failure on line 448 in vllm/model_executor/layers/deepseek_v4_attention.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionMetadata" has no attribute "num_decodes" [attr-defined]
decode = metadata.decode

Check failure on line 449 in vllm/model_executor/layers/deepseek_v4_attention.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionMetadata" has no attribute "decode" [attr-defined]

Check failure on line 449 in vllm/model_executor/layers/deepseek_v4_attention.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionMetadata" has no attribute "decode" [attr-defined]
if decode is None or decode.max_seq_len > topk_tokens:
return False

return True

def attention_impl(
self,
hidden_states: torch.Tensor,
Expand All @@ -418,9 +460,14 @@
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
skip_indexer_scores = self._can_skip_c4a_indexer(attn_metadata)
self.mla_attn.use_trivial_c4a_topk = skip_indexer_scores

qr_kv, kv_score, indexer_kv_score, indexer_weights = (
self.attn_gemm_parallel_execute(hidden_states)
self.attn_gemm_parallel_execute(
hidden_states,
skip_indexer_scores=skip_indexer_scores,
)
)

qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
Expand Down Expand Up @@ -451,20 +498,42 @@
compressor(kv_score, positions, self.rotary_emb)
return q

q, _ = maybe_execute_in_parallel(
wq_b_kv_insert_and_compress,
lambda: indexer(
hidden_states,
qr,
indexer_kv_score,
indexer_weights,
positions,
self.indexer_rotary_emb,
),
self.ln_events[0],
self.ln_events[1],
aux_stream,
)
if skip_indexer_scores:
assert indexer_kv_score is not None

def compress_indexer_kv_cache() -> torch.Tensor:
indexer.compressor(
indexer_kv_score,
positions,
self.indexer_rotary_emb,
)
assert indexer.topk_indices_buffer is not None
return indexer.topk_indices_buffer

q, _ = maybe_execute_in_parallel(
wq_b_kv_insert_and_compress,
compress_indexer_kv_cache,
self.ln_events[0],
self.ln_events[1],
aux_stream,
)
else:
assert indexer_kv_score is not None
assert indexer_weights is not None
q, _ = maybe_execute_in_parallel(
wq_b_kv_insert_and_compress,
lambda: indexer(
hidden_states,
qr,
indexer_kv_score,
indexer_weights,
positions,
self.indexer_rotary_emb,
),
self.ln_events[0],
self.ln_events[1],
aux_stream,
)
elif self.compressor is not None:
# wq_b + kv_insert on default, compressor on aux.
aux_stream = (
Expand Down Expand Up @@ -656,6 +725,7 @@
self.rope_head_dim = qk_rope_head_dim
self.indexer = indexer
self.topk_indices_buffer = topk_indices_buffer
self.use_trivial_c4a_topk = False

self.prefix = prefix # Alias for compatibility with compressor

Expand Down Expand Up @@ -1130,8 +1200,8 @@
assert cache_config is not None, "Deepseek V4 indexer requires cache_config"
# NOTE(yifan): FP8 indxer cache use the same layout as V3.2:
# head_dim bytes = 128 fp8 + 4 fp32 scale = 132.
# For FP4 indexer cache, we still allocate the same amount of memory as FP8,
# but only use the first half of the memory.
# For FP4 indexer cache, keep the same allocation width as CUDA and
# only use the MXFP4 value/scale prefix.
k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4
self.k_cache = DeepseekV4IndexerCache(
head_dim=k_cache_head_dim,
Expand Down
39 changes: 35 additions & 4 deletions vllm/model_executor/layers/sparse_attn_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,41 @@ def forward_hip(
k: torch.Tensor,
weights: torch.Tensor,
):
assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet"
assert isinstance(q_quant, torch.Tensor), (
"AMD sparse_attn_indexer expects a single FP8 q_quant tensor"
)
if self.use_fp4_cache:
assert isinstance(q_quant, tuple), (
"AMD MXFP4 sparse_attn_indexer expects (q_values, q_scales)"
)
else:
assert isinstance(q_quant, torch.Tensor), (
"AMD FP8 sparse_attn_indexer expects a single q_quant tensor"
)

if (
self.use_fp4_cache
or self.skip_k_cache_insert
or not rocm_aiter_ops.is_enabled()
):
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_aiter_sparse_attn_indexer_native,
)

return rocm_aiter_sparse_attn_indexer_native(
hidden_states,
_encode_layer_name(self.k_cache.prefix),
self.k_cache.kv_cache,
q_quant,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
skip_k_cache_insert=self.skip_k_cache_insert,
use_fp4_cache=self.use_fp4_cache,
)
if rocm_aiter_ops.is_enabled():
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states,
Expand Down
72 changes: 62 additions & 10 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def _lazy_init() -> None:


def get_num_sms() -> int:
if current_platform.is_rocm():
return int(current_platform.num_compute_units())
_lazy_init()
dg = _import_deep_gemm()
if dg is None:
Expand Down Expand Up @@ -370,6 +372,20 @@ def fp8_fp4_mqa_logits(
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
if current_platform.is_rocm() and q[1] is not None:
from vllm.v1.attention.ops.rocm_fp8_fp4_paged_mqa_logits import (
rocm_fp4_mqa_logits,
)

return rocm_fp4_mqa_logits(
q,
kv,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
clean_logits=clean_logits,
)

_lazy_init()
if _fp8_fp4_mqa_logits_impl is None:
return _missing()
Expand All @@ -384,20 +400,33 @@ def fp8_fp4_mqa_logits(


def get_paged_mqa_logits_metadata(
context_lens: torch.Tensor, block_size: int, num_sms: int
context_lens: torch.Tensor,
block_size: int,
num_sms: int,
) -> torch.Tensor:
"""Build scheduling metadata for paged MQA logits.

Args:
context_lens: Tensor of shape [B], dtype int32; effective context length
per batch element.
context_lens: Tensor of shape [B] or [B, next_n], dtype int32;
effective context length per batch element or per decoded token.
block_size: KV-cache block size in tokens (e.g., 64).
num_sms: Number of SMs available. 132 for Hopper
num_sms: Number of SMs/CUs available.

Returns:
Backend-specific tensor consumed by `fp8_fp4_paged_mqa_logits` to
schedule work across SMs.
"""
if current_platform.is_rocm():
from vllm.v1.attention.ops.rocm_fp8_fp4_paged_mqa_logits import (
rocm_get_paged_mqa_logits_metadata,
)

return rocm_get_paged_mqa_logits_metadata(
context_lens,
block_size,
num_sms,
)

_lazy_init()
if _get_paged_mqa_logits_metadata_impl is None:
return _missing()
Expand All @@ -413,6 +442,7 @@ def fp8_fp4_paged_mqa_logits(
schedule_metadata: torch.Tensor,
max_model_len: int,
clean_logits: bool,
logits_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Compute MQA logits using a paged KV-cache.

Expand All @@ -426,25 +456,47 @@ def fp8_fp4_paged_mqa_logits(
q_values is packed uint8 and q_scale is the companion
block-scale tensor.
kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1,
D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos)
storing the float dequant scale.
D+4], dtype `torch.uint8`, with one float32 scale per token.
FP4 layout is [num_blocks, block_size, 1, D/2+4], with packed
MXFP4 values and one packed int32 UE8M0 scale word per token.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
context_lens: Tensor of shape [B] or [B, next_n], dtype int32;
effective context length for each batch element/token.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
clean_logits: Whether to clean the unfilled logits into `-inf`.
logits_dtype: Output dtype, matching DeepGEMM's float32/bfloat16 API.

Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
`logits_dtype`.
"""
if current_platform.is_rocm():
from vllm.v1.attention.ops.rocm_fp8_fp4_paged_mqa_logits import (
rocm_fp8_fp4_paged_mqa_logits,
)

return rocm_fp8_fp4_paged_mqa_logits(
q,
kv_cache,
weights,
context_lens,
block_tables,
schedule_metadata,
max_model_len,
clean_logits=clean_logits,
logits_dtype=logits_dtype,
)

_lazy_init()
if _fp8_fp4_paged_mqa_logits_impl is None:
return _missing()
kwargs: dict[str, Any] = {"clean_logits": clean_logits}
if logits_dtype is not torch.float32:
kwargs["logits_dtype"] = logits_dtype
return _fp8_fp4_paged_mqa_logits_impl(
q,
kv_cache,
Expand All @@ -453,7 +505,7 @@ def fp8_fp4_paged_mqa_logits(
block_tables,
schedule_metadata,
max_model_len,
clean_logits=clean_logits,
**kwargs,
)


Expand Down
Loading
Loading