Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ MLA decode backends are selected using the standard
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/attention/test_rocm_triton_attn_dsv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _ref_combine_topk_swa_ragged(

@torch.inference_mode()
def test_compute_global_topk_ragged_indices_and_indptr() -> None:
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
from vllm.models.deepseek_v4.amd.rocm import (
compute_global_topk_ragged_indices_and_indptr,
)

Expand Down Expand Up @@ -330,7 +330,7 @@ def test_sparse_attn_decode_ragged_kernel() -> None:

@torch.inference_mode()
def test_combine_topk_swa_indices_ragged() -> None:
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
from vllm.models.deepseek_v4.amd.rocm import (
combine_topk_swa_indices_ragged,
)

Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/test_fused_inv_rope_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
the existing separate operations (inverse RoPE via rotate_neox + FP8 quant
via per_token_group_quant_fp8).

The reference faithfully reproduces the exact flow in deepseek_v4_attention.py:295-310:
The reference faithfully reproduces the exact flow in
deepseek_v4/nvidia/ops/attention.py:295-310:
1. Apply inverse RoPE (NeoX style, last rope_dim=64 dims of each head)
2. Reshape [T, H, head_dim] -> [T, G, D]
3. Transpose+flatten to [G*T, D], quantize, reshape back
Expand Down Expand Up @@ -667,7 +668,7 @@ def _unfused_inv_rope_fp8_quant(
nope_dim: int = NOPE_DIM,
rope_dim: int = ROPE_DIM,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Unfused path matching deepseek_v4_attention.py:295-310.
"""Unfused path matching deepseek_v4/nvidia/ops/attention.py:295-310.

Uses the production CUDA RoPE kernel + per_token_group_quant_fp8.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

from vllm.forward_context import get_forward_context
from vllm.models.deepseek_v4.common.ops import dequantize_and_gather_k_cache
from vllm.models.deepseek_v4.nvidia.flashmla import (
DeepseekV4FlashMLASparseBackend,
DeepseekV4SparseMLAAttentionImpl,
)
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionLayer,
CommonAttentionMetadata,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import (
DeepseekV4FlashMLASparseBackend,
FlashMLASparseMetadata,
FlashMLASparseMetadataBuilder,
)
Expand All @@ -31,7 +32,9 @@
from vllm.v1.worker.workspace import current_workspace_manager

if TYPE_CHECKING:
from vllm.models.deepseek_v4.attention import DeepseekV4MLAAttention
from vllm.models.deepseek_v4.nvidia.ops.attention import (
DeepseekV4MLAAttention,
)


def _build_indptr_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -449,47 +452,27 @@ def build(
)


class DeepseekV4ROCMAiterMLASparseImpl(
SparseMLAAttentionImpl[DeepseekV4ROCMAiterMLASparseMetadata]
):
"""ROCm sparse MLA implementation used by DeepSeek V4's custom MLA layer."""
class DeepseekV4ROCMAiterMLASparseBackend(DeepseekV4FlashMLASparseBackend):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

DeepseekV4ROCMAiterMLASparseBackend should not inherit directly from DeepseekV4FlashMLASparseBackend if the latter is defined in a CUDA-specific file. This creates a hard dependency on the CUDA implementation path. As suggested in the previous comments, move the common backend logic to a shared base class in a platform-agnostic file to ensure ROCm support remains functional and isolated.

@staticmethod
def get_name() -> str:
return "ROCM_V4_FLASHMLA_SPARSE"

_PREFILL_CHUNK_SIZE = 4
@staticmethod
def get_builder_cls() -> type["DeepseekV4ROCMAiterMLASparseMetadataBuilder"]:
return DeepseekV4ROCMAiterMLASparseMetadataBuilder

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
**_: object,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
@staticmethod
def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]:
return DeepseekV4ROCMAiterMLASparseImpl

def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: DeepseekV4ROCMAiterMLASparseMetadata,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
raise NotImplementedError(
"DeepseekV4ROCMAiterMLASparseImpl is driven by "
"DeepseekV4MLAAttention.forward."
)

class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
"""ROCm sparse MLA implementation used by DeepSeek V4's custom MLA layer."""

backend_cls = DeepseekV4ROCMAiterMLASparseBackend

@classmethod
def forward(
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
q: torch.Tensor,
Expand All @@ -506,6 +489,25 @@ def forward(

forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata

if attn_metadata is None:
# Warmup dummy run: no real metadata. Reserve the same bf16
# gather workspace _forward_prefill would; the dequantize / topk
# / sparse_fwd kernels are skipped this step.
swa_only = layer.compress_ratio <= 1
N = (
0
if swa_only
else (layer.max_model_len + layer.compress_ratio - 1)
// layer.compress_ratio
)
M = N + layer.window_size + layer.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
output.zero_()
return

assert isinstance(attn_metadata, dict)
rocm_metadata = cast(
DeepseekV4ROCMAiterMLASparseMetadata | None,
Expand Down Expand Up @@ -658,17 +660,17 @@ def _forward_prefill(
N = 0

M = N + layer.window_size + layer.max_num_batched_tokens
num_chunks = (num_prefills + cls._PREFILL_CHUNK_SIZE - 1) // (
cls._PREFILL_CHUNK_SIZE
num_chunks = (num_prefills + cls.PREFILL_CHUNK_SIZE - 1) // (
cls.PREFILL_CHUNK_SIZE
)

workspace_manager = current_workspace_manager()
kv = workspace_manager.get_simultaneous(
((cls._PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)[0]
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * cls._PREFILL_CHUNK_SIZE
chunk_end = min(chunk_start + cls._PREFILL_CHUNK_SIZE, num_prefills)
chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE
chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills)
chunk_size = chunk_end - chunk_start
if not swa_only:
assert attn_metadata is not None
Expand Down Expand Up @@ -737,17 +739,3 @@ def _forward_prefill(
ragged_indices=combined_ragged_indices,
ragged_indptr=combined_ragged_indptr,
)


class DeepseekV4ROCMAiterMLASparseBackend(DeepseekV4FlashMLASparseBackend):
@staticmethod
def get_name() -> str:
return "ROCM_V4_FLASHMLA_SPARSE"

@staticmethod
def get_builder_cls() -> type["DeepseekV4ROCMAiterMLASparseMetadataBuilder"]:
return DeepseekV4ROCMAiterMLASparseMetadataBuilder

@staticmethod
def get_impl_cls() -> type["DeepseekV4ROCMAiterMLASparseImpl"]:
return DeepseekV4ROCMAiterMLASparseImpl
Loading
Loading