diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 0b0df945a885..b665fde2ccb1 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -219,7 +219,7 @@ MLA decode backends are selected using the standard | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | -| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | +| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | diff --git a/tests/kernels/attention/test_rocm_triton_attn_dsv4.py b/tests/kernels/attention/test_rocm_triton_attn_dsv4.py index aefcfeee4d7e..d4fa9697cb7f 100644 --- a/tests/kernels/attention/test_rocm_triton_attn_dsv4.py +++ b/tests/kernels/attention/test_rocm_triton_attn_dsv4.py @@ -198,7 +198,7 @@ def _ref_combine_topk_swa_ragged( @torch.inference_mode() def test_compute_global_topk_ragged_indices_and_indptr() -> None: - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( + from vllm.models.deepseek_v4.amd.rocm import ( compute_global_topk_ragged_indices_and_indptr, ) @@ -330,7 +330,7 @@ def test_sparse_attn_decode_ragged_kernel() -> None: @torch.inference_mode() def test_combine_topk_swa_indices_ragged() -> None: - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( + from vllm.models.deepseek_v4.amd.rocm import ( combine_topk_swa_indices_ragged, ) diff --git a/tests/kernels/test_fused_inv_rope_fp8_quant.py b/tests/kernels/test_fused_inv_rope_fp8_quant.py index 84d64bbf83bf..563661e441c1 100644 --- a/tests/kernels/test_fused_inv_rope_fp8_quant.py +++ b/tests/kernels/test_fused_inv_rope_fp8_quant.py @@ -7,7 +7,8 @@ the existing separate operations (inverse RoPE via rotate_neox + FP8 quant via per_token_group_quant_fp8). -The reference faithfully reproduces the exact flow in deepseek_v4_attention.py:295-310: +The reference faithfully reproduces the exact flow in +deepseek_v4/nvidia/ops/attention.py:295-310: 1. Apply inverse RoPE (NeoX style, last rope_dim=64 dims of each head) 2. Reshape [T, H, head_dim] -> [T, G, D] 3. Transpose+flatten to [G*T, D], quantize, reshape back @@ -667,7 +668,7 @@ def _unfused_inv_rope_fp8_quant( nope_dim: int = NOPE_DIM, rope_dim: int = ROPE_DIM, ) -> tuple[torch.Tensor, torch.Tensor]: - """Unfused path matching deepseek_v4_attention.py:295-310. + """Unfused path matching deepseek_v4/nvidia/ops/attention.py:295-310. Uses the production CUDA RoPE kernel + per_token_group_quant_fp8. """ diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py b/vllm/models/deepseek_v4/amd/rocm.py similarity index 93% rename from vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py rename to vllm/models/deepseek_v4/amd/rocm.py index c29ad2bd1ace..e71742faffd6 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py +++ b/vllm/models/deepseek_v4/amd/rocm.py @@ -8,14 +8,15 @@ from vllm.forward_context import get_forward_context from vllm.models.deepseek_v4.common.ops import dequantize_and_gather_k_cache +from vllm.models.deepseek_v4.nvidia.flashmla import ( + DeepseekV4FlashMLASparseBackend, + DeepseekV4SparseMLAAttentionImpl, +) from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( - AttentionLayer, CommonAttentionMetadata, - SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.mla.flashmla_sparse import ( - DeepseekV4FlashMLASparseBackend, FlashMLASparseMetadata, FlashMLASparseMetadataBuilder, ) @@ -31,7 +32,9 @@ from vllm.v1.worker.workspace import current_workspace_manager if TYPE_CHECKING: - from vllm.models.deepseek_v4.attention import DeepseekV4MLAAttention + from vllm.models.deepseek_v4.nvidia.ops.attention import ( + DeepseekV4MLAAttention, + ) def _build_indptr_from_lengths(lengths: torch.Tensor) -> torch.Tensor: @@ -449,47 +452,27 @@ def build( ) -class DeepseekV4ROCMAiterMLASparseImpl( - SparseMLAAttentionImpl[DeepseekV4ROCMAiterMLASparseMetadata] -): - """ROCm sparse MLA implementation used by DeepSeek V4's custom MLA layer.""" +class DeepseekV4ROCMAiterMLASparseBackend(DeepseekV4FlashMLASparseBackend): + @staticmethod + def get_name() -> str: + return "ROCM_V4_FLASHMLA_SPARSE" - _PREFILL_CHUNK_SIZE = 4 + @staticmethod + def get_builder_cls() -> type["DeepseekV4ROCMAiterMLASparseMetadataBuilder"]: + return DeepseekV4ROCMAiterMLASparseMetadataBuilder - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: list[float] | None, - sliding_window: int | None, - kv_cache_dtype: str, - logits_soft_cap: float | None, - attn_type: str, - kv_sharing_target_layer_name: str | None, - **_: object, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype + @staticmethod + def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]: + return DeepseekV4ROCMAiterMLASparseImpl - def forward_mqa( - self, - q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: DeepseekV4ROCMAiterMLASparseMetadata, - layer: AttentionLayer, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - raise NotImplementedError( - "DeepseekV4ROCMAiterMLASparseImpl is driven by " - "DeepseekV4MLAAttention.forward." - ) + +class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): + """ROCm sparse MLA implementation used by DeepSeek V4's custom MLA layer.""" + + backend_cls = DeepseekV4ROCMAiterMLASparseBackend @classmethod - def forward( + def forward_mqa( # type: ignore[override] cls, layer: "DeepseekV4MLAAttention", q: torch.Tensor, @@ -506,6 +489,25 @@ def forward( forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + + if attn_metadata is None: + # Warmup dummy run: no real metadata. Reserve the same bf16 + # gather workspace _forward_prefill would; the dequantize / topk + # / sparse_fwd kernels are skipped this step. + swa_only = layer.compress_ratio <= 1 + N = ( + 0 + if swa_only + else (layer.max_model_len + layer.compress_ratio - 1) + // layer.compress_ratio + ) + M = N + layer.window_size + layer.max_num_batched_tokens + current_workspace_manager().get_simultaneous( + ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ) + output.zero_() + return + assert isinstance(attn_metadata, dict) rocm_metadata = cast( DeepseekV4ROCMAiterMLASparseMetadata | None, @@ -658,17 +660,17 @@ def _forward_prefill( N = 0 M = N + layer.window_size + layer.max_num_batched_tokens - num_chunks = (num_prefills + cls._PREFILL_CHUNK_SIZE - 1) // ( - cls._PREFILL_CHUNK_SIZE + num_chunks = (num_prefills + cls.PREFILL_CHUNK_SIZE - 1) // ( + cls.PREFILL_CHUNK_SIZE ) workspace_manager = current_workspace_manager() kv = workspace_manager.get_simultaneous( - ((cls._PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), )[0] for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * cls._PREFILL_CHUNK_SIZE - chunk_end = min(chunk_start + cls._PREFILL_CHUNK_SIZE, num_prefills) + chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE + chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills) chunk_size = chunk_end - chunk_start if not swa_only: assert attn_metadata is not None @@ -737,17 +739,3 @@ def _forward_prefill( ragged_indices=combined_ragged_indices, ragged_indptr=combined_ragged_indptr, ) - - -class DeepseekV4ROCMAiterMLASparseBackend(DeepseekV4FlashMLASparseBackend): - @staticmethod - def get_name() -> str: - return "ROCM_V4_FLASHMLA_SPARSE" - - @staticmethod - def get_builder_cls() -> type["DeepseekV4ROCMAiterMLASparseMetadataBuilder"]: - return DeepseekV4ROCMAiterMLASparseMetadataBuilder - - @staticmethod - def get_impl_cls() -> type["DeepseekV4ROCMAiterMLASparseImpl"]: - return DeepseekV4ROCMAiterMLASparseImpl diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py new file mode 100644 index 000000000000..41630456561a --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -0,0 +1,402 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from typing import TYPE_CHECKING, ClassVar, cast + +import torch + +from vllm.forward_context import get_forward_context +from vllm.models.deepseek_v4.common.ops import ( + combine_topk_swa_indices, + compute_global_topk_indices_and_lens, + dequantize_and_gather_k_cache, +) +from vllm.v1.attention.backend import ( + AttentionBackend, + MultipleOf, + SparseMLAAttentionImpl, +) +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + FlashMLASparseBackend, + FlashMLASparseMetadata, +) +from vllm.v1.attention.ops.flashmla import ( + flash_mla_sparse_fwd, + flash_mla_with_kvcache, +) +from vllm.v1.worker.workspace import current_workspace_manager + +if TYPE_CHECKING: + from vllm.models.deepseek_v4.nvidia.ops.attention import ( + DeepseekV4MLAAttention, + ) + from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata + + +class DeepseekV4SparseMLAAttentionImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]): + """Abstract parent for DeepseekV4 sparse MLA impls. + + V4 sparse MLA is driven by the layer (``DeepseekV4MLAAttention.forward``) + rather than the v1 framework, so ``forward_mqa`` is overridden with a + classmethod that takes the layer as its first argument. This Liskov-broken + override is intentional: the grandparent's instance-method ``forward_mqa`` + is never called on V4 layers. + """ + + backend_cls: ClassVar[type[AttentionBackend]] + + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather + # workspace allocated in _forward_prefill and is also read by the V4 layer's + # dummy-run path to pre-reserve that workspace. + PREFILL_CHUNK_SIZE: ClassVar[int] = 4 + + @classmethod + @abstractmethod + def forward_mqa( # type: ignore[override] + cls, + layer: "DeepseekV4MLAAttention", + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + output: torch.Tensor, + ) -> None: + raise NotImplementedError + + +class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend): + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [256] + + @staticmethod + def get_name() -> str: + return "V4_FLASHMLA_SPARSE" + + @staticmethod + def get_impl_cls() -> type["DeepseekV4SparseMLAAttentionImpl"]: + return DeepseekV4FlashMLASparseImpl + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + # DeepSeek V4 layout: 448 NoPE + 64 RoPE = 512 (overrides the + # V3.2 default of 576 from FlashMLASparseBackend). + return [512] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if cache_dtype_str == "fp8_ds_mla": + # DeepseekV4 main MLA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale). + # head_size passed in is the semantic head_dim (512). + return (num_blocks, block_size, 584) + else: + return (num_blocks, block_size, head_size) + + +class DeepseekV4FlashMLASparseImpl(DeepseekV4SparseMLAAttentionImpl): + """FlashMLA sparse MLA implementation for DeepSeek V4's custom MLA layer.""" + + backend_cls = DeepseekV4FlashMLASparseBackend + + @classmethod + def forward_mqa( # type: ignore[override] + cls, + layer: "DeepseekV4MLAAttention", + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + output: torch.Tensor, + ) -> None: + assert output.shape == q.shape, ( + f"output buffer shape {output.shape} must match q shape {q.shape}" + ) + assert output.dtype == q.dtype, ( + f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" + ) + + # Get SWA and indexer metadata from forward context + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + + if attn_metadata is None: + # Warmup dummy run: no real metadata. Reserve the same bf16 + # gather workspace _forward_prefill would; the dequantize / topk + # / sparse_fwd kernels are skipped this step. + swa_only = layer.compress_ratio <= 1 + N = ( + 0 + if swa_only + else (layer.max_model_len + layer.compress_ratio - 1) + // layer.compress_ratio + ) + M = N + layer.window_size + layer.max_num_batched_tokens + current_workspace_manager().get_simultaneous( + ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ) + output.zero_() + return + + assert isinstance(attn_metadata, dict) + flashmla_metadata = cast( + FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix) + ) + swa_metadata = cast( + "DeepseekSparseSWAMetadata | None", + attn_metadata.get(layer.swa_cache_layer.prefix), + ) + assert swa_metadata is not None + + swa_only = layer.compress_ratio <= 1 + # SWA-only layers (compress_ratio <= 1) don't have their own KV cache + # allocation, so layer.kv_cache may be empty after profiling cleanup. + self_kv_cache = layer.kv_cache if not swa_only else None + swa_kv_cache = layer.swa_cache_layer.kv_cache + + # Split prefill and decode + num_decodes = swa_metadata.num_decodes + num_prefills = swa_metadata.num_prefills + num_decode_tokens = swa_metadata.num_decode_tokens + + if num_prefills > 0: + cls._forward_prefill( + layer=layer, + q=q[num_decode_tokens:], + positions=positions[num_decode_tokens:], + compressed_k_cache=self_kv_cache, + swa_k_cache=swa_kv_cache, + output=output[num_decode_tokens:], + attn_metadata=flashmla_metadata, + swa_metadata=swa_metadata, + ) + if num_decodes > 0: + cls._forward_decode( + layer=layer, + q=q[:num_decode_tokens], + kv_cache=self_kv_cache, + swa_metadata=swa_metadata, + attn_metadata=flashmla_metadata, + swa_only=swa_only, + output=output[:num_decode_tokens], + ) + + @classmethod + def _forward_decode( + cls, + layer: "DeepseekV4MLAAttention", + q: torch.Tensor, + kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1 + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + swa_only: bool, + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + + topk_indices = None + topk_lens = None + if not swa_only: + assert attn_metadata is not None + assert swa_metadata.is_valid_token is not None + block_size = attn_metadata.block_size // layer.compress_ratio + is_valid = swa_metadata.is_valid_token[:num_decode_tokens] + if layer.compress_ratio == 4: + # C4A: local indices differ per layer (filled by Indexer). + assert layer.topk_indices_buffer is not None + global_indices, topk_lens = compute_global_topk_indices_and_lens( + layer.topk_indices_buffer[:num_decode_tokens], + swa_metadata.token_to_req_indices, + attn_metadata.block_table[:num_decodes], + block_size, + is_valid, + ) + topk_indices = global_indices.view(num_decode_tokens, 1, -1) + else: + # C128A: pre-computed during metadata build. + topk_indices = attn_metadata.c128a_global_decode_topk_indices + topk_lens = attn_metadata.c128a_decode_topk_lens + + swa_indices = swa_metadata.decode_swa_indices + swa_lens = swa_metadata.decode_swa_lens + + # We treat queries in the same seq as different queries + # and later we only attend by generated indices. + # q arrives pre-padded to layer.padded_heads by the outer wrapper. + q = q.unsqueeze(1) + + # Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes) + # Use unsqueeze to preserve strides (handles padded blocks correctly) + swa_cache = layer.swa_cache_layer.kv_cache.unsqueeze(-2) + # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) + if kv_cache is not None: + kv_cache = kv_cache.unsqueeze(-2) + + # One FlashMLASchedMeta per layer type, shared across all same-type + # layers within this decode step. The first forward call per type + # triggers the in-kernel planner (allocating tile_scheduler_metadata + # and num_splits via PyTorch's graph-aware allocator so CUDA graph + # capture reuses the same addresses on replay); subsequent same-type + # layers see have_initialized=True and skip the planner. + if layer.compress_ratio <= 1: + tile_metadata = swa_metadata.tile_sched_swaonly + elif layer.compress_ratio == 4: + tile_metadata = swa_metadata.tile_sched_c4a + elif layer.compress_ratio == 128: + tile_metadata = swa_metadata.tile_sched_c128a + else: + raise ValueError( + f"Unsupported compress_ratio={layer.compress_ratio}; " + "expected 1, 4, or 128." + ) + assert tile_metadata is not None, ( + "swa_metadata missing tile_sched entry for " + f"compress_ratio={layer.compress_ratio}; " + "DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not " + "allocate one for this layer type." + ) + + out, _ = flash_mla_with_kvcache( + q=q, + k_cache=swa_cache, + block_table=None, + head_dim_v=512, + tile_scheduler_metadata=tile_metadata, + cache_seqlens=None, + is_fp8_kvcache=True, + indices=swa_indices, + topk_length=swa_lens, + softmax_scale=layer.scale, + attn_sink=layer.attn_sink, + extra_k_cache=kv_cache if not swa_only else None, + extra_indices_in_kvcache=topk_indices, + extra_topk_length=topk_lens, + out=output.unsqueeze(1), + ) + + @classmethod + def _forward_prefill( + cls, + layer: "DeepseekV4MLAAttention", + q: torch.Tensor, + positions: torch.Tensor, + compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1 + swa_k_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashMLASparseMetadata | None, + swa_metadata: "DeepseekSparseSWAMetadata", + ) -> None: + swa_only = attn_metadata is None + + num_prefills = swa_metadata.num_prefills + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + + # Use pre-computed prefill metadata. + seq_lens = swa_metadata.prefill_seq_lens + gather_lens = swa_metadata.prefill_gather_lens + assert seq_lens is not None + assert gather_lens is not None + + # Derive prefill-local token offsets from the full query_start_loc_cpu. + query_start_loc_cpu = swa_metadata.query_start_loc_cpu + query_start_loc = swa_metadata.query_start_loc + assert query_start_loc_cpu is not None + assert query_start_loc is not None + prefill_token_base = query_start_loc_cpu[num_decodes] + + if not swa_only: + if layer.compress_ratio == 4: + assert layer.topk_indices_buffer is not None + topk_indices = layer.topk_indices_buffer[num_decode_tokens:] + topk_indices = topk_indices[:num_prefill_tokens] + else: + # C128A: pre-computed during metadata build. + assert attn_metadata is not None + topk_indices = attn_metadata.c128a_prefill_topk_indices + top_k = topk_indices.shape[-1] + # Compressed region must fit the full compressed pool (seq_len // + # compress_ratio), not just top_k. top_k bounds how many indices + # the indexer selects, not the pool size it indexes into. + N = (layer.max_model_len + layer.compress_ratio - 1) // layer.compress_ratio + else: + # NOTE(woosuk): topk_indices will not be used for SWA-only layers. + assert layer.topk_indices_buffer is not None + topk_indices = layer.topk_indices_buffer[num_decode_tokens:] + top_k = 0 + N = 0 + + M = N + layer.window_size + layer.max_num_batched_tokens + chunk_size_const = cls.PREFILL_CHUNK_SIZE + num_chunks = (num_prefills + chunk_size_const - 1) // chunk_size_const + + workspace_manager = current_workspace_manager() + kv = workspace_manager.get_simultaneous( + ((chunk_size_const, M, q.shape[-1]), torch.bfloat16), + )[0] + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size_const + chunk_end = min(chunk_start + chunk_size_const, num_prefills) + chunk_size = chunk_end - chunk_start + if not swa_only: + # Gather compressed KV + assert attn_metadata is not None + block_table = attn_metadata.block_table[num_decodes:] + dequantize_and_gather_k_cache( + kv[:chunk_size], + compressed_k_cache, + seq_lens=seq_lens[chunk_start:chunk_end] // layer.compress_ratio, + gather_lens=None, + block_table=block_table[chunk_start:chunk_end], + block_size=attn_metadata.block_size // layer.compress_ratio, + offset=0, + ) + + # Gather SWA KV + swa_block_table = swa_metadata.block_table[num_decodes:] + dequantize_and_gather_k_cache( + kv[:chunk_size], + swa_k_cache, + seq_lens=seq_lens[chunk_start:chunk_end], + gather_lens=gather_lens[chunk_start:chunk_end], + block_table=swa_block_table[chunk_start:chunk_end], + block_size=swa_metadata.block_size, + offset=N, + ) + + # Combine the topk indices and SWA indices for gathered KV cache + query_start = ( + query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base + ) + query_end = ( + query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base + ) + + combined_indices, combined_lens = combine_topk_swa_indices( + topk_indices[query_start:query_end], + query_start_loc[ + num_decodes + chunk_start : num_decodes + chunk_end + 1 + ], + seq_lens[chunk_start:chunk_end], + gather_lens[chunk_start:chunk_end], + layer.window_size, + layer.compress_ratio, + top_k, + M, + N, + ) + flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=layer.scale, + attn_sink=layer.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 395cd0201b89..d69bad8d38de 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -56,7 +56,7 @@ maybe_prefix, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.models.deepseek_v4.attention import ( +from vllm.models.deepseek_v4.nvidia.ops.attention import ( DeepseekV4Indexer, DeepseekV4MLAModules, DeepseekV4MultiHeadLatentAttentionWrapper, diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/nvidia/ops/attention.py similarity index 70% rename from vllm/models/deepseek_v4/attention.py rename to vllm/models/deepseek_v4/nvidia/ops/attention.py index b773f9cbbc30..d23b5a578bfa 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/nvidia/ops/attention.py @@ -20,9 +20,6 @@ ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer from vllm.models.deepseek_v4.common.ops import ( - combine_topk_swa_indices, - compute_global_topk_indices_and_lens, - dequantize_and_gather_k_cache, fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, @@ -62,28 +59,36 @@ ) from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( - DeepseekV4FlashMLASparseBackend, FlashMLASparseBackend, - FlashMLASparseMetadata, ) from vllm.v1.attention.backends.mla.indexer import ( DeepseekV4IndexerBackend, get_max_prefill_buffer_size, ) from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache -from vllm.v1.attention.ops.flashmla import ( - flash_mla_sparse_fwd, - flash_mla_with_kvcache, -) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from vllm.v1.worker.workspace import current_workspace_manager + +if TYPE_CHECKING: + from vllm.models.deepseek_v4.nvidia.flashmla import ( + DeepseekV4SparseMLAAttentionImpl, + ) logger = init_logger(__name__) -# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather -# workspace allocated at _forward_prefill (and the matching profile-time -# reservation in attention_impl's dummy-run branch). -PREFILL_CHUNK_SIZE = 4 + +def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]": + """Pick the platform-specific V4 sparse MLA impl class. Sole platform check.""" + if current_platform.is_rocm(): + from vllm.models.deepseek_v4.amd.rocm import ( + DeepseekV4ROCMAiterMLASparseImpl, + ) + + return DeepseekV4ROCMAiterMLASparseImpl + from vllm.models.deepseek_v4.nvidia.flashmla import ( + DeepseekV4FlashMLASparseImpl, + ) + + return DeepseekV4FlashMLASparseImpl @dataclass @@ -489,25 +494,6 @@ def wq_b_kv_insert() -> torch.Tensor: q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - # Handle dummy run (no metadata). - if not isinstance(attn_metadata, dict): - # Reserve _forward_prefill's bf16-gather workspace; the dummy - # run returns before mla_attn runs, so without this the shared - # workspace locks below the real prefill size. - sub = self.mla_attn - swa_only = sub.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio - ) - M = N + sub.window_size + sub.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) - out.zero_() - return - # Pad q to FlashMLA-required head count (64 or 128) if self.n_local_heads < self.padded_heads: pad_size = self.padded_heads - self.n_local_heads @@ -643,6 +629,8 @@ def __init__( **extra_impl_args, ) -> None: super().__init__() + self.impl_cls = _select_v4_sparse_impl() + self.backend_cls = self.impl_cls.backend_cls self.num_heads = num_heads self.num_kv_heads = 1 self.head_dim = head_dim @@ -723,13 +711,7 @@ def __init__( self.kv_cache = torch.tensor([]) def get_attn_backend(self) -> type[AttentionBackend]: - if current_platform.is_rocm(): - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( - DeepseekV4ROCMAiterMLASparseBackend, - ) - - return DeepseekV4ROCMAiterMLASparseBackend - return DeepseekV4FlashMLASparseBackend + return self.backend_cls def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: if ( @@ -754,275 +736,7 @@ def forward( positions: torch.Tensor, output: torch.Tensor, ) -> None: - assert output.shape == q.shape, ( - f"output buffer shape {output.shape} must match q shape {q.shape}" - ) - assert output.dtype == q.dtype, ( - f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" - ) - - if current_platform.is_rocm(): - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( - DeepseekV4ROCMAiterMLASparseImpl, - ) - - DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output) - return - - # Get SWA and indexer metadata from forward context - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - assert isinstance(attn_metadata, dict) - flashmla_metadata = cast( - FlashMLASparseMetadata | None, attn_metadata.get(self.prefix) - ) - swa_metadata = cast( - "DeepseekSparseSWAMetadata | None", - attn_metadata.get(self.swa_cache_layer.prefix), - ) - assert swa_metadata is not None - - swa_only = self.compress_ratio <= 1 - # SWA-only layers (compress_ratio <= 1) don't have their own KV cache - # allocation, so self.kv_cache may be empty after profiling cleanup. - self_kv_cache = self.kv_cache if not swa_only else None - swa_kv_cache = self.swa_cache_layer.kv_cache - - # Split prefill and decode - num_decodes = swa_metadata.num_decodes - num_prefills = swa_metadata.num_prefills - num_decode_tokens = swa_metadata.num_decode_tokens - - if num_prefills > 0: - self._forward_prefill( - q=q[num_decode_tokens:], - positions=positions[num_decode_tokens:], - compressed_k_cache=self_kv_cache, - swa_k_cache=swa_kv_cache, - output=output[num_decode_tokens:], - attn_metadata=flashmla_metadata, - swa_metadata=swa_metadata, - ) - if num_decodes > 0: - self._forward_decode( - q=q[:num_decode_tokens], - kv_cache=self_kv_cache, - swa_metadata=swa_metadata, - attn_metadata=flashmla_metadata, - swa_only=swa_only, - output=output[:num_decode_tokens], - ) - - def _forward_decode( - self, - q: torch.Tensor, - kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1 - swa_metadata: "DeepseekSparseSWAMetadata", - attn_metadata: FlashMLASparseMetadata | None, - swa_only: bool, - output: torch.Tensor, - ) -> None: - num_decodes = swa_metadata.num_decodes - num_decode_tokens = swa_metadata.num_decode_tokens - - topk_indices = None - topk_lens = None - if not swa_only: - assert attn_metadata is not None - assert swa_metadata.is_valid_token is not None - block_size = attn_metadata.block_size // self.compress_ratio - is_valid = swa_metadata.is_valid_token[:num_decode_tokens] - if self.compress_ratio == 4: - # C4A: local indices differ per layer (filled by Indexer). - assert self.topk_indices_buffer is not None - global_indices, topk_lens = compute_global_topk_indices_and_lens( - self.topk_indices_buffer[:num_decode_tokens], - swa_metadata.token_to_req_indices, - attn_metadata.block_table[:num_decodes], - block_size, - is_valid, - ) - topk_indices = global_indices.view(num_decode_tokens, 1, -1) - else: - # C128A: pre-computed during metadata build. - topk_indices = attn_metadata.c128a_global_decode_topk_indices - topk_lens = attn_metadata.c128a_decode_topk_lens - - swa_indices = swa_metadata.decode_swa_indices - swa_lens = swa_metadata.decode_swa_lens - - # We treat queries in the same seq as different queries - # and later we only attend by generated indices. - # q arrives pre-padded to self.padded_heads by the outer wrapper. - q = q.unsqueeze(1) - - # Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes) - # Use unsqueeze to preserve strides (handles padded blocks correctly) - swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) - # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) - if kv_cache is not None: - kv_cache = kv_cache.unsqueeze(-2) - - # One FlashMLASchedMeta per layer type, shared across all same-type - # layers within this decode step. The first forward call per type - # triggers the in-kernel planner (allocating tile_scheduler_metadata - # and num_splits via PyTorch's graph-aware allocator so CUDA graph - # capture reuses the same addresses on replay); subsequent same-type - # layers see have_initialized=True and skip the planner. - if self.compress_ratio <= 1: - tile_metadata = swa_metadata.tile_sched_swaonly - elif self.compress_ratio == 4: - tile_metadata = swa_metadata.tile_sched_c4a - elif self.compress_ratio == 128: - tile_metadata = swa_metadata.tile_sched_c128a - else: - raise ValueError( - f"Unsupported compress_ratio={self.compress_ratio}; " - "expected 1, 4, or 128." - ) - assert tile_metadata is not None, ( - "swa_metadata missing tile_sched entry for " - f"compress_ratio={self.compress_ratio}; " - "DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not " - "allocate one for this layer type." - ) - - out, _ = flash_mla_with_kvcache( - q=q, - k_cache=swa_cache, - block_table=None, - head_dim_v=512, - tile_scheduler_metadata=tile_metadata, - cache_seqlens=None, - is_fp8_kvcache=True, - indices=swa_indices, - topk_length=swa_lens, - softmax_scale=self.scale, - attn_sink=self.attn_sink, - extra_k_cache=kv_cache if not swa_only else None, - extra_indices_in_kvcache=topk_indices, - extra_topk_length=topk_lens, - out=output.unsqueeze(1), - ) - - def _forward_prefill( - self, - q: torch.Tensor, - positions: torch.Tensor, - compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1 - swa_k_cache: torch.Tensor, - output: torch.Tensor, - attn_metadata: FlashMLASparseMetadata | None, - swa_metadata: "DeepseekSparseSWAMetadata", - ) -> None: - swa_only = attn_metadata is None - - num_prefills = swa_metadata.num_prefills - num_prefill_tokens = swa_metadata.num_prefill_tokens - num_decodes = swa_metadata.num_decodes - num_decode_tokens = swa_metadata.num_decode_tokens - - # Use pre-computed prefill metadata. - seq_lens = swa_metadata.prefill_seq_lens - gather_lens = swa_metadata.prefill_gather_lens - assert seq_lens is not None - assert gather_lens is not None - - # Derive prefill-local token offsets from the full query_start_loc_cpu. - query_start_loc_cpu = swa_metadata.query_start_loc_cpu - query_start_loc = swa_metadata.query_start_loc - assert query_start_loc_cpu is not None - assert query_start_loc is not None - prefill_token_base = query_start_loc_cpu[num_decodes] - - if not swa_only: - if self.compress_ratio == 4: - assert self.topk_indices_buffer is not None - topk_indices = self.topk_indices_buffer[num_decode_tokens:] - topk_indices = topk_indices[:num_prefill_tokens] - else: - # C128A: pre-computed during metadata build. - assert attn_metadata is not None - topk_indices = attn_metadata.c128a_prefill_topk_indices - top_k = topk_indices.shape[-1] - # Compressed region must fit the full compressed pool (seq_len // - # compress_ratio), not just top_k. top_k bounds how many indices - # the indexer selects, not the pool size it indexes into. - N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio - else: - # NOTE(woosuk): topk_indices will not be used for SWA-only layers. - assert self.topk_indices_buffer is not None - topk_indices = self.topk_indices_buffer[num_decode_tokens:] - top_k = 0 - N = 0 - - M = N + self.window_size + self.max_num_batched_tokens - num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE - - workspace_manager = current_workspace_manager() - kv = workspace_manager.get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - )[0] - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * PREFILL_CHUNK_SIZE - chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) - chunk_size = chunk_end - chunk_start - if not swa_only: - # Gather compressed KV - assert attn_metadata is not None - block_table = attn_metadata.block_table[num_decodes:] - dequantize_and_gather_k_cache( - kv[:chunk_size], - compressed_k_cache, - seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio, - gather_lens=None, - block_table=block_table[chunk_start:chunk_end], - block_size=attn_metadata.block_size // self.compress_ratio, - offset=0, - ) - - # Gather SWA KV - swa_block_table = swa_metadata.block_table[num_decodes:] - dequantize_and_gather_k_cache( - kv[:chunk_size], - swa_k_cache, - seq_lens=seq_lens[chunk_start:chunk_end], - gather_lens=gather_lens[chunk_start:chunk_end], - block_table=swa_block_table[chunk_start:chunk_end], - block_size=swa_metadata.block_size, - offset=N, - ) - - # Combine the topk indices and SWA indices for gathered KV cache - query_start = ( - query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base - ) - query_end = ( - query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base - ) - - combined_indices, combined_lens = combine_topk_swa_indices( - topk_indices[query_start:query_end], - query_start_loc[ - num_decodes + chunk_start : num_decodes + chunk_end + 1 - ], - seq_lens[chunk_start:chunk_end], - gather_lens[chunk_start:chunk_end], - self.window_size, - self.compress_ratio, - top_k, - M, - N, - ) - flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], - ) + self.impl_cls.forward_mqa(self, q, kv, positions, output) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 797179076969..9140a6fccd55 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -117,8 +117,11 @@ def get_impl_cls() -> type[SparseMLAAttentionImpl[Any]]: @classmethod def get_supported_head_sizes(cls) -> list[int]: - # V3.2: 576 (512 NoPE + 64 RoPE); DeepseekV4: 512 (448 NoPE + 64 RoPE) - return [512, 576] + # DeepSeek V3.2 layout: 512 NoPE + 64 RoPE = 576. + # DeepSeek V4 uses 448 NoPE + 64 RoPE = 512 and overrides this in + # vllm/models/deepseek_v4/nvidia/flashmla.py: + # DeepseekV4FlashMLASparseBackend.get_supported_head_sizes. + return [576] @classmethod def is_mla(cls) -> bool: @@ -147,31 +150,6 @@ def get_kv_cache_shape( return (num_blocks, block_size, head_size) -class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend): - @staticmethod - def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [256] - - @staticmethod - def get_name() -> str: - return "V4_FLASHMLA_SPARSE" - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if cache_dtype_str == "fp8_ds_mla": - # DeepseekV4 main MLA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale). - # head_size passed in is the semantic head_dim (512). - return (num_blocks, block_size, 584) - else: - return (num_blocks, block_size, head_size) - - @dataclass class FlashMLASparseMetadata(AttentionMetadata): num_reqs: int diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index bfa3b7285dbd..09e4100df1d3 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -113,7 +113,7 @@ def get_supported_head_sizes(cls) -> list[int]: @staticmethod def get_builder_cls() -> type["DeepseekSparseSWAMetadataBuilder"]: if current_platform.is_rocm(): - from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( + from vllm.models.deepseek_v4.amd.rocm import ( DeepseekV4ROCMAiterSparseSWAMetadataBuilder, )