From 0fb42331d12a38041abcd3d6cb4c4ea0d329467b Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Sat, 9 May 2026 00:32:12 -0700 Subject: [PATCH 01/24] Port DeepSeek V4 FlashInfer sparse MLA --- .../layers/deepseek_compressor.py | 46 +++- .../layers/deepseek_v4_attention.py | 258 +++++++++++++++--- vllm/model_executor/models/deepseek_v4.py | 13 + vllm/transformers_utils/config.py | 38 +++ vllm/utils/flashinfer.py | 14 + .../attention/backends/mla/flashmla_sparse.py | 2 + vllm/v1/attention/backends/mla/sparse_swa.py | 7 +- .../attention/ops/deepseek_v4_ops/__init__.py | 2 + .../ops/deepseek_v4_ops/cache_utils.py | 254 +++++++++++++++++ .../fused_compress_quant_cache.py | 125 +++++++++ vllm/v1/core/kv_cache_utils.py | 6 +- vllm/v1/kv_cache_interface.py | 7 +- 12 files changed, 730 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index 48628fec46e0..09632cee9dee 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -27,6 +27,7 @@ _fused_kv_compress_norm_rope_insert_indexer_attn, _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, _fused_kv_compress_norm_rope_insert_sparse_attn, + _fused_kv_compress_norm_rope_insert_sparse_attn_full_cache, ) from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( MXFP4_BLOCK_SIZE, @@ -129,11 +130,13 @@ def __init__( dtype: torch.dtype, compress_ratio: int, prefix: str, + alignment: int | None = 576, ): super().__init__() self.state_dim = state_dim self.dtype = dtype self.prefix = prefix + self.alignment = alignment self.kv_cache = torch.tensor([]) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -165,7 +168,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: head_size=self.state_dim, dtype=self.dtype, sliding_window=self.sliding_window, - alignment=576, # NOTE: FlashMLA requires 576B alignment + alignment=self.alignment, ) def forward(self): ... @@ -185,6 +188,7 @@ def __init__( prefix: str = "", k_cache_prefix="", use_fp4_cache: bool = False, + state_cache_alignment: int | None = 576, ): super().__init__() self.compress_ratio = compress_ratio @@ -232,6 +236,7 @@ def __init__( dtype=state_dtype, compress_ratio=compress_ratio, prefix=f"{prefix}.state_cache", + alignment=state_cache_alignment, ) # Save reference to static_forward_context for forward-time KV cache lookup. @@ -339,6 +344,45 @@ def forward( k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache + if self.head_dim == 512 and kv_cache.dtype != torch.uint8: + assert kv_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) + _fused_kv_compress_norm_rope_insert_sparse_attn_full_cache[(num_actual,)]( + # state cache + state_cache, + state_cache.stride(0), + state_cache.stride(1), + # metadata + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_table.stride(0), + block_size, + # RMSNorm + self.norm.weight, + self.rms_norm_eps, + # RoPE + cos_sin_cache, + cos_sin_cache.stride(0), + # KV cache + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], + # constexprs + HEAD_SIZE=self.head_dim, + TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim), + STATE_WIDTH=state_width, + COMPRESS_RATIO=self.compress_ratio, + OVERLAP=self.overlap, + ROPE_HEAD_DIM=self.rope_head_dim, + KV_BLOCK_STRIDE=kv_cache.stride(0), + KV_TOKEN_STRIDE=kv_cache.stride(1), + STORE_FP8=kv_cache.dtype == torch.float8_e4m3fn, + num_warps=self._num_warps, + **pdl_kwargs, + ) + return + self._fused_kernel[(num_actual,)]( # state cache state_cache, diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index e9a6ec9a587c..fc25cb727bbb 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -27,8 +27,13 @@ fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, + qnorm_rope_and_insert_full_k_cache, +) +from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + rocm_forward_decode_fallback, + rocm_inv_rope_einsum, + rocm_sparse_attn_prefill, ) -from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum if TYPE_CHECKING: from vllm.v1.attention.backends.mla.sparse_swa import ( @@ -55,10 +60,12 @@ GroupShape, ) from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_trtllm_batch_decode_sparse_mla_dsv4 from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, ) +from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -79,12 +86,60 @@ logger = init_logger(__name__) +_FLASHINFER_DSV4_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +_flashinfer_dsv4_workspace_by_device: dict[torch.device, torch.Tensor] = {} + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time # reservation in attention_impl's dummy-run branch). PREFILL_CHUNK_SIZE = 4 +def _normalize_dsv4_kv_cache_dtype( + cache_config: CacheConfig | None, +) -> str: + kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "auto" + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + assert cache_config is not None + cache_config.cache_dtype = "fp8_ds_mla" + return "fp8_ds_mla" + return kv_cache_dtype + + +def _dsv4_kv_cache_torch_dtype( + kv_cache_dtype: str, + vllm_config: VllmConfig, +) -> torch.dtype: + if kv_cache_dtype == "fp8_ds_mla": + return torch.uint8 + if kv_cache_dtype == "fp8_inc": + return torch.float8_e4m3fn + if kv_cache_dtype == "bfloat16": + return torch.bfloat16 + if kv_cache_dtype == "auto": + dtype = kv_cache_dtype_str_to_dtype(kv_cache_dtype, vllm_config.model_config) + if dtype == torch.bfloat16: + return dtype + raise ValueError( + "DeepSeek V4 FlashInfer sparse MLA supports only BF16 or per-tensor " + f"FP8 E4M3 KV cache; got kv_cache_dtype={kv_cache_dtype}. Use " + "`bfloat16`/`auto` for BF16, `fp8_inc` for per-tensor FP8, or " + "`fp8`/`fp8_ds_mla` for the legacy UE8M0 FlashMLA path." + ) + + +def _get_flashinfer_dsv4_workspace(device: torch.device) -> torch.Tensor: + workspace = _flashinfer_dsv4_workspace_by_device.get(device) + if workspace is None: + workspace = torch.zeros( + _FLASHINFER_DSV4_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, + ) + _flashinfer_dsv4_workspace_by_device[device] = workspace + return workspace + + @dataclass class DeepseekV4MLAModules: """Modules used in DeepseekV4 MLA.""" @@ -230,10 +285,16 @@ def __init__( self.ln_events = [torch.cuda.Event() for _ in range(4)] assert cache_config is not None, "DeepseekV4 attention requires cache_config" + kv_cache_dtype = _normalize_dsv4_kv_cache_dtype(cache_config) + kv_cache_torch_dtype = _dsv4_kv_cache_torch_dtype( + kv_cache_dtype, mla_modules.vllm_config + ) + if kv_cache_dtype == "fp8_ds_mla": + logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") self.swa_cache_layer = DeepseekV4SWACache( head_dim=self.head_dim, window_size=self.window_size, - dtype=torch.uint8, + dtype=kv_cache_torch_dtype, prefix=f"{prefix}.swa_cache", cache_config=cache_config, ) @@ -278,6 +339,12 @@ def __init__( rotate=True, prefix=f"{prefix}.compressor", k_cache_prefix=self.mla_attn.prefix, + # Legacy FlashMLA state/KV pages need 576B alignment. The + # FlashInfer BF16/per-tensor FP8 path shares state pages with + # contiguous C4 KV pages, so padding would break page matching. + state_cache_alignment=( + 576 if self.mla_attn.kv_cache_dtype == "fp8_ds_mla" else None + ), ) def forward( @@ -535,22 +602,35 @@ def _fused_qnorm_rope_kv_insert( assert swa_metadata is not None swa_kv_cache = self.swa_cache_layer.kv_cache - swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) # Horizontally fused: # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE - # KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert + # KV side: GPT-J RoPE + paged cache insert. The uint8 cache keeps the + # legacy UE8M0 layout; BF16/FP8 caches store the full 512-wide vector. # kv is unchanged; mla_attn reads kv solely via swa_kv_cache. - torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - q, - kv, - swa_kv_cache_2d, - swa_metadata.slot_mapping, - positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, - self.eps, - swa_metadata.block_size, - ) + if swa_kv_cache.dtype == torch.uint8: + swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + q, + kv, + swa_kv_cache_2d, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + self.eps, + swa_metadata.block_size, + ) + else: + qnorm_rope_and_insert_full_k_cache( + q, + kv, + swa_kv_cache, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + self.eps, + swa_metadata.block_size, + ) def deepseek_v4_attention( @@ -687,27 +767,13 @@ def __init__( vllm_config.scheduler_config.max_num_batched_tokens ) self.max_model_len = vllm_config.model_config.max_model_len - # DeepseekV4 only supports fp8 kv-cache format for now. - kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" - - assert kv_cache_dtype.startswith("fp8"), ( - f"DeepseekV4 only supports fp8 kv-cache format for now, " - f"got {kv_cache_dtype}" - ) assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), ( - "Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now" + "DeepseekV4 requires the sparse MLA metadata/cache backend" + ) + kv_cache_dtype = _normalize_dsv4_kv_cache_dtype(cache_config) + self.kv_cache_torch_dtype = _dsv4_kv_cache_torch_dtype( + kv_cache_dtype, vllm_config ) - # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format - # Automatically convert fp8 kv-cache format to "fp8_ds_mla" - if ( - issubclass(self.get_attn_backend(), FlashMLASparseBackend) - and kv_cache_dtype.startswith("fp8") - and kv_cache_dtype != "fp8_ds_mla" - ): - assert cache_config is not None - cache_config.cache_dtype = "fp8_ds_mla" - kv_cache_dtype = "fp8_ds_mla" - logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") self.kv_cache_dtype = kv_cache_dtype @@ -738,10 +804,14 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: block_size=vllm_config.cache_config.block_size, num_kv_heads=1, head_size=self.head_dim, - dtype=torch.uint8, + dtype=self.kv_cache_torch_dtype, compress_ratio=self.compress_ratio, cache_dtype_str=self.kv_cache_dtype, - alignment=576, # NOTE: FlashMLA requires 576B alignment + # FlashMLA's legacy fp8_ds_mla layout needs 576B page alignment. + # FlashInfer DSV4 BF16/per-tensor FP8 sparse decode treats the KV + # pool as a flat contiguous token array, so padding would skew + # physical sparse indices after the first page. + alignment=576 if self.kv_cache_dtype == "fp8_ds_mla" else None, model_version="deepseek_v4", ) @@ -849,6 +919,38 @@ def _forward_decode( swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens + if current_platform.is_rocm(): + rocm_forward_decode_fallback( + q=q, + kv_cache=kv_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + attn_sink=self.attn_sink, + scale=self.scale, + head_dim=self.head_dim, + nope_head_dim=self.nope_head_dim, + rope_head_dim=self.rope_head_dim, + output=output, + ) + return + + if self.kv_cache_torch_dtype != torch.uint8: + self._forward_decode_flashinfer( + q=q, + kv_cache=kv_cache, + swa_metadata=swa_metadata, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + output=output, + ) + return + # We treat queries in the same seq as different queries # and later we only attend by generated indices. # q arrives pre-padded to self.padded_heads by the outer wrapper. @@ -903,6 +1005,92 @@ def _forward_decode( out=output.unsqueeze(1), ) + def _forward_decode_flashinfer( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_metadata: "DeepseekSparseSWAMetadata", + swa_only: bool, + topk_indices: torch.Tensor | None, + topk_lens: torch.Tensor | None, + swa_indices: torch.Tensor, + output: torch.Tensor, + ) -> None: + assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn) + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + assert swa_metadata.seq_lens is not None + assert swa_metadata.query_start_loc is not None + assert swa_metadata.query_start_loc_cpu is not None + + swa_indices_2d = swa_indices.view(num_decode_tokens, -1).contiguous() + if swa_indices_2d.shape[-1] != self.window_size: + raise ValueError( + f"DeepSeek V4 FlashInfer path expects {self.window_size} SWA " + f"indices, got {swa_indices_2d.shape[-1]}" + ) + + if swa_only: + compressed_kv_cache = self.swa_cache_layer.kv_cache + compressed_indices = torch.full( + (num_decode_tokens, 4), + -1, + dtype=torch.int32, + device=q.device, + ) + sparse_topk_lens = torch.full( + (num_decode_tokens,), + self.window_size, + dtype=torch.int32, + device=q.device, + ) + else: + assert kv_cache is not None + assert topk_indices is not None + assert topk_lens is not None + compressed_kv_cache = kv_cache + compressed_indices = topk_indices.view(num_decode_tokens, -1).contiguous() + sparse_topk_lens = (topk_lens + self.window_size).to(torch.int32) + + sparse_indices = torch.cat((swa_indices_2d, compressed_indices), dim=-1) + if sparse_indices.shape[-1] % 4 != 0: + pad = 4 - sparse_indices.shape[-1] % 4 + sparse_indices = F.pad(sparse_indices, (0, pad), value=-1) + sparse_indices = sparse_indices.contiguous() + + query_start_loc = swa_metadata.query_start_loc[: num_decodes + 1] + query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_decodes + 1] + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_q_len = int(query_lens_cpu.max().item()) + + seq_lens = swa_metadata.seq_lens[:num_decodes].to(torch.int32) + query = q.contiguous() + bmm1_scale: float | torch.Tensor = self.scale + bmm2_scale: float | torch.Tensor = 1.0 + if self.kv_cache_torch_dtype == torch.float8_e4m3fn: + query = query.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + bmm1_scale = torch.tensor( + [self.scale], dtype=torch.float32, device=q.device + ) + bmm2_scale = torch.ones(1, dtype=torch.float32, device=q.device) + + flashinfer_trtllm_batch_decode_sparse_mla_dsv4( + query=query, + swa_kv_cache=self.swa_cache_layer.kv_cache, + workspace_buffer=_get_flashinfer_dsv4_workspace(q.device), + sparse_indices=sparse_indices, + compressed_kv_cache=compressed_kv_cache, + sparse_topk_lens=sparse_topk_lens.contiguous(), + seq_lens=seq_lens.contiguous(), + out=output, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + sinks=self.attn_sink, + kv_layout="HND", + cum_seq_lens_q=query_start_loc.contiguous(), + max_q_len=max_q_len, + ) + def _forward_prefill( self, q: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 791d9b1bf5ed..45b33035a76f 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -172,6 +172,19 @@ def expert_dtype(self) -> str: @property def is_scale_e8m0(self) -> bool: + try: + hf_config = get_current_vllm_config().model_config.hf_config + except Exception: + hf_config = None + + scale_fmt = getattr(hf_config, "scale_fmt", None) + if scale_fmt is None and hf_config is not None: + quantization_config = getattr(hf_config, "quantization_config", None) + if isinstance(quantization_config, dict): + scale_fmt = quantization_config.get("scale_fmt") + if scale_fmt is not None: + return scale_fmt == "ue8m0" + # FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert # checkpoints (Flash-Base) store them as float32. return self.expert_dtype == "fp4" diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index e6c497c0b450..5db672537ad6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -63,6 +63,8 @@ from transformers import AutoConfig MISTRAL_CONFIG_NAME = "params.json" +DEEPSEEK_V4_INFERENCE_CONFIG_NAME = "inference/config.json" +DEEPSEEK_V4_GLOBAL_INFERENCE_FIELDS = ("expert_dtype", "scale_fmt") logger = init_logger(__name__) @@ -84,6 +86,40 @@ def __getitem__(self, key): return getattr(configs, value) +def _maybe_apply_deepseek_v4_inference_config( + config: PretrainedConfig, + model: str, + revision: str | None, +) -> None: + """Promote DeepSeek V4 inference config fields into hf_config.""" + if getattr(config, "model_type", None) != "deepseek_v4": + return + if not file_or_path_exists(model, DEEPSEEK_V4_INFERENCE_CONFIG_NAME, revision): + return + + inference_config = get_hf_file_to_dict( + DEEPSEEK_V4_INFERENCE_CONFIG_NAME, model, revision + ) + updates = { + key: inference_config[key] + for key in DEEPSEEK_V4_GLOBAL_INFERENCE_FIELDS + if key in inference_config and not hasattr(config, key) + } + if "scale_fmt" not in updates and not hasattr(config, "scale_fmt"): + quantization_config = getattr(config, "quantization_config", None) + if isinstance(quantization_config, dict): + scale_fmt = quantization_config.get("scale_fmt") + if scale_fmt is not None: + updates["scale_fmt"] = scale_fmt + + if updates: + config.update(updates) + logger.info_once( + "Applied DeepSeek V4 inference globals to hf_config: %s", + tuple(sorted(updates)), + ) + + _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( afmoe="AfmoeConfig", bagel="BagelConfig", @@ -815,6 +851,8 @@ def apply_gguf_default(key: str, gguf_default: Any): scale_fmt, ) + _maybe_apply_deepseek_v4_inference_config(config, model, revision) + if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) config.update(hf_overrides_kw) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 44fcc19c2d2b..886edf19d898 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -71,6 +71,14 @@ def _missing(*_: Any, **__: Any) -> NoReturn: ) +def _missing_dsv4_sparse_mla(*_: Any, **__: Any) -> NoReturn: + raise RuntimeError( + "flashinfer.mla.trtllm_batch_decode_sparse_mla_dsv4 is not available. " + "Install a FlashInfer build that includes DeepSeek V4 sparse MLA " + "TRTLLM-GEN support." + ) + + def _get_submodule(module_name: str) -> Any | None: """Safely import a submodule and return it, or None if not available.""" try: @@ -137,6 +145,11 @@ def wrapper(*args, **kwargs): trtllm_fp4_block_scale_moe = _lazy_import_wrapper( "flashinfer", "trtllm_fp4_block_scale_moe" ) +flashinfer_trtllm_batch_decode_sparse_mla_dsv4 = _lazy_import_wrapper( + "flashinfer.mla", + "trtllm_batch_decode_sparse_mla_dsv4", + fallback_fn=_missing_dsv4_sparse_mla, +) # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( "flashinfer.autotuner", @@ -891,6 +904,7 @@ def is_flashinfer_cudnn_fp8_prefill_attn_supported() -> bool: "flashinfer_cute_dsl_fused_moe_nvfp4", "flashinfer_convert_sf_to_mma_layout", "trtllm_fp4_block_scale_moe", + "flashinfer_trtllm_batch_decode_sparse_mla_dsv4", "autotune", "has_flashinfer_moe", "has_flashinfer_comm", diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 797179076969..d0cd637ac33b 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -95,8 +95,10 @@ class FlashMLASparseBackend(AttentionBackend): supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "bfloat16", + "fp8_inc", "fp8_ds_mla", "fp8", # alias for fp8_ds_mla + "fp8_e4m3", # alias for fp8_ds_mla ] @staticmethod diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index bfa3b7285dbd..ccf2acab3b61 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -73,7 +73,7 @@ def __init__( # determines the SWA block size of 64 tokens per block. # TODO(yifan): make SWA block size automatically determined and configurable. self.block_size = 64 - assert self.dtype == torch.uint8 + assert self.dtype in (torch.uint8, torch.bfloat16, torch.float8_e4m3fn) def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: return SlidingWindowMLASpec( @@ -83,7 +83,10 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: dtype=self.dtype, sliding_window=self.window_size, cache_dtype_str=self.cache_config.cache_dtype, - alignment=576, # NOTE: FlashMLA requires 576B alignment + # The legacy fp8_ds_mla FlashMLA layout needs 576B alignment. + # FlashInfer DSV4 BF16/per-tensor FP8 reads sparse indices as + # flat token offsets, so those cache pages must remain contiguous. + alignment=576 if self.cache_config.cache_dtype == "fp8_ds_mla" else None, model_version="deepseek_v4", ) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index 959a79f292a5..50aa3c6865dd 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -5,6 +5,7 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + qnorm_rope_and_insert_full_k_cache, quantize_and_insert_k_cache, ) from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant @@ -19,5 +20,6 @@ "fused_indexer_q_rope_quant", "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", + "qnorm_rope_and_insert_full_k_cache", "quantize_and_insert_k_cache", ] diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index dfb107b515eb..2060e274ef8d 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -20,6 +20,184 @@ from vllm.utils.import_utils import has_cutedsl +@triton.jit +def _apply_gptj_rope_512( + values, + position, + cos_sin_cache_ptr, + cos_sin_stride, + HEAD_SIZE: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + NUM_PAIRS: tl.constexpr = BLOCK_SIZE // 2 + NOPE_PAIRS: tl.constexpr = (HEAD_SIZE - ROPE_HEAD_DIM) // 2 + HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 + + pairs = tl.reshape(values, (NUM_PAIRS, 2)) + even, odd = tl.split(pairs) + + pair_idx = tl.arange(0, NUM_PAIRS) + rope_pair_local = pair_idx - NOPE_PAIRS + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + + cache_base = cos_sin_cache_ptr + position * cos_sin_stride + cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0) + + new_even = tl.where(is_rope_pair, even * cos_v - odd * sin_v, even) + new_odd = tl.where(is_rope_pair, odd * cos_v + even * sin_v, odd) + return tl.interleave(new_even, new_odd) + + +@triton.jit +def _qnorm_rope_kernel( + q_ptr, + q_stride0, + q_stride1, + positions_ptr, + cos_sin_cache_ptr, + cos_sin_stride, + eps, + HEAD_SIZE: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < HEAD_SIZE + + q_row = q_ptr + token_idx * q_stride0 + head_idx * q_stride1 + values = tl.load(q_row + offsets, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(values * values, axis=0) / HEAD_SIZE + values *= tl.rsqrt(variance + eps) + + position = tl.load(positions_ptr + token_idx) + values = _apply_gptj_rope_512( + values, + position, + cos_sin_cache_ptr, + cos_sin_stride, + HEAD_SIZE, + ROPE_HEAD_DIM, + BLOCK_SIZE, + ) + tl.store(q_row + offsets, values.to(tl.bfloat16), mask=mask) + + +@triton.jit +def _kv_rope_insert_full_cache_kernel( + kv_ptr, + kv_stride0, + slot_mapping_ptr, + positions_ptr, + cos_sin_cache_ptr, + cos_sin_stride, + k_cache_ptr, + cache_stride0, + cache_stride1, + cache_block_size, + HEAD_SIZE: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + STORE_FP8: tl.constexpr, +): + token_idx = tl.program_id(0) + slot_idx = tl.load(slot_mapping_ptr + token_idx) + if slot_idx < 0: + return + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < HEAD_SIZE + values = tl.load( + kv_ptr + token_idx * kv_stride0 + offsets, mask=mask, other=0.0 + ).to(tl.float32) + + position = tl.load(positions_ptr + token_idx) + values = _apply_gptj_rope_512( + values, + position, + cos_sin_cache_ptr, + cos_sin_stride, + HEAD_SIZE, + ROPE_HEAD_DIM, + BLOCK_SIZE, + ) + + block_idx = slot_idx // cache_block_size + pos_in_block = slot_idx % cache_block_size + cache_row = ( + k_cache_ptr + + block_idx.to(tl.int64) * cache_stride0 + + (pos_in_block * cache_stride1) + ) + if STORE_FP8: + values = tl.clamp(values, -448.0, 448.0) + tl.store(cache_row + offsets, values.to(tl.float8e4nv), mask=mask) + else: + tl.store(cache_row + offsets, values.to(tl.bfloat16), mask=mask) + + +def qnorm_rope_and_insert_full_k_cache( + q: torch.Tensor, + kv: torch.Tensor, + k_cache: torch.Tensor, + slot_mapping: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + eps: float, + cache_block_size: int, +) -> None: + """Apply DeepSeek V4 Q RMSNorm/RoPE and insert full-width BF16/FP8 KV. + + This path is for FlashInfer's DeepSeek V4 sparse MLA launcher, which accepts + full 512-wide BF16 or per-tensor FP8 E4M3 KV pools. The existing 584-byte + UE8M0 cache path remains handled by the CUDA fused op. + """ + assert q.dim() == 3 and q.shape[-1] == 512 + assert kv.dim() == 2 and kv.shape[-1] == 512 + assert q.dtype == torch.bfloat16 + assert kv.dtype == torch.bfloat16 + assert k_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) + assert cos_sin_cache.dtype == torch.float32 + + num_tokens_full, num_heads, _ = q.shape + _qnorm_rope_kernel[(num_tokens_full, num_heads)]( + q, + q.stride(0), + q.stride(1), + positions, + cos_sin_cache, + cos_sin_cache.stride(0), + eps, + HEAD_SIZE=512, + ROPE_HEAD_DIM=64, + BLOCK_SIZE=512, + num_warps=8, + ) + + num_tokens_insert = slot_mapping.shape[0] + _kv_rope_insert_full_cache_kernel[(num_tokens_insert,)]( + kv, + kv.stride(0), + slot_mapping, + positions, + cos_sin_cache, + cos_sin_cache.stride(0), + k_cache, + k_cache.stride(0), + k_cache.stride(1), + cache_block_size, + HEAD_SIZE=512, + ROPE_HEAD_DIM=64, + BLOCK_SIZE=512, + STORE_FP8=k_cache.dtype == torch.float8_e4m3fn, + num_warps=8, + ) + + @triton.jit def quantize_and_insert_k_kernel( # Input tensors @@ -304,6 +482,57 @@ def _dequantize_and_gather_k_kernel( tl.store(output_row_ptr + bf16_output_offset + chunk_offsets, bf16_vals) +@triton.jit +def _gather_full_k_cache_kernel( + out_ptr, + out_stride0, + out_stride1, + k_cache_ptr, + k_cache_stride0, + k_cache_stride1, + seq_lens_ptr, + block_table_ptr, + offset, + gather_lens_ptr, + max_blocks_per_seq: tl.constexpr, + cache_block_size: tl.constexpr, + output_dim: tl.constexpr, + STORE_FP8: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + worker_id = tl.program_id(1) + num_workers = tl.num_programs(1) + + seq_len = tl.load(seq_lens_ptr + batch_idx) + if gather_lens_ptr is not None: + gather_len = tl.load(gather_lens_ptr + batch_idx) + else: + gather_len = seq_len + start_pos = seq_len - gather_len + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < output_dim + for i in range(worker_id, gather_len, num_workers): + pos = start_pos + i + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + block_table_row_ptr = block_table_ptr + batch_idx * max_blocks_per_seq + physical_block_idx = tl.load(block_table_row_ptr + block_in_seq) + + cache_row = ( + k_cache_ptr + + physical_block_idx.to(tl.int64) * k_cache_stride0 + + pos_in_block * k_cache_stride1 + ) + values = tl.load(cache_row + offsets, mask=mask, other=0.0) + if STORE_FP8: + values = values.to(tl.float32) + + out_row = out_ptr + batch_idx * out_stride0 + (offset + i) * out_stride1 + tl.store(out_row + offsets, values.to(tl.bfloat16), mask=mask) + + def dequantize_and_gather_k_cache_triton( # [num_reqs, max_num_tokens, head_size] out: torch.Tensor, @@ -364,6 +593,31 @@ def dequantize_and_gather_k_cache( block_size: int, offset: int, ) -> None: + if k_cache.dtype != torch.uint8: + assert k_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) + assert k_cache.dim() == 3 and k_cache.shape[-1] == 512 + num_reqs = seq_lens.shape[0] + NUM_WORKERS = 128 + _gather_full_k_cache_kernel[(num_reqs, NUM_WORKERS)]( + out, + out.stride(0), + out.stride(1), + k_cache, + k_cache.stride(0), + k_cache.stride(1), + seq_lens, + block_table, + offset, + gather_lens, + max_blocks_per_seq=block_table.shape[-1], + cache_block_size=block_size, + output_dim=512, + STORE_FP8=k_cache.dtype == torch.float8_e4m3fn, + BLOCK_SIZE=512, + num_warps=8, + ) + return + if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. from .dequant_gather_k_cutedsl import dequantize_and_gather_k_cache_cutedsl diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py index 2f97d8733c95..5c817e789477 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py @@ -214,6 +214,131 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( tl.store(bf16_ptr + rope_local, result.to(tl.bfloat16), mask=is_rope) +@triton.jit +def _fused_kv_compress_norm_rope_insert_sparse_attn_full_cache( + # ── state cache (compressor internal state) ── + state_cache_ptr, + state_cache_stride0, + state_cache_stride1, + # ── metadata ── + token_to_req_indices_ptr, + positions_ptr, + slot_mapping_ptr, + block_table_ptr, + block_table_stride, + block_size, + # ── RMSNorm ── + rms_norm_weight_ptr, + rms_norm_eps, + # ── RoPE ── + cos_sin_cache_ptr, + cos_sin_stride, + # ── KV cache output ── + k_cache_ptr, + kv_slot_mapping_ptr, + kv_cache_block_size, + # ── constexprs ── + HEAD_SIZE: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + OVERLAP: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + KV_BLOCK_STRIDE: tl.constexpr, + KV_TOKEN_STRIDE: tl.constexpr, + STORE_FP8: tl.constexpr, +): + """Fused compress/RMSNorm/RoPE store for BF16 or per-tensor FP8 caches.""" + token_idx = tl.program_id(0) + + slot_id = tl.load(slot_mapping_ptr + token_idx) + if slot_id < 0: + return + + position = tl.load(positions_ptr + token_idx) + if (position + 1) % COMPRESS_RATIO != 0: + return + + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + + start = position - (1 + OVERLAP) * COMPRESS_RATIO + 1 + tokens = tl.arange(0, (1 + OVERLAP) * COMPRESS_RATIO) + pos = start + tokens + mask_pos = pos >= 0 + + block_indices = pos // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=mask_pos, + other=0, + ) + block_offsets = pos % block_size + head_offset = (tokens >= COMPRESS_RATIO).to(tl.int32) * HEAD_SIZE + + block = tl.arange(0, TRITON_BLOCK_SIZE) + mask = block < HEAD_SIZE + block_numbers_i64 = block_numbers.to(tl.int64) + row_base = ( + state_cache_ptr + + block_numbers_i64 * state_cache_stride0 + + block_offsets * state_cache_stride1 + + head_offset + ) + combined_mask = mask_pos[:, None] & mask[None, :] + + score = tl.load( + row_base[:, None] + STATE_WIDTH + block[None, :], + mask=combined_mask, + other=float("-inf"), + ) + score = tl.softmax(score, dim=0) + kv = tl.load(row_base[:, None] + block[None, :], mask=combined_mask, other=0.0) + compressed_kv = tl.sum(kv * score, axis=0) + + rms_w = tl.load(rms_norm_weight_ptr + block, mask=mask, other=0.0) + variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_SIZE + normed = compressed_kv * tl.rsqrt(variance + rms_norm_eps) * rms_w + + kv_slot_idx = tl.load(kv_slot_mapping_ptr + token_idx) + if kv_slot_idx < 0: + return + kv_block_idx = kv_slot_idx // kv_cache_block_size + kv_pos_in_block = kv_slot_idx % kv_cache_block_size + cache_row = ( + k_cache_ptr + + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE + + kv_pos_in_block * KV_TOKEN_STRIDE + ) + + NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM + HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 + NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 + NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 + + pair_2d = tl.reshape(normed, (NUM_PAIRS, 2)) + even, odd = tl.split(pair_2d) + pair_idx = tl.arange(0, NUM_PAIRS) + rope_pair_local = pair_idx - NOPE_PAIRS + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + + compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO + cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride + cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0) + + new_even = tl.where(is_rope_pair, even * cos_v - odd * sin_v, even) + new_odd = tl.where(is_rope_pair, odd * cos_v + even * sin_v, odd) + result = tl.interleave(new_even, new_odd) + result = result.to(tl.bfloat16).to(tl.float32) + + if STORE_FP8: + result = tl.clamp(result, -448.0, 448.0) + tl.store(cache_row + block, result.to(tl.float8e4nv), mask=mask) + else: + tl.store(cache_row + block, result.to(tl.bfloat16), mask=mask) + + # ============================================================================= # Indexer path (head=128, all FP8, single quant block) # ============================================================================= diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..e0fa89b2c1c8 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1541,7 +1541,11 @@ def _get_kv_cache_groups_uniform_groups( for sm_spec in swa_mla_specs: sm_page_sizes = sm_spec.get_page_sizes() layers_per_size: dict[int, list[str]] = defaultdict(list) - assert max(sm_page_sizes) <= max(all_page_sizes) + if max(sm_page_sizes) > max(all_page_sizes): + raise AssertionError( + "DeepseekV4 SWA page size exceeds full-MLA page sizes: " + f"swa={sorted(sm_page_sizes)}, full={sorted(all_page_sizes)}" + ) # Unify page size by padding layers' page_size to the nearest larger page_size. # Compute candidate (nearest larger page_size) for each unique page size. diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index cf50dbff179a..56e8e1dfab83 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -482,10 +482,11 @@ def storage_block_size(self) -> int: @property def real_page_size_bytes(self) -> int: - if self.model_version == "deepseek_v4": - # DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token. + if self.model_version == "deepseek_v4" and self.cache_dtype_str == "fp8_ds_mla": + # DeepseekV4 legacy UE8M0 layout: + # 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token. return self.storage_block_size * 584 - assert self.model_version is None, ( + assert self.model_version in (None, "deepseek_v4"), ( f"Unsupported model version: {self.model_version}" ) return ( From 6d0ccbdb13ce7250357e5280c1985d4f32d38f98 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Sat, 9 May 2026 00:52:59 -0700 Subject: [PATCH 02/24] Optimize DeepSeek V4 FlashInfer decode --- .../layers/deepseek_v4_attention.py | 61 +++++++------- vllm/utils/flashinfer.py | 84 +++++++++++++++++++ .../attention/backends/mla/flashmla_sparse.py | 13 ++- vllm/v1/attention/backends/mla/sparse_swa.py | 10 +++ .../ops/deepseek_v4_ops/cache_utils.py | 11 ++- 5 files changed, 146 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index fc25cb727bbb..ed0ffbd1fffb 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -60,7 +60,7 @@ GroupShape, ) from vllm.platforms import current_platform -from vllm.utils.flashinfer import flashinfer_trtllm_batch_decode_sparse_mla_dsv4 +from vllm.utils.flashinfer import flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, @@ -776,6 +776,16 @@ def __init__( ) self.kv_cache_dtype = kv_cache_dtype + self.register_buffer( + "_flashinfer_fp8_bmm1_scale_log2", + torch.tensor([self.scale * 1.4426950408889634], dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "_flashinfer_fp8_bmm2_scale", + torch.ones(1, dtype=torch.float32), + persistent=False, + ) # Register with compilation context for metadata lookup compilation_config = vllm_config.compilation_config @@ -895,6 +905,9 @@ def _forward_decode( topk_indices = None topk_lens = None + use_flashinfer_dsv4 = ( + self.kv_cache_torch_dtype != torch.uint8 and not current_platform.is_rocm() + ) if not swa_only: assert attn_metadata is not None assert swa_metadata.is_valid_token is not None @@ -909,6 +922,7 @@ def _forward_decode( attn_metadata.block_table[:num_decodes], block_size, is_valid, + self.window_size if use_flashinfer_dsv4 else 0, ) topk_indices = global_indices.view(num_decode_tokens, 1, -1) else: @@ -1023,7 +1037,7 @@ def _forward_decode_flashinfer( assert swa_metadata.query_start_loc is not None assert swa_metadata.query_start_loc_cpu is not None - swa_indices_2d = swa_indices.view(num_decode_tokens, -1).contiguous() + swa_indices_2d = swa_indices.view(num_decode_tokens, -1) if swa_indices_2d.shape[-1] != self.window_size: raise ValueError( f"DeepSeek V4 FlashInfer path expects {self.window_size} SWA " @@ -1031,32 +1045,22 @@ def _forward_decode_flashinfer( ) if swa_only: + assert swa_metadata.decode_swa_sparse_topk_lens is not None compressed_kv_cache = self.swa_cache_layer.kv_cache - compressed_indices = torch.full( - (num_decode_tokens, 4), - -1, - dtype=torch.int32, - device=q.device, - ) - sparse_topk_lens = torch.full( - (num_decode_tokens,), - self.window_size, - dtype=torch.int32, - device=q.device, - ) + sparse_indices = swa_indices_2d + sparse_topk_lens = swa_metadata.decode_swa_sparse_topk_lens else: assert kv_cache is not None assert topk_indices is not None assert topk_lens is not None compressed_kv_cache = kv_cache compressed_indices = topk_indices.view(num_decode_tokens, -1).contiguous() - sparse_topk_lens = (topk_lens + self.window_size).to(torch.int32) - - sparse_indices = torch.cat((swa_indices_2d, compressed_indices), dim=-1) - if sparse_indices.shape[-1] % 4 != 0: - pad = 4 - sparse_indices.shape[-1] % 4 - sparse_indices = F.pad(sparse_indices, (0, pad), value=-1) - sparse_indices = sparse_indices.contiguous() + sparse_topk_lens = topk_lens + sparse_indices = torch.cat((swa_indices_2d, compressed_indices), dim=-1) + if sparse_indices.shape[-1] % 4 != 0: + pad = 4 - sparse_indices.shape[-1] % 4 + sparse_indices = F.pad(sparse_indices, (0, pad), value=-1) + sparse_indices = sparse_indices.contiguous() query_start_loc = swa_metadata.query_start_loc[: num_decodes + 1] query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_decodes + 1] @@ -1069,25 +1073,22 @@ def _forward_decode_flashinfer( bmm2_scale: float | torch.Tensor = 1.0 if self.kv_cache_torch_dtype == torch.float8_e4m3fn: query = query.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) - bmm1_scale = torch.tensor( - [self.scale], dtype=torch.float32, device=q.device - ) - bmm2_scale = torch.ones(1, dtype=torch.float32, device=q.device) + bmm1_scale = self._flashinfer_fp8_bmm1_scale_log2 + bmm2_scale = self._flashinfer_fp8_bmm2_scale - flashinfer_trtllm_batch_decode_sparse_mla_dsv4( + flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( query=query, swa_kv_cache=self.swa_cache_layer.kv_cache, workspace_buffer=_get_flashinfer_dsv4_workspace(q.device), sparse_indices=sparse_indices, compressed_kv_cache=compressed_kv_cache, - sparse_topk_lens=sparse_topk_lens.contiguous(), - seq_lens=seq_lens.contiguous(), + sparse_topk_lens=sparse_topk_lens, + seq_lens=seq_lens, out=output, bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, sinks=self.attn_sink, - kv_layout="HND", - cum_seq_lens_q=query_start_loc.contiguous(), + cum_seq_lens_q=query_start_loc, max_q_len=max_q_len, ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 886edf19d898..31e12883a1c6 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -150,6 +150,90 @@ def wrapper(*args, **kwargs): "trtllm_batch_decode_sparse_mla_dsv4", fallback_fn=_missing_dsv4_sparse_mla, ) + + +@functools.cache +def _get_dsv4_sparse_mla_raw_impl(): + if not has_flashinfer(): + return None + core = _get_submodule("flashinfer.mla._core") + if core is None: + return None + op = core.get_trtllm_gen_fmha_module() + run_func = getattr(op, "trtllm_paged_attention_decode_sparse_mla_dsv4", None) + if run_func is None: + return None + return run_func, core.device_support_pdl, core.get_device_sm_count + + +def flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( + *, + query: torch.Tensor, + swa_kv_cache: torch.Tensor, + workspace_buffer: torch.Tensor, + sparse_indices: torch.Tensor, + compressed_kv_cache: torch.Tensor, + sparse_topk_lens: torch.Tensor, + seq_lens: torch.Tensor, + out: torch.Tensor, + bmm1_scale: float | torch.Tensor = 1.0, + bmm2_scale: float | torch.Tensor = 1.0, + sinks: torch.Tensor | None = None, + cum_seq_lens_q: torch.Tensor | None = None, + max_q_len: int | None = None, + enable_pdl: bool | None = None, +) -> torch.Tensor: + """Unchecked DeepSeek V4 sparse MLA launcher for hot vLLM decode paths. + + The caller must provide HND-compatible 3D/4D KV caches, contiguous INT32 + metadata, a BF16 output tensor, and launcher-ready scale tensors. This skips + FlashInfer's Python validation, which otherwise adds syncs and pointwise + kernels on every attention layer. + """ + impl = _get_dsv4_sparse_mla_raw_impl() + if impl is None: + return _missing_dsv4_sparse_mla() + + run_func, device_support_pdl, get_device_sm_count = impl + if enable_pdl is None: + enable_pdl = device_support_pdl(query.device) + + if swa_kv_cache.ndim == 3: + swa_kv_cache = swa_kv_cache.unsqueeze(1) + if compressed_kv_cache.ndim == 3: + compressed_kv_cache = compressed_kv_cache.unsqueeze(1) + + if cum_seq_lens_q is None: + batch_size, q_len_per_request = query.shape[:2] + query_flat = query.flatten(0, 1) + else: + batch_size = cum_seq_lens_q.numel() - 1 + assert max_q_len is not None + q_len_per_request = max_q_len + query_flat = query + + run_func( + out, + query_flat, + compressed_kv_cache, + swa_kv_cache, + workspace_buffer, + sparse_indices, + seq_lens, + sparse_topk_lens, + bmm1_scale, + bmm2_scale, + batch_size, + q_len_per_request, + get_device_sm_count(query.device), + enable_pdl, + workspace_buffer.numel() * workspace_buffer.element_size(), + sinks, + cum_seq_lens_q, + ) + return out + + # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( "flashinfer.autotuner", diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index d0cd637ac33b..15b01d1f0e21 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -358,6 +358,11 @@ def __init__( if self.is_deepseek_v4: assert hasattr(self.kv_cache_spec, "compress_ratio") self.compress_ratio = self.kv_cache_spec.compress_ratio + self.sliding_window = hf_config.sliding_window + self.use_dsv4_flashinfer_decode = ( + self.kv_cache_spec.dtype != torch.uint8 + and not current_platform.is_rocm() + ) # Pre-allocate compressed slot mapping buffer for CUDA graph # address stability when compress_ratio > 1. if self.compress_ratio > 1: @@ -697,6 +702,9 @@ def _build_c128a_metadata( self.c128a_decode_lens_buffer, self.c128a_prefill_buffer, max_compressed_tokens=self.c128a_max_compressed, + decode_lens_base=( + self.sliding_window if self.use_dsv4_flashinfer_decode else 0 + ), ) result: dict[str, torch.Tensor | None] = {} @@ -1065,6 +1073,7 @@ def build_c128a_topk_metadata( decode_lens_buffer: torch.Tensor, prefill_buffer: torch.Tensor, max_compressed_tokens: int = 8192, + decode_lens_base: int = 0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Single kernel for all C128A tokens (decode + prefill). @@ -1099,6 +1108,7 @@ def build_c128a_topk_metadata( block_table.stride(0), block_size, slot_mapping, + decode_lens_base, BLOCK_SIZE=1024, ) return global_decode, decode_lens, prefill_local @@ -1123,6 +1133,7 @@ def _build_c128a_topk_metadata_kernel( block_table_stride, block_size, slot_mapping_ptr, + decode_lens_base: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) @@ -1158,7 +1169,7 @@ def _build_c128a_topk_metadata_kernel( tl.store( decode_lens_ptr + token_idx, - tl.where(is_valid_token, count, 0), + tl.where(is_valid_token, count, 0) + decode_lens_base, ) else: # --- Prefill: write local indices --- diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index ccf2acab3b61..860a5be57c1b 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -161,6 +161,7 @@ class DeepseekSparseSWAMetadata: token_to_req_indices: torch.Tensor | None = None # [num_tokens] decode_swa_indices: torch.Tensor | None = None # [num_decode_tokens, window_size] decode_swa_lens: torch.Tensor | None = None # [num_decode_tokens] + decode_swa_sparse_topk_lens: torch.Tensor | None = None # [num_decode_tokens] # Number of decode/prefill requests/tokens (batch is reordered: decodes first) num_decodes: int = 0 @@ -255,6 +256,12 @@ def __init__(self, *args, **kwargs): dtype=torch.int32, device=self.device, ) + self.decode_swa_sparse_topk_lens = torch.full( + (max_tokens,), + self.window_size, + dtype=torch.int32, + device=self.device, + ) self.is_valid_token = torch.zeros( max_tokens, dtype=torch.bool, @@ -340,6 +347,9 @@ def build( token_to_req_indices=token_to_req_indices, decode_swa_indices=self.decode_swa_indices[:num_decode_tokens], decode_swa_lens=self.decode_swa_lens[:num_decode_tokens], + decode_swa_sparse_topk_lens=( + self.decode_swa_sparse_topk_lens[:num_decode_tokens] + ), block_size=self.block_size, num_decodes=num_decodes, num_prefills=num_prefills, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 2060e274ef8d..d1763d9fae9e 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -638,6 +638,7 @@ def compute_global_topk_indices_and_lens( block_table: torch.Tensor, block_size: int, is_valid_token: torch.Tensor, + topk_lens_base: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local topk indices to global KV cache slots and count valid entries. @@ -645,6 +646,7 @@ def compute_global_topk_indices_and_lens( 1. Block-table lookup (local index → global slot id) 2. Valid-entry counting (topk_lens per token) 3. Masking padding tokens to length 0 + 4. Optional constant top-k length base for callers with fixed prefixes """ num_tokens = topk_indices.shape[0] global_topk_indices = torch.empty_like(topk_indices) @@ -661,6 +663,7 @@ def compute_global_topk_indices_and_lens( block_table.stride(0), block_size, is_valid_token, + topk_lens_base, TRITON_BLOCK_SIZE=1024, ) return global_topk_indices, topk_lens @@ -679,6 +682,7 @@ def _compute_global_topk_indices_and_lens_kernel( block_table_stride, block_size, is_valid_token_ptr, + topk_lens_base: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) @@ -713,8 +717,11 @@ def _compute_global_topk_indices_and_lens_kernel( ) count += tl.sum(is_valid.to(tl.int32), axis=0) - # Zero out length for padding tokens. - tl.store(topk_lens_ptr + token_idx, tl.where(is_valid_token, count, 0)) + # Mask compressed entries for padding tokens, then add any fixed prefix. + tl.store( + topk_lens_ptr + token_idx, + tl.where(is_valid_token, count, 0) + topk_lens_base, + ) # FlashMLA sparse prefill asserts `params.topk % B_TOPK == 0` (see From 1e2a6853e4b9e7314f7b20f978c1551bfd20d475 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Sat, 9 May 2026 03:27:36 -0700 Subject: [PATCH 03/24] Fix DeepSeek V4 FlashInfer FP8 cache scaling --- vllm/envs.py | 20 ++++++++ .../layers/deepseek_compressor.py | 4 +- .../layers/deepseek_v4_attention.py | 50 +++++++++++++++++-- .../ops/deepseek_v4_ops/cache_utils.py | 13 ++++- .../fused_compress_quant_cache.py | 4 +- 5 files changed, 84 insertions(+), 7 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 03230eed0688..45a1cbd9e133 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -177,6 +177,9 @@ VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_USE_FLASHINFER_MOE_INT4: bool = False + VLLM_DSV4_FLASHINFER_FP8_SCALE: float | None = None + VLLM_DSV4_FLASHINFER_FP8_Q_SCALE: float | None = None + VLLM_DSV4_FLASHINFER_FP8_KV_SCALE: float | None = None VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) @@ -289,6 +292,12 @@ def maybe_convert_int(value: str | None) -> int | None: return int(value) +def maybe_convert_float(value: str | None) -> float | None: + if value is None: + return None + return float(value) + + def maybe_convert_bool(value: str | None) -> bool | None: if value is None: return None @@ -1319,6 +1328,17 @@ def _get_or_set_default() -> str: "VLLM_USE_FLASHINFER_MOE_INT4": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_INT4", "0")) ), + # Global and Q/KV-specific per-tensor scales for DeepSeek V4 FlashInfer + # sparse MLA FP8 cache/query tensors. + "VLLM_DSV4_FLASHINFER_FP8_SCALE": lambda: maybe_convert_float( + os.getenv("VLLM_DSV4_FLASHINFER_FP8_SCALE") + ), + "VLLM_DSV4_FLASHINFER_FP8_Q_SCALE": lambda: maybe_convert_float( + os.getenv("VLLM_DSV4_FLASHINFER_FP8_Q_SCALE") + ), + "VLLM_DSV4_FLASHINFER_FP8_KV_SCALE": lambda: maybe_convert_float( + os.getenv("VLLM_DSV4_FLASHINFER_FP8_KV_SCALE") + ), # If set to 1, use the FlashInfer # MXFP8 (activation) x MXFP4 (weight) MoE backend. "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool( diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index 09632cee9dee..edc71a686db3 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -342,7 +342,8 @@ def forward( # - position used: (positions // compress_ratio) * compress_ratio cos_sin_cache = rotary_emb.cos_sin_cache k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) - kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache + k_cache_layer = self._static_forward_context[self.k_cache_prefix] + kv_cache = k_cache_layer.kv_cache if self.head_dim == 512 and kv_cache.dtype != torch.uint8: assert kv_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) @@ -368,6 +369,7 @@ def forward( kv_cache, k_cache_metadata.slot_mapping, kv_cache.shape[1], + k_cache_layer._flashinfer_fp8_kv_scale, # constexprs HEAD_SIZE=self.head_dim, TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim), diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index ed0ffbd1fffb..f35fe382419c 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,6 +4,7 @@ DeepseekV4 MLA Attention Layer """ +import os from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -88,6 +89,8 @@ _FLASHINFER_DSV4_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 _flashinfer_dsv4_workspace_by_device: dict[torch.device, torch.Tensor] = {} +_FLASHINFER_FP8_LOG2E = 1.4426950408889634 +_DEFAULT_FLASHINFER_DSV4_FP8_SCALE = 1.0 / 32.0 # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time @@ -95,6 +98,19 @@ PREFILL_CHUNK_SIZE = 4 +def _get_dsv4_flashinfer_fp8_scale(kind: str) -> float: + specific_name = f"VLLM_DSV4_FLASHINFER_FP8_{kind.upper()}_SCALE" + for env_name in (specific_name, "VLLM_DSV4_FLASHINFER_FP8_SCALE"): + value = os.environ.get(env_name) + if value is None: + continue + scale = float(value) + if scale <= 0.0: + raise ValueError(f"{env_name} must be positive, got {value!r}") + return scale + return _DEFAULT_FLASHINFER_DSV4_FP8_SCALE + + def _normalize_dsv4_kv_cache_dtype( cache_config: CacheConfig | None, ) -> str: @@ -630,6 +646,7 @@ def _fused_qnorm_rope_kv_insert( self.rotary_emb.cos_sin_cache, self.eps, swa_metadata.block_size, + self.mla_attn._flashinfer_fp8_kv_scale, ) @@ -776,14 +793,37 @@ def __init__( ) self.kv_cache_dtype = kv_cache_dtype + fp8_q_scale = 1.0 + fp8_kv_scale = 1.0 + if self.kv_cache_torch_dtype == torch.float8_e4m3fn: + fp8_q_scale = _get_dsv4_flashinfer_fp8_scale("q") + fp8_kv_scale = _get_dsv4_flashinfer_fp8_scale("kv") + self.register_buffer( + "_flashinfer_fp8_q_scale", + torch.tensor([fp8_q_scale], dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "_flashinfer_fp8_q_scale_inv", + torch.tensor([1.0 / fp8_q_scale], dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "_flashinfer_fp8_kv_scale", + torch.tensor([fp8_kv_scale], dtype=torch.float32), + persistent=False, + ) self.register_buffer( "_flashinfer_fp8_bmm1_scale_log2", - torch.tensor([self.scale * 1.4426950408889634], dtype=torch.float32), + torch.tensor( + [self.scale * fp8_q_scale * fp8_kv_scale * _FLASHINFER_FP8_LOG2E], + dtype=torch.float32, + ), persistent=False, ) self.register_buffer( "_flashinfer_fp8_bmm2_scale", - torch.ones(1, dtype=torch.float32), + torch.tensor([fp8_kv_scale], dtype=torch.float32), persistent=False, ) @@ -1072,7 +1112,9 @@ def _forward_decode_flashinfer( bmm1_scale: float | torch.Tensor = self.scale bmm2_scale: float | torch.Tensor = 1.0 if self.kv_cache_torch_dtype == torch.float8_e4m3fn: - query = query.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + query = (query * self._flashinfer_fp8_q_scale_inv).clamp( + -448.0, 448.0 + ).to(torch.float8_e4m3fn) bmm1_scale = self._flashinfer_fp8_bmm1_scale_log2 bmm2_scale = self._flashinfer_fp8_bmm2_scale @@ -1166,6 +1208,7 @@ def _forward_prefill( block_table=block_table[chunk_start:chunk_end], block_size=attn_metadata.block_size // self.compress_ratio, offset=0, + fp8_scale=self._flashinfer_fp8_kv_scale, ) # Gather SWA KV @@ -1178,6 +1221,7 @@ def _forward_prefill( block_table=swa_block_table[chunk_start:chunk_end], block_size=swa_metadata.block_size, offset=N, + fp8_scale=self._flashinfer_fp8_kv_scale, ) # Combine the topk indices and SWA indices for gathered KV cache diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index d1763d9fae9e..26d99ee5f550 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -99,6 +99,7 @@ def _kv_rope_insert_full_cache_kernel( cache_stride0, cache_stride1, cache_block_size, + fp8_scale_ptr, HEAD_SIZE: tl.constexpr, ROPE_HEAD_DIM: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -134,7 +135,8 @@ def _kv_rope_insert_full_cache_kernel( + (pos_in_block * cache_stride1) ) if STORE_FP8: - values = tl.clamp(values, -448.0, 448.0) + fp8_scale = tl.load(fp8_scale_ptr) + values = tl.clamp(values / fp8_scale, -448.0, 448.0) tl.store(cache_row + offsets, values.to(tl.float8e4nv), mask=mask) else: tl.store(cache_row + offsets, values.to(tl.bfloat16), mask=mask) @@ -149,6 +151,7 @@ def qnorm_rope_and_insert_full_k_cache( cos_sin_cache: torch.Tensor, eps: float, cache_block_size: int, + fp8_scale: torch.Tensor, ) -> None: """Apply DeepSeek V4 Q RMSNorm/RoPE and insert full-width BF16/FP8 KV. @@ -190,6 +193,7 @@ def qnorm_rope_and_insert_full_k_cache( k_cache.stride(0), k_cache.stride(1), cache_block_size, + fp8_scale, HEAD_SIZE=512, ROPE_HEAD_DIM=64, BLOCK_SIZE=512, @@ -494,6 +498,7 @@ def _gather_full_k_cache_kernel( block_table_ptr, offset, gather_lens_ptr, + fp8_scale_ptr, max_blocks_per_seq: tl.constexpr, cache_block_size: tl.constexpr, output_dim: tl.constexpr, @@ -527,7 +532,7 @@ def _gather_full_k_cache_kernel( ) values = tl.load(cache_row + offsets, mask=mask, other=0.0) if STORE_FP8: - values = values.to(tl.float32) + values = values.to(tl.float32) * tl.load(fp8_scale_ptr) out_row = out_ptr + batch_idx * out_stride0 + (offset + i) * out_stride1 tl.store(out_row + offsets, values.to(tl.bfloat16), mask=mask) @@ -592,10 +597,13 @@ def dequantize_and_gather_k_cache( block_table: torch.Tensor, block_size: int, offset: int, + fp8_scale: torch.Tensor | None = None, ) -> None: if k_cache.dtype != torch.uint8: assert k_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) assert k_cache.dim() == 3 and k_cache.shape[-1] == 512 + if k_cache.dtype == torch.float8_e4m3fn: + assert fp8_scale is not None num_reqs = seq_lens.shape[0] NUM_WORKERS = 128 _gather_full_k_cache_kernel[(num_reqs, NUM_WORKERS)]( @@ -609,6 +617,7 @@ def dequantize_and_gather_k_cache( block_table, offset, gather_lens, + fp8_scale if fp8_scale is not None else k_cache, max_blocks_per_seq=block_table.shape[-1], cache_block_size=block_size, output_dim=512, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py index 5c817e789477..7a9cc17b6e03 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py @@ -237,6 +237,7 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn_full_cache( k_cache_ptr, kv_slot_mapping_ptr, kv_cache_block_size, + fp8_scale_ptr, # ── constexprs ── HEAD_SIZE: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, @@ -333,7 +334,8 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn_full_cache( result = result.to(tl.bfloat16).to(tl.float32) if STORE_FP8: - result = tl.clamp(result, -448.0, 448.0) + fp8_scale = tl.load(fp8_scale_ptr) + result = tl.clamp(result / fp8_scale, -448.0, 448.0) tl.store(cache_row + block, result.to(tl.float8e4nv), mask=mask) else: tl.store(cache_row + block, result.to(tl.bfloat16), mask=mask) From 32934d18554da982c410bff61ac2a8b265a6794f Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Sat, 9 May 2026 03:34:44 -0700 Subject: [PATCH 04/24] Optimize DeepSeek V4 FlashInfer FP8 decode quantization --- .../layers/deepseek_v4_attention.py | 27 ++++-------- .../attention/ops/deepseek_v4_ops/__init__.py | 2 + .../ops/deepseek_v4_ops/cache_utils.py | 44 +++++++++++++++++++ 3 files changed, 54 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index f35fe382419c..791bc2f9817d 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -25,6 +25,7 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + fp8_per_tensor_quant, fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, @@ -89,7 +90,6 @@ _FLASHINFER_DSV4_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 _flashinfer_dsv4_workspace_by_device: dict[torch.device, torch.Tensor] = {} -_FLASHINFER_FP8_LOG2E = 1.4426950408889634 _DEFAULT_FLASHINFER_DSV4_FP8_SCALE = 1.0 / 32.0 # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather @@ -813,19 +813,8 @@ def __init__( torch.tensor([fp8_kv_scale], dtype=torch.float32), persistent=False, ) - self.register_buffer( - "_flashinfer_fp8_bmm1_scale_log2", - torch.tensor( - [self.scale * fp8_q_scale * fp8_kv_scale * _FLASHINFER_FP8_LOG2E], - dtype=torch.float32, - ), - persistent=False, - ) - self.register_buffer( - "_flashinfer_fp8_bmm2_scale", - torch.tensor([fp8_kv_scale], dtype=torch.float32), - persistent=False, - ) + self._flashinfer_fp8_bmm1_scale = self.scale * fp8_q_scale * fp8_kv_scale + self._flashinfer_fp8_bmm2_scale = fp8_kv_scale # Register with compilation context for metadata lookup compilation_config = vllm_config.compilation_config @@ -1108,15 +1097,15 @@ def _forward_decode_flashinfer( max_q_len = int(query_lens_cpu.max().item()) seq_lens = swa_metadata.seq_lens[:num_decodes].to(torch.int32) - query = q.contiguous() + query = q bmm1_scale: float | torch.Tensor = self.scale bmm2_scale: float | torch.Tensor = 1.0 if self.kv_cache_torch_dtype == torch.float8_e4m3fn: - query = (query * self._flashinfer_fp8_q_scale_inv).clamp( - -448.0, 448.0 - ).to(torch.float8_e4m3fn) - bmm1_scale = self._flashinfer_fp8_bmm1_scale_log2 + query = fp8_per_tensor_quant(query, self._flashinfer_fp8_q_scale_inv) + bmm1_scale = self._flashinfer_fp8_bmm1_scale bmm2_scale = self._flashinfer_fp8_bmm2_scale + else: + query = query.contiguous() flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( query=query, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index 50aa3c6865dd..994a012fd3f1 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -5,6 +5,7 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + fp8_per_tensor_quant, qnorm_rope_and_insert_full_k_cache, quantize_and_insert_k_cache, ) @@ -17,6 +18,7 @@ "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", "dequantize_and_gather_k_cache", + "fp8_per_tensor_quant", "fused_indexer_q_rope_quant", "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 26d99ee5f550..6da7e5cb1316 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -20,6 +20,50 @@ from vllm.utils.import_utils import has_cutedsl +@triton.jit +def _fp8_per_tensor_quant_kernel( + input_ptr, + output_ptr, + scale_inv_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + values = tl.load(input_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + scale_inv = tl.load(scale_inv_ptr) + values = tl.clamp(values * scale_inv, -448.0, 448.0) + tl.store(output_ptr + offsets, values.to(tl.float8e4nv), mask=mask) + + +def fp8_per_tensor_quant( + input_tensor: torch.Tensor, + scale_inv: torch.Tensor, +) -> torch.Tensor: + assert input_tensor.dtype == torch.bfloat16 + assert scale_inv.dtype == torch.float32 and scale_inv.numel() == 1 + + input_tensor = input_tensor.contiguous() + output = torch.empty_like(input_tensor, dtype=torch.float8_e4m3fn) + n_elements = input_tensor.numel() + if n_elements == 0: + return output + + block_size = 1024 + grid = (triton.cdiv(n_elements, block_size),) + _fp8_per_tensor_quant_kernel[grid]( + input_tensor, + output, + scale_inv, + n_elements, + BLOCK_SIZE=block_size, + num_warps=4, + ) + return output + + @triton.jit def _apply_gptj_rope_512( values, From 53fe44323211cbf4ca024d6356167678dd492cf1 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Sun, 10 May 2026 20:01:06 -0700 Subject: [PATCH 05/24] Optimize DeepSeek V4 FlashInfer FP8 sparse MLA --- .../layers/deepseek_v4_attention.py | 168 +++++++-- vllm/v1/attention/backends/mla/sparse_swa.py | 23 +- .../attention/ops/deepseek_v4_ops/__init__.py | 6 +- .../ops/deepseek_v4_ops/cache_utils.py | 333 +++++++++++++++--- 4 files changed, 443 insertions(+), 87 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 791bc2f9817d..9078dd515469 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -22,10 +22,11 @@ from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( + build_flashinfer_decode_sparse_indices, + build_flashinfer_prefill_sparse_indices, combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, - fp8_per_tensor_quant, fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, @@ -527,13 +528,15 @@ def attention_impl( assert self.compressor is not None compressor = self.compressor - def wq_b_kv_insert_and_compress() -> torch.Tensor: + def wq_b_kv_insert_and_compress() -> tuple[torch.Tensor, torch.Tensor | None]: q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + q_fp8 = self._fused_qnorm_rope_kv_insert( + q, kv, positions, attn_metadata + ) compressor(kv_score, positions, self.rotary_emb) - return q + return q, q_fp8 - q, _ = maybe_execute_in_parallel( + (q, q_fp8), _ = maybe_execute_in_parallel( wq_b_kv_insert_and_compress, lambda: indexer( hidden_states, @@ -554,12 +557,14 @@ def wq_b_kv_insert_and_compress() -> torch.Tensor: ) compressor = self.compressor - def wq_b_kv_insert() -> torch.Tensor: + def wq_b_kv_insert() -> tuple[torch.Tensor, torch.Tensor | None]: q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - return q + q_fp8 = self._fused_qnorm_rope_kv_insert( + q, kv, positions, attn_metadata + ) + return q, q_fp8 - q, _ = maybe_execute_in_parallel( + (q, q_fp8), _ = maybe_execute_in_parallel( wq_b_kv_insert, lambda: compressor(kv_score, positions, self.rotary_emb), self.ln_events[0], @@ -569,35 +574,42 @@ def wq_b_kv_insert() -> torch.Tensor: else: # SWA-only layer: no compressor, no overlap. q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + q_fp8 = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) # Handle dummy run (no metadata). if not isinstance(attn_metadata, dict): # Reserve _forward_prefill's bf16-gather workspace; the dummy # run returns before mla_attn runs, so without this the shared - # workspace locks below the real prefill size. + # workspace locks below the real prefill size. The per-tensor FP8 + # FlashInfer path reads the cache directly and does not need it. sub = self.mla_attn - swa_only = sub.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio - ) - M = N + sub.window_size + sub.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) + if sub.kv_cache_torch_dtype != torch.float8_e4m3fn: + swa_only = sub.compress_ratio <= 1 + N = ( + 0 + if swa_only + else (sub.max_model_len + sub.compress_ratio - 1) + // sub.compress_ratio + ) + M = N + sub.window_size + sub.max_num_batched_tokens + current_workspace_manager().get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ) out.zero_() return - # Pad q to FlashMLA-required head count (64 or 128) - if self.n_local_heads < self.padded_heads: + q_for_attn = q_fp8 if q_fp8 is not None else q + + # Pad q to FlashMLA-required head count (64 or 128). The per-tensor + # FP8 FlashInfer path emits a padded q_fp8 tensor directly from the + # qnorm/RoPE kernel. + if q_fp8 is None and self.n_local_heads < self.padded_heads: pad_size = self.padded_heads - self.n_local_heads - q = F.pad(q, (0, 0, 0, pad_size), value=0.0) + q_for_attn = F.pad(q_for_attn, (0, 0, 0, pad_size), value=0.0) # MLA attention writes into the pre-allocated `out` buffer # ([num_tokens, padded_heads, head_dim]). - self.mla_attn(q, kv, positions, output=out) + self.mla_attn(q_for_attn, kv, positions, output=out) def _fused_qnorm_rope_kv_insert( self, @@ -607,9 +619,9 @@ def _fused_qnorm_rope_kv_insert( attn_metadata: ( dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None ), - ) -> None: + ) -> torch.Tensor | None: if not isinstance(attn_metadata, dict): - return + return None swa_metadata = cast( "DeepseekSparseSWAMetadata | None", @@ -618,6 +630,13 @@ def _fused_qnorm_rope_kv_insert( assert swa_metadata is not None swa_kv_cache = self.swa_cache_layer.kv_cache + q_fp8 = None + if swa_kv_cache.dtype == torch.float8_e4m3fn: + q_fp8 = torch.empty( + (q.shape[0], self.padded_heads, q.shape[-1]), + dtype=torch.float8_e4m3fn, + device=q.device, + ) # Horizontally fused: # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE @@ -647,7 +666,10 @@ def _fused_qnorm_rope_kv_insert( self.eps, swa_metadata.block_size, self.mla_attn._flashinfer_fp8_kv_scale, + q_fp8=q_fp8, + q_fp8_scale_inv=self.mla_attn._flashinfer_fp8_q_scale_inv, ) + return q_fp8 def deepseek_v4_attention( @@ -864,8 +886,12 @@ def forward( assert output.shape == q.shape, ( f"output buffer shape {output.shape} must match q shape {q.shape}" ) - assert output.dtype == q.dtype, ( - f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" + expected_output_dtype = ( + torch.bfloat16 if q.dtype == torch.float8_e4m3fn else q.dtype + ) + assert output.dtype == expected_output_dtype, ( + f"output buffer dtype {output.dtype} must match expected attention " + f"output dtype {expected_output_dtype} for q dtype {q.dtype}" ) if current_platform.is_rocm(): @@ -1076,32 +1102,37 @@ def _forward_decode_flashinfer( if swa_only: assert swa_metadata.decode_swa_sparse_topk_lens is not None compressed_kv_cache = self.swa_cache_layer.kv_cache - sparse_indices = swa_indices_2d + sparse_indices = build_flashinfer_decode_sparse_indices( + swa_indices_2d, + None, + self.window_size, + ) sparse_topk_lens = swa_metadata.decode_swa_sparse_topk_lens else: assert kv_cache is not None assert topk_indices is not None assert topk_lens is not None compressed_kv_cache = kv_cache - compressed_indices = topk_indices.view(num_decode_tokens, -1).contiguous() + compressed_indices = topk_indices.view(num_decode_tokens, -1) sparse_topk_lens = topk_lens - sparse_indices = torch.cat((swa_indices_2d, compressed_indices), dim=-1) - if sparse_indices.shape[-1] % 4 != 0: - pad = 4 - sparse_indices.shape[-1] % 4 - sparse_indices = F.pad(sparse_indices, (0, pad), value=-1) - sparse_indices = sparse_indices.contiguous() + sparse_indices = build_flashinfer_decode_sparse_indices( + swa_indices_2d, + compressed_indices, + self.window_size, + ) query_start_loc = swa_metadata.query_start_loc[: num_decodes + 1] query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_decodes + 1] query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_q_len = int(query_lens_cpu.max().item()) - seq_lens = swa_metadata.seq_lens[:num_decodes].to(torch.int32) + assert swa_metadata.seq_lens_int32 is not None + seq_lens = swa_metadata.seq_lens_int32[:num_decodes] query = q bmm1_scale: float | torch.Tensor = self.scale bmm2_scale: float | torch.Tensor = 1.0 if self.kv_cache_torch_dtype == torch.float8_e4m3fn: - query = fp8_per_tensor_quant(query, self._flashinfer_fp8_q_scale_inv) + assert query.dtype == torch.float8_e4m3fn bmm1_scale = self._flashinfer_fp8_bmm1_scale bmm2_scale = self._flashinfer_fp8_bmm2_scale else: @@ -1174,6 +1205,67 @@ def _forward_prefill( top_k = 0 N = 0 + if ( + self.kv_cache_torch_dtype == torch.float8_e4m3fn + and not current_platform.is_rocm() + ): + assert q.dtype == torch.float8_e4m3fn + assert swa_metadata.prefill_query_start_loc is not None + assert swa_metadata.prefill_query_start_loc_cpu is not None + assert swa_metadata.seq_lens_int32 is not None + + prefill_query_start_loc = swa_metadata.prefill_query_start_loc + prefill_query_start_loc_cpu = swa_metadata.prefill_query_start_loc_cpu + query_lens_cpu = ( + prefill_query_start_loc_cpu[1:] - prefill_query_start_loc_cpu[:-1] + ) + max_q_len = int(query_lens_cpu.max().item()) + seq_lens_int32 = swa_metadata.seq_lens_int32[ + num_decodes : num_decodes + num_prefills + ] + swa_block_table = swa_metadata.block_table[num_decodes:] + + if swa_only: + compressed_kv_cache = swa_k_cache + compressed_block_table = None + compressed_block_size = swa_metadata.block_size + else: + assert compressed_k_cache is not None + assert attn_metadata is not None + compressed_kv_cache = compressed_k_cache + compressed_block_table = attn_metadata.block_table[num_decodes:] + compressed_block_size = attn_metadata.block_size // self.compress_ratio + + sparse_indices, sparse_topk_lens = build_flashinfer_prefill_sparse_indices( + topk_indices[:num_prefill_tokens], + prefill_query_start_loc, + seq_lens_int32, + swa_block_table, + swa_metadata.block_size, + compressed_block_table, + compressed_block_size, + self.window_size, + self.compress_ratio, + top_k, + ) + + flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( + query=q, + swa_kv_cache=swa_k_cache, + workspace_buffer=_get_flashinfer_dsv4_workspace(q.device), + sparse_indices=sparse_indices, + compressed_kv_cache=compressed_kv_cache, + sparse_topk_lens=sparse_topk_lens, + seq_lens=seq_lens_int32, + out=output, + bmm1_scale=self._flashinfer_fp8_bmm1_scale, + bmm2_scale=self._flashinfer_fp8_bmm2_scale, + sinks=self.attn_sink, + cum_seq_lens_q=prefill_query_start_loc, + max_q_len=max_q_len, + ) + return + M = N + self.window_size + self.max_num_batched_tokens num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 860a5be57c1b..76c7dc6f7d79 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -154,6 +154,7 @@ class DeepseekSparseSWAMetadata: slot_mapping: torch.Tensor block_size: int seq_lens: torch.Tensor | None = None # [num_seqs] + seq_lens_int32: torch.Tensor | None = None # [num_seqs] query_start_loc: torch.Tensor | None = None # [num_seqs + 1] query_start_loc_cpu: torch.Tensor | None = None # [num_seqs + 1] @@ -172,6 +173,8 @@ class DeepseekSparseSWAMetadata: # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. prefill_seq_lens: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None + prefill_query_start_loc: torch.Tensor | None = None + prefill_query_start_loc_cpu: torch.Tensor | None = None # Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta # per present DeepseekV4 layer type, shared across all ~60 layers of that type @@ -288,6 +291,9 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + seq_lens_int32 = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to( + torch.int32 + ) # Split into decode and prefill portions using configurable threshold (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( @@ -327,8 +333,9 @@ def build( deepseek_v4_fields = self._build_deepseek_v4_metadata( num_decodes, num_prefills, - seq_lens, + seq_lens_int32, query_start_loc, + query_start_loc_cpu, ) # Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta @@ -339,6 +346,7 @@ def build( return DeepseekSparseSWAMetadata( seq_lens=seq_lens, + seq_lens_int32=seq_lens_int32, query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, block_table=block_table, @@ -396,6 +404,7 @@ def _build_deepseek_v4_metadata( num_prefills: int, seq_lens: torch.Tensor, query_start_loc: torch.Tensor, + query_start_loc_cpu: torch.Tensor, ) -> dict[str, torch.Tensor | None]: """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. @@ -424,6 +433,18 @@ def _build_deepseek_v4_metadata( result["prefill_seq_lens"] = seq_lens[num_decodes:] result["prefill_gather_lens"] = pfx_gather_lens + prefill_query_start_loc = query_start_loc[ + num_decodes : num_decodes + num_prefills + 1 + ] + prefill_query_start_loc_cpu = query_start_loc_cpu[ + num_decodes : num_decodes + num_prefills + 1 + ] + result["prefill_query_start_loc"] = ( + prefill_query_start_loc - prefill_query_start_loc[0] + ) + result["prefill_query_start_loc_cpu"] = ( + prefill_query_start_loc_cpu - prefill_query_start_loc_cpu[0] + ) return result diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index 994a012fd3f1..e26953c2502c 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .cache_utils import ( + build_flashinfer_decode_sparse_indices, + build_flashinfer_prefill_sparse_indices, combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, - fp8_per_tensor_quant, qnorm_rope_and_insert_full_k_cache, quantize_and_insert_k_cache, ) @@ -15,10 +16,11 @@ __all__ = [ "MXFP4_BLOCK_SIZE", + "build_flashinfer_decode_sparse_indices", + "build_flashinfer_prefill_sparse_indices", "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", "dequantize_and_gather_k_cache", - "fp8_per_tensor_quant", "fused_indexer_q_rope_quant", "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 6da7e5cb1316..f394a745d5f4 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -20,50 +20,6 @@ from vllm.utils.import_utils import has_cutedsl -@triton.jit -def _fp8_per_tensor_quant_kernel( - input_ptr, - output_ptr, - scale_inv_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - values = tl.load(input_ptr + offsets, mask=mask, other=0.0).to(tl.float32) - scale_inv = tl.load(scale_inv_ptr) - values = tl.clamp(values * scale_inv, -448.0, 448.0) - tl.store(output_ptr + offsets, values.to(tl.float8e4nv), mask=mask) - - -def fp8_per_tensor_quant( - input_tensor: torch.Tensor, - scale_inv: torch.Tensor, -) -> torch.Tensor: - assert input_tensor.dtype == torch.bfloat16 - assert scale_inv.dtype == torch.float32 and scale_inv.numel() == 1 - - input_tensor = input_tensor.contiguous() - output = torch.empty_like(input_tensor, dtype=torch.float8_e4m3fn) - n_elements = input_tensor.numel() - if n_elements == 0: - return output - - block_size = 1024 - grid = (triton.cdiv(n_elements, block_size),) - _fp8_per_tensor_quant_kernel[grid]( - input_tensor, - output, - scale_inv, - n_elements, - BLOCK_SIZE=block_size, - num_warps=4, - ) - return output - - @triton.jit def _apply_gptj_rope_512( values, @@ -100,6 +56,11 @@ def _qnorm_rope_kernel( q_ptr, q_stride0, q_stride1, + num_q_heads, + q_fp8_ptr, + q_fp8_stride0, + q_fp8_stride1, + q_fp8_scale_inv_ptr, positions_ptr, cos_sin_cache_ptr, cos_sin_stride, @@ -107,12 +68,22 @@ def _qnorm_rope_kernel( HEAD_SIZE: tl.constexpr, ROPE_HEAD_DIM: tl.constexpr, BLOCK_SIZE: tl.constexpr, + STORE_Q_FP8: tl.constexpr, ): token_idx = tl.program_id(0) head_idx = tl.program_id(1) offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < HEAD_SIZE + if STORE_Q_FP8 and head_idx >= num_q_heads: + q_fp8_row = q_fp8_ptr + token_idx * q_fp8_stride0 + head_idx * q_fp8_stride1 + tl.store( + q_fp8_row + offsets, + tl.zeros((BLOCK_SIZE,), dtype=tl.float32).to(tl.float8e4nv), + mask=mask, + ) + return + q_row = q_ptr + token_idx * q_stride0 + head_idx * q_stride1 values = tl.load(q_row + offsets, mask=mask, other=0.0).to(tl.float32) variance = tl.sum(values * values, axis=0) / HEAD_SIZE @@ -128,7 +99,16 @@ def _qnorm_rope_kernel( ROPE_HEAD_DIM, BLOCK_SIZE, ) - tl.store(q_row + offsets, values.to(tl.bfloat16), mask=mask) + q_bf16 = values.to(tl.bfloat16) + tl.store(q_row + offsets, q_bf16, mask=mask) + + if STORE_Q_FP8: + q_fp8_scale_inv = tl.load(q_fp8_scale_inv_ptr) + q_fp8_row = q_fp8_ptr + token_idx * q_fp8_stride0 + head_idx * q_fp8_stride1 + q_fp8_values = tl.clamp( + q_bf16.to(tl.float32) * q_fp8_scale_inv, -448.0, 448.0 + ) + tl.store(q_fp8_row + offsets, q_fp8_values.to(tl.float8e4nv), mask=mask) @triton.jit @@ -196,6 +176,8 @@ def qnorm_rope_and_insert_full_k_cache( eps: float, cache_block_size: int, fp8_scale: torch.Tensor, + q_fp8: torch.Tensor | None = None, + q_fp8_scale_inv: torch.Tensor | None = None, ) -> None: """Apply DeepSeek V4 Q RMSNorm/RoPE and insert full-width BF16/FP8 KV. @@ -209,12 +191,25 @@ def qnorm_rope_and_insert_full_k_cache( assert kv.dtype == torch.bfloat16 assert k_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) assert cos_sin_cache.dtype == torch.float32 + if q_fp8 is not None: + assert q_fp8.dtype == torch.float8_e4m3fn + assert q_fp8.dim() == 3 and q_fp8.shape[0] == q.shape[0] + assert q_fp8.shape[-1] == q.shape[-1] + assert q_fp8_scale_inv is not None + assert q_fp8_scale_inv.dtype == torch.float32 + assert q_fp8_scale_inv.numel() == 1 num_tokens_full, num_heads, _ = q.shape - _qnorm_rope_kernel[(num_tokens_full, num_heads)]( + q_heads_for_grid = q_fp8.shape[1] if q_fp8 is not None else num_heads + _qnorm_rope_kernel[(num_tokens_full, q_heads_for_grid)]( q, q.stride(0), q.stride(1), + num_heads, + q_fp8 if q_fp8 is not None else q, + q_fp8.stride(0) if q_fp8 is not None else q.stride(0), + q_fp8.stride(1) if q_fp8 is not None else q.stride(1), + q_fp8_scale_inv if q_fp8_scale_inv is not None else fp8_scale, positions, cos_sin_cache, cos_sin_cache.stride(0), @@ -222,6 +217,7 @@ def qnorm_rope_and_insert_full_k_cache( HEAD_SIZE=512, ROPE_HEAD_DIM=64, BLOCK_SIZE=512, + STORE_Q_FP8=q_fp8 is not None, num_warps=8, ) @@ -722,6 +718,251 @@ def compute_global_topk_indices_and_lens( return global_topk_indices, topk_lens +def build_flashinfer_decode_sparse_indices( + swa_indices: torch.Tensor, + compressed_indices: torch.Tensor | None, + window_size: int, +) -> torch.Tensor: + """Build FlashInfer DSV4 SWA-first sparse indices without torch.cat/pad.""" + assert swa_indices.dtype == torch.int32 + assert swa_indices.dim() == 2 and swa_indices.shape[-1] == window_size + if compressed_indices is None: + return swa_indices + assert compressed_indices.dtype == torch.int32 + assert compressed_indices.dim() == 2 + assert compressed_indices.shape[0] == swa_indices.shape[0] + + num_tokens = swa_indices.shape[0] + compressed_topk = compressed_indices.shape[-1] + padded_compressed_topk = (compressed_topk + 3) // 4 * 4 + sparse_indices = torch.empty( + (num_tokens, window_size + padded_compressed_topk), + dtype=torch.int32, + device=swa_indices.device, + ) + if num_tokens == 0: + return sparse_indices + + _merge_flashinfer_sparse_indices_kernel[(num_tokens,)]( + sparse_indices, + sparse_indices.stride(0), + swa_indices, + swa_indices.stride(0), + compressed_indices, + compressed_indices.stride(0), + WINDOW_SIZE=window_size, + COMPRESSED_TOPK=compressed_topk, + PADDED_COMPRESSED_TOPK=padded_compressed_topk, + BLOCK_SIZE=1024, + ) + return sparse_indices + + +@triton.jit +def _merge_flashinfer_sparse_indices_kernel( + sparse_indices_ptr, + sparse_indices_stride, + swa_indices_ptr, + swa_indices_stride, + compressed_indices_ptr, + compressed_indices_stride, + WINDOW_SIZE: tl.constexpr, + COMPRESSED_TOPK: tl.constexpr, + PADDED_COMPRESSED_TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + + for i in range(0, WINDOW_SIZE, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < WINDOW_SIZE + values = tl.load( + swa_indices_ptr + token_idx * swa_indices_stride + offset, + mask=mask, + other=-1, + ) + tl.store( + sparse_indices_ptr + token_idx * sparse_indices_stride + offset, + values, + mask=mask, + ) + + for i in range(0, PADDED_COMPRESSED_TOPK, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < PADDED_COMPRESSED_TOPK + values = tl.load( + compressed_indices_ptr + token_idx * compressed_indices_stride + offset, + mask=offset < COMPRESSED_TOPK, + other=-1, + ) + tl.store( + sparse_indices_ptr + + token_idx * sparse_indices_stride + + WINDOW_SIZE + + offset, + values, + mask=mask, + ) + + +def build_flashinfer_prefill_sparse_indices( + topk_indices: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + swa_block_table: torch.Tensor, + swa_block_size: int, + compressed_block_table: torch.Tensor | None, + compressed_block_size: int, + window_size: int, + compress_ratio: int, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build FlashInfer DSV4 prefill sparse indices from physical KV slots. + + FlashInfer's DSV4 launcher expects the first `window_size` columns to refer + to the SWA KV pool and the remaining columns to refer to the compressed KV + pool. Unlike FlashMLA prefill, these are physical cache slots rather than + offsets into a gathered BF16 workspace. + """ + assert topk_indices.dtype == torch.int32 + assert query_start_loc.dtype == torch.int32 + assert seq_lens.dtype == torch.int32 + assert swa_block_table.dtype == torch.int32 + assert topk_indices.dim() == 2 + + num_tokens = topk_indices.shape[0] + num_reqs = seq_lens.shape[0] + padded_topk = (topk + 3) // 4 * 4 + sparse_indices = torch.empty( + (num_tokens, window_size + padded_topk), + dtype=torch.int32, + device=topk_indices.device, + ) + sparse_topk_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + if num_tokens == 0: + return sparse_indices, sparse_topk_lens + + if compressed_block_table is None: + compressed_block_table = swa_block_table + assert compressed_block_table.dtype == torch.int32 + + NUM_WORKERS = 128 + _build_flashinfer_prefill_sparse_indices_kernel[(num_reqs, NUM_WORKERS)]( + sparse_indices, + sparse_indices.stride(0), + sparse_topk_lens, + topk_indices, + topk_indices.stride(0), + query_start_loc, + seq_lens, + swa_block_table, + swa_block_table.stride(0), + swa_block_size, + compressed_block_table, + compressed_block_table.stride(0), + compressed_block_size, + WINDOW_SIZE=window_size, + COMPRESS_RATIO=compress_ratio, + TOP_K=topk, + PADDED_TOP_K=padded_topk, + TOPK_STRIDE=topk_indices.shape[-1], + BLOCK_SIZE=1024, + ) + return sparse_indices, sparse_topk_lens + + +@triton.jit +def _build_flashinfer_prefill_sparse_indices_kernel( + sparse_indices_ptr, + sparse_indices_stride, + sparse_topk_lens_ptr, + topk_indices_ptr, + topk_indices_stride, + query_start_loc_ptr, + seq_lens_ptr, + swa_block_table_ptr, + swa_block_table_stride, + swa_block_size, + compressed_block_table_ptr, + compressed_block_table_stride, + compressed_block_size, + WINDOW_SIZE: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + TOP_K: tl.constexpr, + PADDED_TOP_K: tl.constexpr, + TOPK_STRIDE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + worker_id = tl.program_id(1) + num_workers = tl.num_programs(1) + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + seq_len = tl.load(seq_lens_ptr + batch_idx) + start_pos = seq_len - query_len + + for token_idx in range(query_start + worker_id, query_end, num_workers): + token_idx_in_query = token_idx - query_start + pos = start_pos + token_idx_in_query + swa_len = tl.minimum(pos + 1, WINDOW_SIZE) + swa_start_pos = pos - swa_len + 1 + topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) + + for i in range(0, WINDOW_SIZE, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < WINDOW_SIZE + pos_offset = swa_start_pos + offset + block_indices = pos_offset // swa_block_size + block_numbers = tl.load( + swa_block_table_ptr + batch_idx * swa_block_table_stride + block_indices, + mask=mask & (offset < swa_len), + other=-1, + ) + block_offsets = pos_offset % swa_block_size + slot_ids = block_numbers * swa_block_size + block_offsets + slot_ids = tl.where(offset < swa_len, slot_ids, -1) + tl.store( + sparse_indices_ptr + token_idx * sparse_indices_stride + offset, + slot_ids, + mask=mask, + ) + + for i in range(0, PADDED_TOP_K, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < PADDED_TOP_K + local_idx = tl.load( + topk_indices_ptr + token_idx * topk_indices_stride + offset, + mask=(offset < TOPK_STRIDE) & (offset < topk_len), + other=-1, + ) + is_valid = local_idx >= 0 + block_indices = local_idx // compressed_block_size + block_numbers = tl.load( + compressed_block_table_ptr + + batch_idx * compressed_block_table_stride + + block_indices, + mask=mask & is_valid, + other=-1, + ) + block_offsets = local_idx % compressed_block_size + slot_ids = block_numbers * compressed_block_size + block_offsets + slot_ids = tl.where((offset < topk_len) & is_valid, slot_ids, -1) + tl.store( + sparse_indices_ptr + + token_idx * sparse_indices_stride + + WINDOW_SIZE + + offset, + slot_ids, + mask=mask, + ) + + tl.store(sparse_topk_lens_ptr + token_idx, WINDOW_SIZE + topk_len) + + @triton.jit def _compute_global_topk_indices_and_lens_kernel( global_topk_indices_ptr, From cc5ec7b68c1a578b64337f8da1a91076ace14aa6 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Sun, 10 May 2026 20:24:49 -0700 Subject: [PATCH 06/24] Use one FlashInfer call for mixed DSV4 FP8 batches --- .../layers/deepseek_v4_attention.py | 142 +++++++++++ .../attention/ops/deepseek_v4_ops/__init__.py | 2 + .../ops/deepseek_v4_ops/cache_utils.py | 229 ++++++++++++++++++ 3 files changed, 373 insertions(+) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 9078dd515469..7de25d53c863 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -23,6 +23,7 @@ from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( build_flashinfer_decode_sparse_indices, + build_flashinfer_mixed_sparse_indices, build_flashinfer_prefill_sparse_indices, combine_topk_swa_indices, compute_global_topk_indices_and_lens, @@ -926,6 +927,23 @@ def forward( num_prefills = swa_metadata.num_prefills num_decode_tokens = swa_metadata.num_decode_tokens + if ( + num_prefills > 0 + and num_decodes > 0 + and self.kv_cache_torch_dtype == torch.float8_e4m3fn + and not current_platform.is_rocm() + ): + self._forward_mixed_flashinfer( + q=q, + kv_cache=self_kv_cache, + swa_k_cache=swa_kv_cache, + swa_metadata=swa_metadata, + attn_metadata=flashmla_metadata, + swa_only=swa_only, + output=output, + ) + return + if num_prefills > 0: self._forward_prefill( q=q[num_decode_tokens:], @@ -1154,6 +1172,130 @@ def _forward_decode_flashinfer( max_q_len=max_q_len, ) + def _forward_mixed_flashinfer( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + swa_only: bool, + output: torch.Tensor, + ) -> None: + assert self.kv_cache_torch_dtype == torch.float8_e4m3fn + assert q.dtype == torch.float8_e4m3fn + num_decodes = swa_metadata.num_decodes + num_prefills = swa_metadata.num_prefills + num_decode_tokens = swa_metadata.num_decode_tokens + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_reqs = num_decodes + num_prefills + num_tokens = num_decode_tokens + num_prefill_tokens + + assert swa_metadata.seq_lens_int32 is not None + assert swa_metadata.query_start_loc is not None + assert swa_metadata.query_start_loc_cpu is not None + assert swa_metadata.token_to_req_indices is not None + assert swa_metadata.decode_swa_indices is not None + + decode_swa_indices = swa_metadata.decode_swa_indices.view( + num_decode_tokens, -1 + ) + if decode_swa_indices.shape[-1] != self.window_size: + raise ValueError( + f"DeepSeek V4 FlashInfer path expects {self.window_size} SWA " + f"indices, got {decode_swa_indices.shape[-1]}" + ) + + if swa_only: + assert swa_metadata.decode_swa_sparse_topk_lens is not None + assert self.topk_indices_buffer is not None + compressed_kv_cache = swa_k_cache + decode_compressed_indices = None + decode_sparse_topk_lens = swa_metadata.decode_swa_sparse_topk_lens + prefill_topk_indices = self.topk_indices_buffer[ + num_decode_tokens : num_decode_tokens + num_prefill_tokens + ] + compressed_block_table = None + compressed_block_size = swa_metadata.block_size + top_k = 0 + else: + assert kv_cache is not None + assert attn_metadata is not None + compressed_kv_cache = kv_cache + compressed_block_table = attn_metadata.block_table[:num_reqs] + compressed_block_size = attn_metadata.block_size // self.compress_ratio + + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + assert swa_metadata.is_valid_token is not None + is_valid = swa_metadata.is_valid_token[:num_decode_tokens] + decode_global_indices, decode_sparse_topk_lens = ( + compute_global_topk_indices_and_lens( + self.topk_indices_buffer[:num_decode_tokens], + swa_metadata.token_to_req_indices, + attn_metadata.block_table[:num_decodes], + compressed_block_size, + is_valid, + self.window_size, + ) + ) + decode_compressed_indices = decode_global_indices.view( + num_decode_tokens, -1 + ) + prefill_topk_indices = self.topk_indices_buffer[ + num_decode_tokens : num_decode_tokens + num_prefill_tokens + ] + else: + assert attn_metadata.c128a_global_decode_topk_indices is not None + assert attn_metadata.c128a_decode_topk_lens is not None + assert attn_metadata.c128a_prefill_topk_indices is not None + decode_compressed_indices = ( + attn_metadata.c128a_global_decode_topk_indices.view( + num_decode_tokens, -1 + ) + ) + decode_sparse_topk_lens = attn_metadata.c128a_decode_topk_lens + prefill_topk_indices = attn_metadata.c128a_prefill_topk_indices + top_k = prefill_topk_indices.shape[-1] + + query_start_loc = swa_metadata.query_start_loc[: num_reqs + 1] + query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_reqs + 1] + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + max_q_len = int(query_lens_cpu.max().item()) + seq_lens = swa_metadata.seq_lens_int32[:num_reqs] + sparse_indices, sparse_topk_lens = build_flashinfer_mixed_sparse_indices( + decode_swa_indices, + decode_compressed_indices, + decode_sparse_topk_lens, + prefill_topk_indices[:num_prefill_tokens], + query_start_loc, + seq_lens, + swa_metadata.token_to_req_indices[:num_tokens], + swa_metadata.block_table[:num_reqs], + swa_metadata.block_size, + compressed_block_table, + compressed_block_size, + self.window_size, + self.compress_ratio, + top_k, + ) + + flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( + query=q, + swa_kv_cache=swa_k_cache, + workspace_buffer=_get_flashinfer_dsv4_workspace(q.device), + sparse_indices=sparse_indices, + compressed_kv_cache=compressed_kv_cache, + sparse_topk_lens=sparse_topk_lens, + seq_lens=seq_lens, + out=output, + bmm1_scale=self._flashinfer_fp8_bmm1_scale, + bmm2_scale=self._flashinfer_fp8_bmm2_scale, + sinks=self.attn_sink, + cum_seq_lens_q=query_start_loc, + max_q_len=max_q_len, + ) + def _forward_prefill( self, q: torch.Tensor, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index e26953c2502c..bdd1a5955044 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -3,6 +3,7 @@ from .cache_utils import ( build_flashinfer_decode_sparse_indices, + build_flashinfer_mixed_sparse_indices, build_flashinfer_prefill_sparse_indices, combine_topk_swa_indices, compute_global_topk_indices_and_lens, @@ -17,6 +18,7 @@ __all__ = [ "MXFP4_BLOCK_SIZE", "build_flashinfer_decode_sparse_indices", + "build_flashinfer_mixed_sparse_indices", "build_flashinfer_prefill_sparse_indices", "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index f394a745d5f4..163ae1dc94c8 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -873,6 +873,99 @@ def build_flashinfer_prefill_sparse_indices( return sparse_indices, sparse_topk_lens +def build_flashinfer_mixed_sparse_indices( + decode_swa_indices: torch.Tensor, + decode_compressed_indices: torch.Tensor | None, + decode_sparse_topk_lens: torch.Tensor, + prefill_topk_indices: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + token_to_req_indices: torch.Tensor, + swa_block_table: torch.Tensor, + swa_block_size: int, + compressed_block_table: torch.Tensor | None, + compressed_block_size: int, + window_size: int, + compress_ratio: int, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build FlashInfer DSV4 sparse indices for decode-first mixed batches.""" + assert decode_swa_indices.dtype == torch.int32 + assert decode_swa_indices.dim() == 2 + assert decode_swa_indices.shape[-1] == window_size + assert decode_sparse_topk_lens.dtype == torch.int32 + assert prefill_topk_indices.dtype == torch.int32 + assert prefill_topk_indices.dim() == 2 + assert query_start_loc.dtype == torch.int32 + assert seq_lens.dtype == torch.int32 + assert token_to_req_indices.dtype == torch.int32 + assert swa_block_table.dtype == torch.int32 + + num_decode_tokens = decode_swa_indices.shape[0] + num_prefill_tokens = prefill_topk_indices.shape[0] + num_tokens = num_decode_tokens + num_prefill_tokens + assert token_to_req_indices.shape[0] >= num_tokens + assert decode_sparse_topk_lens.shape[0] >= num_decode_tokens + + decode_compressed_topk = 0 + if decode_compressed_indices is None: + decode_compressed_indices = prefill_topk_indices + else: + assert decode_compressed_indices.dtype == torch.int32 + assert decode_compressed_indices.dim() == 2 + assert decode_compressed_indices.shape[0] == num_decode_tokens + decode_compressed_topk = decode_compressed_indices.shape[-1] + + if compressed_block_table is None: + compressed_block_table = swa_block_table + assert compressed_block_table.dtype == torch.int32 + + padded_topk = max(topk, decode_compressed_topk) + padded_topk = (padded_topk + 3) // 4 * 4 + sparse_indices = torch.empty( + (num_tokens, window_size + padded_topk), + dtype=torch.int32, + device=decode_swa_indices.device, + ) + sparse_topk_lens = torch.empty( + num_tokens, dtype=torch.int32, device=decode_swa_indices.device + ) + if num_tokens == 0: + return sparse_indices, sparse_topk_lens + + _build_flashinfer_mixed_sparse_indices_kernel[(num_tokens,)]( + sparse_indices, + sparse_indices.stride(0), + sparse_topk_lens, + decode_swa_indices, + decode_swa_indices.stride(0), + decode_compressed_indices, + decode_compressed_indices.stride(0), + decode_sparse_topk_lens, + prefill_topk_indices, + prefill_topk_indices.stride(0), + query_start_loc, + seq_lens, + token_to_req_indices, + swa_block_table, + swa_block_table.stride(0), + swa_block_size, + compressed_block_table, + compressed_block_table.stride(0), + compressed_block_size, + NUM_DECODE_TOKENS=num_decode_tokens, + WINDOW_SIZE=window_size, + COMPRESS_RATIO=compress_ratio, + TOP_K=topk, + PADDED_TOP_K=padded_topk, + PREFILL_TOPK_STRIDE=prefill_topk_indices.shape[-1], + DECODE_COMPRESSED_TOPK=decode_compressed_topk, + BLOCK_SIZE=1024, + num_warps=8, + ) + return sparse_indices, sparse_topk_lens + + @triton.jit def _build_flashinfer_prefill_sparse_indices_kernel( sparse_indices_ptr, @@ -963,6 +1056,142 @@ def _build_flashinfer_prefill_sparse_indices_kernel( tl.store(sparse_topk_lens_ptr + token_idx, WINDOW_SIZE + topk_len) +@triton.jit +def _build_flashinfer_mixed_sparse_indices_kernel( + sparse_indices_ptr, + sparse_indices_stride, + sparse_topk_lens_ptr, + decode_swa_indices_ptr, + decode_swa_stride, + decode_compressed_indices_ptr, + decode_compressed_stride, + decode_sparse_topk_lens_ptr, + prefill_topk_indices_ptr, + prefill_topk_stride, + query_start_loc_ptr, + seq_lens_ptr, + token_to_req_indices_ptr, + swa_block_table_ptr, + swa_block_table_stride, + swa_block_size, + compressed_block_table_ptr, + compressed_block_table_stride, + compressed_block_size, + NUM_DECODE_TOKENS: tl.constexpr, + WINDOW_SIZE: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + TOP_K: tl.constexpr, + PADDED_TOP_K: tl.constexpr, + PREFILL_TOPK_STRIDE: tl.constexpr, + DECODE_COMPRESSED_TOPK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + + if token_idx < NUM_DECODE_TOKENS: + for i in range(0, WINDOW_SIZE, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < WINDOW_SIZE + values = tl.load( + decode_swa_indices_ptr + token_idx * decode_swa_stride + offset, + mask=mask, + other=-1, + ) + tl.store( + sparse_indices_ptr + token_idx * sparse_indices_stride + offset, + values, + mask=mask, + ) + + for i in range(0, PADDED_TOP_K, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < PADDED_TOP_K + values = tl.load( + decode_compressed_indices_ptr + + token_idx * decode_compressed_stride + + offset, + mask=offset < DECODE_COMPRESSED_TOPK, + other=-1, + ) + tl.store( + sparse_indices_ptr + + token_idx * sparse_indices_stride + + WINDOW_SIZE + + offset, + values, + mask=mask, + ) + + tl.store( + sparse_topk_lens_ptr + token_idx, + tl.load(decode_sparse_topk_lens_ptr + token_idx), + ) + return + + prefill_idx = token_idx - NUM_DECODE_TOKENS + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + query_start = tl.load(query_start_loc_ptr + req_idx) + query_end = tl.load(query_start_loc_ptr + req_idx + 1) + query_len = query_end - query_start + seq_len = tl.load(seq_lens_ptr + req_idx) + start_pos = seq_len - query_len + token_idx_in_query = token_idx - query_start + pos = start_pos + token_idx_in_query + swa_len = tl.minimum(pos + 1, WINDOW_SIZE) + swa_start_pos = pos - swa_len + 1 + topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) + + for i in range(0, WINDOW_SIZE, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < WINDOW_SIZE + pos_offset = swa_start_pos + offset + block_indices = pos_offset // swa_block_size + block_numbers = tl.load( + swa_block_table_ptr + req_idx * swa_block_table_stride + block_indices, + mask=mask & (offset < swa_len), + other=-1, + ) + block_offsets = pos_offset % swa_block_size + slot_ids = block_numbers * swa_block_size + block_offsets + slot_ids = tl.where(offset < swa_len, slot_ids, -1) + tl.store( + sparse_indices_ptr + token_idx * sparse_indices_stride + offset, + slot_ids, + mask=mask, + ) + + for i in range(0, PADDED_TOP_K, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < PADDED_TOP_K + local_idx = tl.load( + prefill_topk_indices_ptr + prefill_idx * prefill_topk_stride + offset, + mask=(offset < PREFILL_TOPK_STRIDE) & (offset < topk_len), + other=-1, + ) + is_valid = local_idx >= 0 + block_indices = local_idx // compressed_block_size + block_numbers = tl.load( + compressed_block_table_ptr + + req_idx * compressed_block_table_stride + + block_indices, + mask=mask & is_valid, + other=-1, + ) + block_offsets = local_idx % compressed_block_size + slot_ids = block_numbers * compressed_block_size + block_offsets + slot_ids = tl.where((offset < topk_len) & is_valid, slot_ids, -1) + tl.store( + sparse_indices_ptr + + token_idx * sparse_indices_stride + + WINDOW_SIZE + + offset, + slot_ids, + mask=mask, + ) + + tl.store(sparse_topk_lens_ptr + token_idx, WINDOW_SIZE + topk_len) + + @triton.jit def _compute_global_topk_indices_and_lens_kernel( global_topk_indices_ptr, From db6aff8ad56b3041770d3add65fcce48db89183c Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Sun, 10 May 2026 20:36:22 -0700 Subject: [PATCH 07/24] Unify DeepSeek V4 FlashInfer attention path --- .../layers/deepseek_v4_attention.py | 277 +++++------------- 1 file changed, 72 insertions(+), 205 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 7de25d53c863..693829d76588 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -22,9 +22,7 @@ from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( - build_flashinfer_decode_sparse_indices, build_flashinfer_mixed_sparse_indices, - build_flashinfer_prefill_sparse_indices, combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, @@ -927,13 +925,8 @@ def forward( num_prefills = swa_metadata.num_prefills num_decode_tokens = swa_metadata.num_decode_tokens - if ( - num_prefills > 0 - and num_decodes > 0 - and self.kv_cache_torch_dtype == torch.float8_e4m3fn - and not current_platform.is_rocm() - ): - self._forward_mixed_flashinfer( + if self.kv_cache_torch_dtype != torch.uint8 and not current_platform.is_rocm(): + self._forward_flashinfer( q=q, kv_cache=self_kv_cache, swa_k_cache=swa_kv_cache, @@ -1025,18 +1018,7 @@ def _forward_decode( ) return - if self.kv_cache_torch_dtype != torch.uint8: - self._forward_decode_flashinfer( - q=q, - kv_cache=kv_cache, - swa_metadata=swa_metadata, - swa_only=swa_only, - topk_indices=topk_indices, - topk_lens=topk_lens, - swa_indices=swa_indices, - output=output, - ) - return + assert self.kv_cache_torch_dtype == torch.uint8 # We treat queries in the same seq as different queries # and later we only attend by generated indices. @@ -1092,87 +1074,7 @@ def _forward_decode( out=output.unsqueeze(1), ) - def _forward_decode_flashinfer( - self, - q: torch.Tensor, - kv_cache: torch.Tensor | None, - swa_metadata: "DeepseekSparseSWAMetadata", - swa_only: bool, - topk_indices: torch.Tensor | None, - topk_lens: torch.Tensor | None, - swa_indices: torch.Tensor, - output: torch.Tensor, - ) -> None: - assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn) - num_decodes = swa_metadata.num_decodes - num_decode_tokens = swa_metadata.num_decode_tokens - assert swa_metadata.seq_lens is not None - assert swa_metadata.query_start_loc is not None - assert swa_metadata.query_start_loc_cpu is not None - - swa_indices_2d = swa_indices.view(num_decode_tokens, -1) - if swa_indices_2d.shape[-1] != self.window_size: - raise ValueError( - f"DeepSeek V4 FlashInfer path expects {self.window_size} SWA " - f"indices, got {swa_indices_2d.shape[-1]}" - ) - - if swa_only: - assert swa_metadata.decode_swa_sparse_topk_lens is not None - compressed_kv_cache = self.swa_cache_layer.kv_cache - sparse_indices = build_flashinfer_decode_sparse_indices( - swa_indices_2d, - None, - self.window_size, - ) - sparse_topk_lens = swa_metadata.decode_swa_sparse_topk_lens - else: - assert kv_cache is not None - assert topk_indices is not None - assert topk_lens is not None - compressed_kv_cache = kv_cache - compressed_indices = topk_indices.view(num_decode_tokens, -1) - sparse_topk_lens = topk_lens - sparse_indices = build_flashinfer_decode_sparse_indices( - swa_indices_2d, - compressed_indices, - self.window_size, - ) - - query_start_loc = swa_metadata.query_start_loc[: num_decodes + 1] - query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_decodes + 1] - query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - max_q_len = int(query_lens_cpu.max().item()) - - assert swa_metadata.seq_lens_int32 is not None - seq_lens = swa_metadata.seq_lens_int32[:num_decodes] - query = q - bmm1_scale: float | torch.Tensor = self.scale - bmm2_scale: float | torch.Tensor = 1.0 - if self.kv_cache_torch_dtype == torch.float8_e4m3fn: - assert query.dtype == torch.float8_e4m3fn - bmm1_scale = self._flashinfer_fp8_bmm1_scale - bmm2_scale = self._flashinfer_fp8_bmm2_scale - else: - query = query.contiguous() - - flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( - query=query, - swa_kv_cache=self.swa_cache_layer.kv_cache, - workspace_buffer=_get_flashinfer_dsv4_workspace(q.device), - sparse_indices=sparse_indices, - compressed_kv_cache=compressed_kv_cache, - sparse_topk_lens=sparse_topk_lens, - seq_lens=seq_lens, - out=output, - bmm1_scale=bmm1_scale, - bmm2_scale=bmm2_scale, - sinks=self.attn_sink, - cum_seq_lens_q=query_start_loc, - max_q_len=max_q_len, - ) - - def _forward_mixed_flashinfer( + def _forward_flashinfer( self, q: torch.Tensor, kv_cache: torch.Tensor | None, @@ -1182,38 +1084,34 @@ def _forward_mixed_flashinfer( swa_only: bool, output: torch.Tensor, ) -> None: - assert self.kv_cache_torch_dtype == torch.float8_e4m3fn - assert q.dtype == torch.float8_e4m3fn + assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn) num_decodes = swa_metadata.num_decodes num_prefills = swa_metadata.num_prefills num_decode_tokens = swa_metadata.num_decode_tokens num_prefill_tokens = swa_metadata.num_prefill_tokens num_reqs = num_decodes + num_prefills num_tokens = num_decode_tokens + num_prefill_tokens + if num_tokens == 0: + return assert swa_metadata.seq_lens_int32 is not None assert swa_metadata.query_start_loc is not None assert swa_metadata.query_start_loc_cpu is not None assert swa_metadata.token_to_req_indices is not None assert swa_metadata.decode_swa_indices is not None + assert swa_metadata.decode_swa_sparse_topk_lens is not None - decode_swa_indices = swa_metadata.decode_swa_indices.view( - num_decode_tokens, -1 + decode_swa_indices = swa_metadata.decode_swa_indices.reshape( + num_decode_tokens, self.window_size ) - if decode_swa_indices.shape[-1] != self.window_size: - raise ValueError( - f"DeepSeek V4 FlashInfer path expects {self.window_size} SWA " - f"indices, got {decode_swa_indices.shape[-1]}" - ) + decode_sparse_topk_lens = swa_metadata.decode_swa_sparse_topk_lens if swa_only: - assert swa_metadata.decode_swa_sparse_topk_lens is not None assert self.topk_indices_buffer is not None compressed_kv_cache = swa_k_cache decode_compressed_indices = None - decode_sparse_topk_lens = swa_metadata.decode_swa_sparse_topk_lens prefill_topk_indices = self.topk_indices_buffer[ - num_decode_tokens : num_decode_tokens + num_prefill_tokens + num_decode_tokens:num_tokens, :0 ] compressed_block_table = None compressed_block_size = swa_metadata.block_size @@ -1227,36 +1125,55 @@ def _forward_mixed_flashinfer( if self.compress_ratio == 4: assert self.topk_indices_buffer is not None - assert swa_metadata.is_valid_token is not None - is_valid = swa_metadata.is_valid_token[:num_decode_tokens] - decode_global_indices, decode_sparse_topk_lens = ( - compute_global_topk_indices_and_lens( - self.topk_indices_buffer[:num_decode_tokens], - swa_metadata.token_to_req_indices, - attn_metadata.block_table[:num_decodes], - compressed_block_size, - is_valid, - self.window_size, + if num_prefill_tokens > 0: + prefill_topk_indices = self.topk_indices_buffer[ + num_decode_tokens:num_tokens + ] + top_k = prefill_topk_indices.shape[-1] + else: + prefill_topk_indices = self.topk_indices_buffer[:0, :0] + top_k = 0 + + if num_decode_tokens > 0: + assert swa_metadata.is_valid_token is not None + is_valid = swa_metadata.is_valid_token[:num_decode_tokens] + decode_global_indices, decode_sparse_topk_lens = ( + compute_global_topk_indices_and_lens( + self.topk_indices_buffer[:num_decode_tokens], + swa_metadata.token_to_req_indices, + attn_metadata.block_table[:num_decodes], + compressed_block_size, + is_valid, + self.window_size, + ) ) - ) - decode_compressed_indices = decode_global_indices.view( - num_decode_tokens, -1 - ) - prefill_topk_indices = self.topk_indices_buffer[ - num_decode_tokens : num_decode_tokens + num_prefill_tokens - ] - else: - assert attn_metadata.c128a_global_decode_topk_indices is not None - assert attn_metadata.c128a_decode_topk_lens is not None - assert attn_metadata.c128a_prefill_topk_indices is not None - decode_compressed_indices = ( - attn_metadata.c128a_global_decode_topk_indices.view( + decode_compressed_indices = decode_global_indices.view( num_decode_tokens, -1 ) - ) - decode_sparse_topk_lens = attn_metadata.c128a_decode_topk_lens - prefill_topk_indices = attn_metadata.c128a_prefill_topk_indices - top_k = prefill_topk_indices.shape[-1] + else: + decode_compressed_indices = prefill_topk_indices[:0, :0] + else: + if num_prefill_tokens > 0: + assert attn_metadata.c128a_prefill_topk_indices is not None + prefill_topk_indices = attn_metadata.c128a_prefill_topk_indices + top_k = prefill_topk_indices.shape[-1] + else: + prefill_topk_indices = decode_swa_indices[:0, :0] + top_k = 0 + + if num_decode_tokens > 0: + assert attn_metadata.c128a_global_decode_topk_indices is not None + assert attn_metadata.c128a_decode_topk_lens is not None + decode_compressed_indices = ( + attn_metadata.c128a_global_decode_topk_indices.view( + num_decode_tokens, -1 + ) + ) + decode_sparse_topk_lens = attn_metadata.c128a_decode_topk_lens + if num_prefill_tokens == 0: + prefill_topk_indices = decode_compressed_indices[:0, :0] + else: + decode_compressed_indices = prefill_topk_indices[:0, :0] query_start_loc = swa_metadata.query_start_loc[: num_reqs + 1] query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_reqs + 1] @@ -1280,8 +1197,19 @@ def _forward_mixed_flashinfer( top_k, ) + query = q + bmm1_scale: float | torch.Tensor = self.scale + bmm2_scale: float | torch.Tensor = 1.0 + if self.kv_cache_torch_dtype == torch.float8_e4m3fn: + assert query.dtype == torch.float8_e4m3fn + bmm1_scale = self._flashinfer_fp8_bmm1_scale + bmm2_scale = self._flashinfer_fp8_bmm2_scale + else: + assert query.dtype == torch.bfloat16 + query = query.contiguous() + flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( - query=q, + query=query, swa_kv_cache=swa_k_cache, workspace_buffer=_get_flashinfer_dsv4_workspace(q.device), sparse_indices=sparse_indices, @@ -1289,8 +1217,8 @@ def _forward_mixed_flashinfer( sparse_topk_lens=sparse_topk_lens, seq_lens=seq_lens, out=output, - bmm1_scale=self._flashinfer_fp8_bmm1_scale, - bmm2_scale=self._flashinfer_fp8_bmm2_scale, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, sinks=self.attn_sink, cum_seq_lens_q=query_start_loc, max_q_len=max_q_len, @@ -1347,67 +1275,6 @@ def _forward_prefill( top_k = 0 N = 0 - if ( - self.kv_cache_torch_dtype == torch.float8_e4m3fn - and not current_platform.is_rocm() - ): - assert q.dtype == torch.float8_e4m3fn - assert swa_metadata.prefill_query_start_loc is not None - assert swa_metadata.prefill_query_start_loc_cpu is not None - assert swa_metadata.seq_lens_int32 is not None - - prefill_query_start_loc = swa_metadata.prefill_query_start_loc - prefill_query_start_loc_cpu = swa_metadata.prefill_query_start_loc_cpu - query_lens_cpu = ( - prefill_query_start_loc_cpu[1:] - prefill_query_start_loc_cpu[:-1] - ) - max_q_len = int(query_lens_cpu.max().item()) - seq_lens_int32 = swa_metadata.seq_lens_int32[ - num_decodes : num_decodes + num_prefills - ] - swa_block_table = swa_metadata.block_table[num_decodes:] - - if swa_only: - compressed_kv_cache = swa_k_cache - compressed_block_table = None - compressed_block_size = swa_metadata.block_size - else: - assert compressed_k_cache is not None - assert attn_metadata is not None - compressed_kv_cache = compressed_k_cache - compressed_block_table = attn_metadata.block_table[num_decodes:] - compressed_block_size = attn_metadata.block_size // self.compress_ratio - - sparse_indices, sparse_topk_lens = build_flashinfer_prefill_sparse_indices( - topk_indices[:num_prefill_tokens], - prefill_query_start_loc, - seq_lens_int32, - swa_block_table, - swa_metadata.block_size, - compressed_block_table, - compressed_block_size, - self.window_size, - self.compress_ratio, - top_k, - ) - - flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( - query=q, - swa_kv_cache=swa_k_cache, - workspace_buffer=_get_flashinfer_dsv4_workspace(q.device), - sparse_indices=sparse_indices, - compressed_kv_cache=compressed_kv_cache, - sparse_topk_lens=sparse_topk_lens, - seq_lens=seq_lens_int32, - out=output, - bmm1_scale=self._flashinfer_fp8_bmm1_scale, - bmm2_scale=self._flashinfer_fp8_bmm2_scale, - sinks=self.attn_sink, - cum_seq_lens_q=prefill_query_start_loc, - max_q_len=max_q_len, - ) - return - M = N + self.window_size + self.max_num_batched_tokens num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE From 377fa2bf5926c4427b2e15c351358be243b10fe5 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Mon, 11 May 2026 03:10:27 -0700 Subject: [PATCH 08/24] Clean up DeepSeek V4 FlashInfer FP8 path --- vllm/envs.py | 14 - .../layers/deepseek_v4_attention.py | 35 +- vllm/v1/attention/backends/mla/sparse_swa.py | 24 +- .../attention/ops/deepseek_v4_ops/__init__.py | 4 - .../ops/deepseek_v4_ops/cache_utils.py | 327 +----------------- 5 files changed, 16 insertions(+), 388 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 45a1cbd9e133..c9be71f9ef04 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -177,9 +177,6 @@ VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_USE_FLASHINFER_MOE_INT4: bool = False - VLLM_DSV4_FLASHINFER_FP8_SCALE: float | None = None - VLLM_DSV4_FLASHINFER_FP8_Q_SCALE: float | None = None - VLLM_DSV4_FLASHINFER_FP8_KV_SCALE: float | None = None VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) @@ -1328,17 +1325,6 @@ def _get_or_set_default() -> str: "VLLM_USE_FLASHINFER_MOE_INT4": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_INT4", "0")) ), - # Global and Q/KV-specific per-tensor scales for DeepSeek V4 FlashInfer - # sparse MLA FP8 cache/query tensors. - "VLLM_DSV4_FLASHINFER_FP8_SCALE": lambda: maybe_convert_float( - os.getenv("VLLM_DSV4_FLASHINFER_FP8_SCALE") - ), - "VLLM_DSV4_FLASHINFER_FP8_Q_SCALE": lambda: maybe_convert_float( - os.getenv("VLLM_DSV4_FLASHINFER_FP8_Q_SCALE") - ), - "VLLM_DSV4_FLASHINFER_FP8_KV_SCALE": lambda: maybe_convert_float( - os.getenv("VLLM_DSV4_FLASHINFER_FP8_KV_SCALE") - ), # If set to 1, use the FlashInfer # MXFP8 (activation) x MXFP4 (weight) MoE backend. "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool( diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 693829d76588..f668a254f5b0 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,7 +4,6 @@ DeepseekV4 MLA Attention Layer """ -import os from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -90,7 +89,6 @@ _FLASHINFER_DSV4_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 _flashinfer_dsv4_workspace_by_device: dict[torch.device, torch.Tensor] = {} -_DEFAULT_FLASHINFER_DSV4_FP8_SCALE = 1.0 / 32.0 # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time @@ -98,19 +96,6 @@ PREFILL_CHUNK_SIZE = 4 -def _get_dsv4_flashinfer_fp8_scale(kind: str) -> float: - specific_name = f"VLLM_DSV4_FLASHINFER_FP8_{kind.upper()}_SCALE" - for env_name in (specific_name, "VLLM_DSV4_FLASHINFER_FP8_SCALE"): - value = os.environ.get(env_name) - if value is None: - continue - scale = float(value) - if scale <= 0.0: - raise ValueError(f"{env_name} must be positive, got {value!r}") - return scale - return _DEFAULT_FLASHINFER_DSV4_FP8_SCALE - - def _normalize_dsv4_kv_cache_dtype( cache_config: CacheConfig | None, ) -> str: @@ -817,8 +802,10 @@ def __init__( fp8_q_scale = 1.0 fp8_kv_scale = 1.0 if self.kv_cache_torch_dtype == torch.float8_e4m3fn: - fp8_q_scale = _get_dsv4_flashinfer_fp8_scale("q") - fp8_kv_scale = _get_dsv4_flashinfer_fp8_scale("kv") + # TODO: load the per-tensor FP8 Q and KV scales from checkpoint + # weights. Use unit scales until the scale tensor names are wired. + fp8_q_scale = 1.0 + fp8_kv_scale = 1.0 self.register_buffer( "_flashinfer_fp8_q_scale", torch.tensor([fp8_q_scale], dtype=torch.float32), @@ -925,7 +912,12 @@ def forward( num_prefills = swa_metadata.num_prefills num_decode_tokens = swa_metadata.num_decode_tokens - if self.kv_cache_torch_dtype != torch.uint8 and not current_platform.is_rocm(): + if self.kv_cache_torch_dtype != torch.uint8: + if current_platform.is_rocm(): + raise NotImplementedError( + "DeepSeek V4 BF16/per-tensor FP8 FlashInfer sparse MLA " + "cache path is CUDA-only." + ) self._forward_flashinfer( q=q, kv_cache=self_kv_cache, @@ -1094,7 +1086,7 @@ def _forward_flashinfer( if num_tokens == 0: return - assert swa_metadata.seq_lens_int32 is not None + assert swa_metadata.seq_lens is not None assert swa_metadata.query_start_loc is not None assert swa_metadata.query_start_loc_cpu is not None assert swa_metadata.token_to_req_indices is not None @@ -1179,7 +1171,8 @@ def _forward_flashinfer( query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_reqs + 1] query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_q_len = int(query_lens_cpu.max().item()) - seq_lens = swa_metadata.seq_lens_int32[:num_reqs] + seq_lens = swa_metadata.seq_lens[:num_reqs] + assert seq_lens.dtype == torch.int32 sparse_indices, sparse_topk_lens = build_flashinfer_mixed_sparse_indices( decode_swa_indices, decode_compressed_indices, @@ -1298,7 +1291,6 @@ def _forward_prefill( block_table=block_table[chunk_start:chunk_end], block_size=attn_metadata.block_size // self.compress_ratio, offset=0, - fp8_scale=self._flashinfer_fp8_kv_scale, ) # Gather SWA KV @@ -1311,7 +1303,6 @@ def _forward_prefill( block_table=swa_block_table[chunk_start:chunk_end], block_size=swa_metadata.block_size, offset=N, - fp8_scale=self._flashinfer_fp8_kv_scale, ) # Combine the topk indices and SWA indices for gathered KV cache diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 76c7dc6f7d79..62e6362dd4cd 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -154,7 +154,6 @@ class DeepseekSparseSWAMetadata: slot_mapping: torch.Tensor block_size: int seq_lens: torch.Tensor | None = None # [num_seqs] - seq_lens_int32: torch.Tensor | None = None # [num_seqs] query_start_loc: torch.Tensor | None = None # [num_seqs + 1] query_start_loc_cpu: torch.Tensor | None = None # [num_seqs + 1] @@ -173,8 +172,6 @@ class DeepseekSparseSWAMetadata: # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. prefill_seq_lens: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None - prefill_query_start_loc: torch.Tensor | None = None - prefill_query_start_loc_cpu: torch.Tensor | None = None # Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta # per present DeepseekV4 layer type, shared across all ~60 layers of that type @@ -291,9 +288,7 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - seq_lens_int32 = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to( - torch.int32 - ) + assert seq_lens.dtype == torch.int32 # Split into decode and prefill portions using configurable threshold (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( @@ -333,9 +328,8 @@ def build( deepseek_v4_fields = self._build_deepseek_v4_metadata( num_decodes, num_prefills, - seq_lens_int32, + seq_lens, query_start_loc, - query_start_loc_cpu, ) # Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta @@ -346,7 +340,6 @@ def build( return DeepseekSparseSWAMetadata( seq_lens=seq_lens, - seq_lens_int32=seq_lens_int32, query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, block_table=block_table, @@ -404,7 +397,6 @@ def _build_deepseek_v4_metadata( num_prefills: int, seq_lens: torch.Tensor, query_start_loc: torch.Tensor, - query_start_loc_cpu: torch.Tensor, ) -> dict[str, torch.Tensor | None]: """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. @@ -433,18 +425,6 @@ def _build_deepseek_v4_metadata( result["prefill_seq_lens"] = seq_lens[num_decodes:] result["prefill_gather_lens"] = pfx_gather_lens - prefill_query_start_loc = query_start_loc[ - num_decodes : num_decodes + num_prefills + 1 - ] - prefill_query_start_loc_cpu = query_start_loc_cpu[ - num_decodes : num_decodes + num_prefills + 1 - ] - result["prefill_query_start_loc"] = ( - prefill_query_start_loc - prefill_query_start_loc[0] - ) - result["prefill_query_start_loc_cpu"] = ( - prefill_query_start_loc_cpu - prefill_query_start_loc_cpu[0] - ) return result diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index bdd1a5955044..9e5499cbee4b 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .cache_utils import ( - build_flashinfer_decode_sparse_indices, build_flashinfer_mixed_sparse_indices, - build_flashinfer_prefill_sparse_indices, combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, @@ -17,9 +15,7 @@ __all__ = [ "MXFP4_BLOCK_SIZE", - "build_flashinfer_decode_sparse_indices", "build_flashinfer_mixed_sparse_indices", - "build_flashinfer_prefill_sparse_indices", "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", "dequantize_and_gather_k_cache", diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 163ae1dc94c8..664c19d436f2 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -526,58 +526,6 @@ def _dequantize_and_gather_k_kernel( tl.store(output_row_ptr + bf16_output_offset + chunk_offsets, bf16_vals) -@triton.jit -def _gather_full_k_cache_kernel( - out_ptr, - out_stride0, - out_stride1, - k_cache_ptr, - k_cache_stride0, - k_cache_stride1, - seq_lens_ptr, - block_table_ptr, - offset, - gather_lens_ptr, - fp8_scale_ptr, - max_blocks_per_seq: tl.constexpr, - cache_block_size: tl.constexpr, - output_dim: tl.constexpr, - STORE_FP8: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - batch_idx = tl.program_id(0) - worker_id = tl.program_id(1) - num_workers = tl.num_programs(1) - - seq_len = tl.load(seq_lens_ptr + batch_idx) - if gather_lens_ptr is not None: - gather_len = tl.load(gather_lens_ptr + batch_idx) - else: - gather_len = seq_len - start_pos = seq_len - gather_len - - offsets = tl.arange(0, BLOCK_SIZE) - mask = offsets < output_dim - for i in range(worker_id, gather_len, num_workers): - pos = start_pos + i - block_in_seq = pos // cache_block_size - pos_in_block = pos % cache_block_size - block_table_row_ptr = block_table_ptr + batch_idx * max_blocks_per_seq - physical_block_idx = tl.load(block_table_row_ptr + block_in_seq) - - cache_row = ( - k_cache_ptr - + physical_block_idx.to(tl.int64) * k_cache_stride0 - + pos_in_block * k_cache_stride1 - ) - values = tl.load(cache_row + offsets, mask=mask, other=0.0) - if STORE_FP8: - values = values.to(tl.float32) * tl.load(fp8_scale_ptr) - - out_row = out_ptr + batch_idx * out_stride0 + (offset + i) * out_stride1 - tl.store(out_row + offsets, values.to(tl.bfloat16), mask=mask) - - def dequantize_and_gather_k_cache_triton( # [num_reqs, max_num_tokens, head_size] out: torch.Tensor, @@ -637,36 +585,8 @@ def dequantize_and_gather_k_cache( block_table: torch.Tensor, block_size: int, offset: int, - fp8_scale: torch.Tensor | None = None, ) -> None: - if k_cache.dtype != torch.uint8: - assert k_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) - assert k_cache.dim() == 3 and k_cache.shape[-1] == 512 - if k_cache.dtype == torch.float8_e4m3fn: - assert fp8_scale is not None - num_reqs = seq_lens.shape[0] - NUM_WORKERS = 128 - _gather_full_k_cache_kernel[(num_reqs, NUM_WORKERS)]( - out, - out.stride(0), - out.stride(1), - k_cache, - k_cache.stride(0), - k_cache.stride(1), - seq_lens, - block_table, - offset, - gather_lens, - fp8_scale if fp8_scale is not None else k_cache, - max_blocks_per_seq=block_table.shape[-1], - cache_block_size=block_size, - output_dim=512, - STORE_FP8=k_cache.dtype == torch.float8_e4m3fn, - BLOCK_SIZE=512, - num_warps=8, - ) - return - + assert k_cache.dtype == torch.uint8 if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. from .dequant_gather_k_cutedsl import dequantize_and_gather_k_cache_cutedsl @@ -718,161 +638,6 @@ def compute_global_topk_indices_and_lens( return global_topk_indices, topk_lens -def build_flashinfer_decode_sparse_indices( - swa_indices: torch.Tensor, - compressed_indices: torch.Tensor | None, - window_size: int, -) -> torch.Tensor: - """Build FlashInfer DSV4 SWA-first sparse indices without torch.cat/pad.""" - assert swa_indices.dtype == torch.int32 - assert swa_indices.dim() == 2 and swa_indices.shape[-1] == window_size - if compressed_indices is None: - return swa_indices - assert compressed_indices.dtype == torch.int32 - assert compressed_indices.dim() == 2 - assert compressed_indices.shape[0] == swa_indices.shape[0] - - num_tokens = swa_indices.shape[0] - compressed_topk = compressed_indices.shape[-1] - padded_compressed_topk = (compressed_topk + 3) // 4 * 4 - sparse_indices = torch.empty( - (num_tokens, window_size + padded_compressed_topk), - dtype=torch.int32, - device=swa_indices.device, - ) - if num_tokens == 0: - return sparse_indices - - _merge_flashinfer_sparse_indices_kernel[(num_tokens,)]( - sparse_indices, - sparse_indices.stride(0), - swa_indices, - swa_indices.stride(0), - compressed_indices, - compressed_indices.stride(0), - WINDOW_SIZE=window_size, - COMPRESSED_TOPK=compressed_topk, - PADDED_COMPRESSED_TOPK=padded_compressed_topk, - BLOCK_SIZE=1024, - ) - return sparse_indices - - -@triton.jit -def _merge_flashinfer_sparse_indices_kernel( - sparse_indices_ptr, - sparse_indices_stride, - swa_indices_ptr, - swa_indices_stride, - compressed_indices_ptr, - compressed_indices_stride, - WINDOW_SIZE: tl.constexpr, - COMPRESSED_TOPK: tl.constexpr, - PADDED_COMPRESSED_TOPK: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - token_idx = tl.program_id(0) - - for i in range(0, WINDOW_SIZE, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < WINDOW_SIZE - values = tl.load( - swa_indices_ptr + token_idx * swa_indices_stride + offset, - mask=mask, - other=-1, - ) - tl.store( - sparse_indices_ptr + token_idx * sparse_indices_stride + offset, - values, - mask=mask, - ) - - for i in range(0, PADDED_COMPRESSED_TOPK, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < PADDED_COMPRESSED_TOPK - values = tl.load( - compressed_indices_ptr + token_idx * compressed_indices_stride + offset, - mask=offset < COMPRESSED_TOPK, - other=-1, - ) - tl.store( - sparse_indices_ptr - + token_idx * sparse_indices_stride - + WINDOW_SIZE - + offset, - values, - mask=mask, - ) - - -def build_flashinfer_prefill_sparse_indices( - topk_indices: torch.Tensor, - query_start_loc: torch.Tensor, - seq_lens: torch.Tensor, - swa_block_table: torch.Tensor, - swa_block_size: int, - compressed_block_table: torch.Tensor | None, - compressed_block_size: int, - window_size: int, - compress_ratio: int, - topk: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """Build FlashInfer DSV4 prefill sparse indices from physical KV slots. - - FlashInfer's DSV4 launcher expects the first `window_size` columns to refer - to the SWA KV pool and the remaining columns to refer to the compressed KV - pool. Unlike FlashMLA prefill, these are physical cache slots rather than - offsets into a gathered BF16 workspace. - """ - assert topk_indices.dtype == torch.int32 - assert query_start_loc.dtype == torch.int32 - assert seq_lens.dtype == torch.int32 - assert swa_block_table.dtype == torch.int32 - assert topk_indices.dim() == 2 - - num_tokens = topk_indices.shape[0] - num_reqs = seq_lens.shape[0] - padded_topk = (topk + 3) // 4 * 4 - sparse_indices = torch.empty( - (num_tokens, window_size + padded_topk), - dtype=torch.int32, - device=topk_indices.device, - ) - sparse_topk_lens = torch.empty( - num_tokens, dtype=torch.int32, device=topk_indices.device - ) - if num_tokens == 0: - return sparse_indices, sparse_topk_lens - - if compressed_block_table is None: - compressed_block_table = swa_block_table - assert compressed_block_table.dtype == torch.int32 - - NUM_WORKERS = 128 - _build_flashinfer_prefill_sparse_indices_kernel[(num_reqs, NUM_WORKERS)]( - sparse_indices, - sparse_indices.stride(0), - sparse_topk_lens, - topk_indices, - topk_indices.stride(0), - query_start_loc, - seq_lens, - swa_block_table, - swa_block_table.stride(0), - swa_block_size, - compressed_block_table, - compressed_block_table.stride(0), - compressed_block_size, - WINDOW_SIZE=window_size, - COMPRESS_RATIO=compress_ratio, - TOP_K=topk, - PADDED_TOP_K=padded_topk, - TOPK_STRIDE=topk_indices.shape[-1], - BLOCK_SIZE=1024, - ) - return sparse_indices, sparse_topk_lens - - def build_flashinfer_mixed_sparse_indices( decode_swa_indices: torch.Tensor, decode_compressed_indices: torch.Tensor | None, @@ -966,96 +731,6 @@ def build_flashinfer_mixed_sparse_indices( return sparse_indices, sparse_topk_lens -@triton.jit -def _build_flashinfer_prefill_sparse_indices_kernel( - sparse_indices_ptr, - sparse_indices_stride, - sparse_topk_lens_ptr, - topk_indices_ptr, - topk_indices_stride, - query_start_loc_ptr, - seq_lens_ptr, - swa_block_table_ptr, - swa_block_table_stride, - swa_block_size, - compressed_block_table_ptr, - compressed_block_table_stride, - compressed_block_size, - WINDOW_SIZE: tl.constexpr, - COMPRESS_RATIO: tl.constexpr, - TOP_K: tl.constexpr, - PADDED_TOP_K: tl.constexpr, - TOPK_STRIDE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - batch_idx = tl.program_id(0) - worker_id = tl.program_id(1) - num_workers = tl.num_programs(1) - - query_start = tl.load(query_start_loc_ptr + batch_idx) - query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - query_len = query_end - query_start - seq_len = tl.load(seq_lens_ptr + batch_idx) - start_pos = seq_len - query_len - - for token_idx in range(query_start + worker_id, query_end, num_workers): - token_idx_in_query = token_idx - query_start - pos = start_pos + token_idx_in_query - swa_len = tl.minimum(pos + 1, WINDOW_SIZE) - swa_start_pos = pos - swa_len + 1 - topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) - - for i in range(0, WINDOW_SIZE, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < WINDOW_SIZE - pos_offset = swa_start_pos + offset - block_indices = pos_offset // swa_block_size - block_numbers = tl.load( - swa_block_table_ptr + batch_idx * swa_block_table_stride + block_indices, - mask=mask & (offset < swa_len), - other=-1, - ) - block_offsets = pos_offset % swa_block_size - slot_ids = block_numbers * swa_block_size + block_offsets - slot_ids = tl.where(offset < swa_len, slot_ids, -1) - tl.store( - sparse_indices_ptr + token_idx * sparse_indices_stride + offset, - slot_ids, - mask=mask, - ) - - for i in range(0, PADDED_TOP_K, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < PADDED_TOP_K - local_idx = tl.load( - topk_indices_ptr + token_idx * topk_indices_stride + offset, - mask=(offset < TOPK_STRIDE) & (offset < topk_len), - other=-1, - ) - is_valid = local_idx >= 0 - block_indices = local_idx // compressed_block_size - block_numbers = tl.load( - compressed_block_table_ptr - + batch_idx * compressed_block_table_stride - + block_indices, - mask=mask & is_valid, - other=-1, - ) - block_offsets = local_idx % compressed_block_size - slot_ids = block_numbers * compressed_block_size + block_offsets - slot_ids = tl.where((offset < topk_len) & is_valid, slot_ids, -1) - tl.store( - sparse_indices_ptr - + token_idx * sparse_indices_stride - + WINDOW_SIZE - + offset, - slot_ids, - mask=mask, - ) - - tl.store(sparse_topk_lens_ptr + token_idx, WINDOW_SIZE + topk_len) - - @triton.jit def _build_flashinfer_mixed_sparse_indices_kernel( sparse_indices_ptr, From 9c8f3b0d3d8289042daee557b69c09cc753523fb Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Mon, 11 May 2026 03:19:30 -0700 Subject: [PATCH 09/24] Fix DeepSeek V4 post-rebase test coverage --- tests/kernels/test_fused_inv_rope_fp8_quant.py | 9 ++++----- .../test_indexer_deepseek_v4_slot_mapping.py | 18 ++++++++++++++++-- vllm/model_executor/models/deepseek_v4.py | 9 +++++---- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/tests/kernels/test_fused_inv_rope_fp8_quant.py b/tests/kernels/test_fused_inv_rope_fp8_quant.py index 10561a8a0304..29ba1f8c6fe1 100644 --- a/tests/kernels/test_fused_inv_rope_fp8_quant.py +++ b/tests/kernels/test_fused_inv_rope_fp8_quant.py @@ -725,13 +725,12 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups): This catches stride/layout bugs that only manifest when the einsum kernel actually consumes the quantized activations. """ - from deep_gemm.utils.math import ceil_div - from vllm.utils.deep_gemm import ( fp8_einsum, per_block_cast_to_fp8, transform_sf_into_required_layout, ) + from vllm.utils.math_utils import cdiv heads_per_group = num_heads // n_groups d = heads_per_group * HEAD_DIM @@ -753,8 +752,8 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups): w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) w_scale = torch.empty( n_groups, - ceil_div(o_lora_rank, 128), - ceil_div(d, 128), + cdiv(o_lora_rank, 128), + cdiv(d, 128), device=device, dtype=torch.float32, ) @@ -809,7 +808,7 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups): # Einsum output: Triton and CUDA both rotate in fp32 now, so diffs # come from fp32 ordering and UE8M0 boundary shifts only. # Use relative diff (same metric as test_fp8_einsum.py). - from deep_gemm.testing import calc_diff + from vllm.third_party.deep_gemm.testing import calc_diff z_diff = calc_diff(z_fused, z_ref) assert z_diff < 0.01, ( diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py index 159bb8af3fb9..8d5a9b5880d2 100644 --- a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py +++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py @@ -3,6 +3,7 @@ import pytest import torch +from transformers import LlamaConfig from tests.v1.attention.utils import create_vllm_config from vllm.v1.attention.backend import CommonAttentionMetadata @@ -11,7 +12,9 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") -def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size(): +def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size( + tmp_path, +): """Regression test: DeepseekV4 compression path must compute slot_mapping from compressed positions, not reuse the uncompressed common metadata mapping. """ @@ -25,7 +28,18 @@ def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_ dtype=torch.bfloat16, compress_ratio=4, ) - vllm_config = create_vllm_config(max_model_len=1024) + hf_config = LlamaConfig( + architectures=["LlamaForCausalLM"], + hidden_size=128, + intermediate_size=256, + max_position_embeddings=2048, + num_attention_heads=4, + num_hidden_layers=1, + num_key_value_heads=4, + vocab_size=32000, + ) + hf_config.save_pretrained(tmp_path) + vllm_config = create_vllm_config(model_name=str(tmp_path), max_model_len=1024) builder = DeepseekV32IndexerMetadataBuilder( kv_cache_spec=kv_cache_spec, layer_names=["dummy"], diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 45b33035a76f..3cd5a9cb73a1 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -483,10 +483,11 @@ def __init__( # Register in the static forward context so the custom-op wrapper # can look up this module by name from within a torch.compile graph. - compilation_config = vllm_config.compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self + if prefix: + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self def _map_global_expert_id(self, expert_id: int) -> int: if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx: From 706e92d07a6b1533228f343ba17041b578fdcecf Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Mon, 11 May 2026 04:54:11 -0700 Subject: [PATCH 10/24] Preserve DeepSeek V4 Flash accuracy after rebase --- vllm/envs.py | 7 ++ vllm/model_executor/models/deepseek_v4.py | 86 ++++++++++++------- .../ops/deepseek_v4_ops/fused_indexer_q.py | 3 +- 3 files changed, 62 insertions(+), 34 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index c9be71f9ef04..6d292a3599c8 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,6 +57,7 @@ VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512 + VLLM_DSV4_USE_CUTEDSL_INDEXER_Q: bool = False VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True @@ -907,6 +908,12 @@ def _get_or_set_default() -> str: "VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int( os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512") ), + # Enable the CuteDSL MXFP4 indexer Q kernel for DeepSeek V4 sparse MLA. + # The default Triton path preserves the validated DeepSeek V4 Flash + # end-to-end FP8 accuracy. + "VLLM_DSV4_USE_CUTEDSL_INDEXER_Q": lambda: bool( + int(os.getenv("VLLM_DSV4_USE_CUTEDSL_INDEXER_Q", "0")) + ), # If set, the OpenAI API server will stay alive even after the underlying # AsyncLLMEngine errors and stops serving requests "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool( diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 3cd5a9cb73a1..f347383084e5 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -1217,10 +1217,33 @@ def forward( x: torch.Tensor, positions: torch.Tensor, input_ids: torch.Tensor | None, - post_mix: torch.Tensor | None, - res_mix: torch.Tensor | None, - residual: torch.Tensor | None, - ) -> torch.Tensor: + post_mix: torch.Tensor | None = None, + res_mix: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + if residual is None and post_mix is None and res_mix is None: + residual = x + x, post, comb = self.hc_pre( + x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) + x = self.attn_norm(x) + x = self.attn(positions, x, None) + x = self.hc_post(x, residual, post, comb) + + residual = x + x, post, comb = self.hc_pre( + x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base + ) + x = self.ffn_norm(x) + x = self.ffn(x, input_ids) + x = self.hc_post(x, residual, post, comb) + return x + if residual is None: # Run standalone hc_pre on first layer residual = x @@ -1228,37 +1251,28 @@ def forward( x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base ) else: - residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre( - x, + # Keep the post/pre operations separate here. The fused post_pre + # path is faster, but changes DeepSeek V4 Flash FP8 accuracy. + residual = self.hc_post(x, residual, post_mix, res_mix) + x, post_mix, res_mix = self.hc_pre( residual, - post_mix, - res_mix, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base, - self.rms_norm_eps, - self.hc_eps, - self.hc_eps, - self.hc_post_alpha, - self.hc_sinkhorn_iters, ) x = self.attn_norm(x) x = self.attn(positions, x, None) - residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre( - x, + assert residual is not None + assert post_mix is not None + assert res_mix is not None + residual = self.hc_post(x, residual, post_mix, res_mix) + x, post_mix, res_mix = self.hc_pre( residual, - post_mix, - res_mix, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base, - self.rms_norm_eps, - self.hc_eps, - self.hc_eps, - self.hc_post_alpha, - self.hc_sinkhorn_iters, ) x = self.ffn_norm(x) @@ -1414,18 +1428,24 @@ def forward( if self.use_mega_moe: input_ids = input_ids.to(torch.int64) - residual, post_mix, res_mix = None, None, None - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual, post_mix, res_mix = layer( - hidden_states, - positions, - input_ids, - post_mix, - res_mix, - residual, - ) + if get_pp_group().world_size == 1: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states = layer(hidden_states, positions, input_ids) else: - hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) + residual, post_mix, res_mix = None, None, None + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual, post_mix, res_mix = layer( + hidden_states, + positions, + input_ids, + post_mix, + res_mix, + residual, + ) + else: + hidden_states = layer.hc_post( + hidden_states, residual, post_mix, res_mix + ) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index ec880f7ab4c4..99d1e0f301cc 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -3,6 +3,7 @@ import torch +import vllm.envs as envs from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_cutedsl @@ -344,7 +345,7 @@ def fused_indexer_q_rope_quant( dtype=torch.uint8, device=index_q.device, ) - if has_cutedsl(): + if envs.VLLM_DSV4_USE_CUTEDSL_INDEXER_Q and has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. from .fused_indexer_q_cutedsl import ( fused_indexer_q_rope_quant_mxfp4_cutedsl, From 159b9b7db69cf5986b8008bfcd91c6ded67a0fc2 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Mon, 11 May 2026 05:36:21 -0700 Subject: [PATCH 11/24] Rename DeepSeek V4 per-tensor FP8 KV cache dtype --- tests/test_config.py | 12 ++++++++++++ vllm/config/cache.py | 1 + vllm/model_executor/layers/deepseek_v4_attention.py | 8 ++++++-- vllm/utils/torch_utils.py | 1 + vllm/v1/attention/backends/mla/flashmla_sparse.py | 1 + 5 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 57d1e1bc686b..5ceaeaf3d19f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,6 +14,7 @@ import vllm.config.vllm as vllm_config_module from vllm.compilation.backends import VllmBackend from vllm.config import ( + CacheConfig, CompilationConfig, KernelConfig, ModelConfig, @@ -33,10 +34,21 @@ OptimizationLevel, ) from vllm.platforms import current_platform +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE DEVICE_TYPE = current_platform.device_type +def test_fp8_per_tensor_cache_dtype(): + cfg = CacheConfig(cache_dtype="fp8_per_tensor") + + assert cfg.cache_dtype == "fp8_per_tensor" + assert ( + STR_DTYPE_TO_TORCH_DTYPE["fp8_per_tensor"] + is STR_DTYPE_TO_TORCH_DTYPE["fp8_inc"] + ) + + def test_compile_config_repr_succeeds(): # setup: VllmBackend mutates the config object config = VllmConfig() diff --git a/vllm/config/cache.py b/vllm/config/cache.py index ae5023f1e348..ccd56f68e875 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -22,6 +22,7 @@ "fp8", "fp8_e4m3", "fp8_e5m2", + "fp8_per_tensor", "fp8_inc", "fp8_ds_mla", "turboquant_k8v4", diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index f668a254f5b0..b451ca138e0a 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -104,6 +104,10 @@ def _normalize_dsv4_kv_cache_dtype( assert cache_config is not None cache_config.cache_dtype = "fp8_ds_mla" return "fp8_ds_mla" + if kv_cache_dtype == "fp8_inc": + assert cache_config is not None + cache_config.cache_dtype = "fp8_per_tensor" + return "fp8_per_tensor" return kv_cache_dtype @@ -113,7 +117,7 @@ def _dsv4_kv_cache_torch_dtype( ) -> torch.dtype: if kv_cache_dtype == "fp8_ds_mla": return torch.uint8 - if kv_cache_dtype == "fp8_inc": + if kv_cache_dtype in ("fp8_per_tensor", "fp8_inc"): return torch.float8_e4m3fn if kv_cache_dtype == "bfloat16": return torch.bfloat16 @@ -124,7 +128,7 @@ def _dsv4_kv_cache_torch_dtype( raise ValueError( "DeepSeek V4 FlashInfer sparse MLA supports only BF16 or per-tensor " f"FP8 E4M3 KV cache; got kv_cache_dtype={kv_cache_dtype}. Use " - "`bfloat16`/`auto` for BF16, `fp8_inc` for per-tensor FP8, or " + "`bfloat16`/`auto` for BF16, `fp8_per_tensor` for per-tensor FP8, or " "`fp8`/`fp8_ds_mla` for the legacy UE8M0 FlashMLA path." ) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 12ec5b0fcc66..3db6cc2195ae 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -41,6 +41,7 @@ "int8": torch.int8, "int8_per_token_head": torch.int8, "fp8_per_token_head": torch.uint8, + "fp8_per_tensor": torch.float8_e4m3fn, "fp8_inc": torch.float8_e4m3fn, "fp8_ds_mla": torch.uint8, "turboquant_k8v4": torch.uint8, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 15b01d1f0e21..d6dec4847b13 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -95,6 +95,7 @@ class FlashMLASparseBackend(AttentionBackend): supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "bfloat16", + "fp8_per_tensor", "fp8_inc", "fp8_ds_mla", "fp8", # alias for fp8_ds_mla From f7510df1fc2bec2fbd3ed3749581da8179f099ac Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Mon, 11 May 2026 19:04:34 -0700 Subject: [PATCH 12/24] Clean DeepSeek V4 post-rebase indexer path --- vllm/envs.py | 7 -- .../layers/deepseek_v4_attention.py | 25 +------ .../ops/deepseek_v4_ops/cache_utils.py | 16 ++++- .../ops/deepseek_v4_ops/fused_indexer_q.py | 68 +++++++------------ 4 files changed, 40 insertions(+), 76 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 6d292a3599c8..c9be71f9ef04 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,7 +57,6 @@ VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512 - VLLM_DSV4_USE_CUTEDSL_INDEXER_Q: bool = False VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True @@ -908,12 +907,6 @@ def _get_or_set_default() -> str: "VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int( os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512") ), - # Enable the CuteDSL MXFP4 indexer Q kernel for DeepSeek V4 sparse MLA. - # The default Triton path preserves the validated DeepSeek V4 Flash - # end-to-end FP8 accuracy. - "VLLM_DSV4_USE_CUTEDSL_INDEXER_Q": lambda: bool( - int(os.getenv("VLLM_DSV4_USE_CUTEDSL_INDEXER_Q", "0")) - ), # If set, the OpenAI API server will stay alive even after the underlying # AsyncLLMEngine errors and stops serving requests "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool( diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index b451ca138e0a..a418d82174f3 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -30,11 +30,7 @@ fused_q_kv_rmsnorm, qnorm_rope_and_insert_full_k_cache, ) -from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( - rocm_forward_decode_fallback, - rocm_inv_rope_einsum, - rocm_sparse_attn_prefill, -) +from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum if TYPE_CHECKING: from vllm.v1.attention.backends.mla.sparse_swa import ( @@ -995,25 +991,6 @@ def _forward_decode( swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens - if current_platform.is_rocm(): - rocm_forward_decode_fallback( - q=q, - kv_cache=kv_cache, - swa_k_cache=self.swa_cache_layer.kv_cache, - swa_only=swa_only, - topk_indices=topk_indices, - topk_lens=topk_lens, - swa_indices=swa_indices, - swa_lens=swa_lens, - attn_sink=self.attn_sink, - scale=self.scale, - head_dim=self.head_dim, - nope_head_dim=self.nope_head_dim, - rope_head_dim=self.rope_head_dim, - output=output, - ) - return - assert self.kv_cache_torch_dtype == torch.uint8 # We treat queries in the same seq as different queries diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 664c19d436f2..fca66b8c1ddc 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -14,12 +14,26 @@ window indices for sparse prefill. """ +from functools import lru_cache + import torch from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_cutedsl +@lru_cache(maxsize=1) +def _has_dequant_gather_k_cutedsl() -> bool: + if not has_cutedsl(): + return False + try: + from cutlass import cute + + return hasattr(cute.nvgpu, "LoadCacheMode") + except Exception: + return False + + @triton.jit def _apply_gptj_rope_512( values, @@ -587,7 +601,7 @@ def dequantize_and_gather_k_cache( offset: int, ) -> None: assert k_cache.dtype == torch.uint8 - if has_cutedsl(): + if _has_dequant_gather_k_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. from .dequant_gather_k_cutedsl import dequantize_and_gather_k_cache_cutedsl diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index 99d1e0f301cc..dad4f33c0dc2 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -3,9 +3,7 @@ import torch -import vllm.envs as envs from vllm.triton_utils import tl, triton -from vllm.utils.import_utils import has_cutedsl # MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. MXFP4_BLOCK_SIZE = 32 @@ -345,48 +343,30 @@ def fused_indexer_q_rope_quant( dtype=torch.uint8, device=index_q.device, ) - if envs.VLLM_DSV4_USE_CUTEDSL_INDEXER_Q and has_cutedsl(): - # lazily import, otherwise some tests fail due to CUDA driver init failure. - from .fused_indexer_q_cutedsl import ( - fused_indexer_q_rope_quant_mxfp4_cutedsl, - ) - - fused_indexer_q_rope_quant_mxfp4_cutedsl( - positions, - index_q, - index_q_cos_sin_cache, - index_weights, - index_weights_softmax_scale, - index_weights_head_scale, - index_q_packed, - index_q_scale, - index_weights_out, - ) - else: - _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)]( - positions, - index_q, - index_q.stride(0), - index_q.stride(1), - index_q_cos_sin_cache, - index_q_cos_sin_cache.stride(0), - index_q_cos_sin_cache.shape[-1] // 2, - index_q_packed, - index_q_packed.stride(0), - index_q_packed.stride(1), - index_q_scale, - index_q_scale.stride(0), - index_q_scale.stride(1), - index_q_head_dim, - MXFP4_BLOCK_SIZE, - index_weights, - index_weights.stride(0), - index_weights_softmax_scale, - index_weights_head_scale, - index_weights_out, - index_weights_out.stride(0), - num_warps=1, # TODO: Tune this - ) + _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)]( + positions, + index_q, + index_q.stride(0), + index_q.stride(1), + index_q_cos_sin_cache, + index_q_cos_sin_cache.stride(0), + index_q_cos_sin_cache.shape[-1] // 2, + index_q_packed, + index_q_packed.stride(0), + index_q_packed.stride(1), + index_q_scale, + index_q_scale.stride(0), + index_q_scale.stride(1), + index_q_head_dim, + MXFP4_BLOCK_SIZE, + index_weights, + index_weights.stride(0), + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out, + index_weights_out.stride(0), + num_warps=1, # TODO: Tune this + ) # Values stay uint8 (2 E2M1 nibbles per byte). Scales are 4 ue8m0 # bytes per (token, head) reinterpreted as one int32, then squeezed From 0de10a276eced6036e9752a223cba53c996d31cd Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Mon, 11 May 2026 19:53:54 -0700 Subject: [PATCH 13/24] Clean DeepSeek V4 FlashInfer sparse path --- vllm/envs.py | 6 - .../layers/deepseek_compressor.py | 25 ++- .../layers/deepseek_v4_attention.py | 17 +- vllm/transformers_utils/config.py | 39 ---- .../attention/backends/mla/flashmla_sparse.py | 14 +- vllm/v1/attention/backends/mla/sparse_swa.py | 11 +- .../ops/deepseek_v4_ops/cache_utils.py | 94 ++++------ .../fused_compress_quant_cache.py | 169 +++++------------- .../ops/deepseek_v4_ops/fused_indexer_q.py | 67 ++++--- 9 files changed, 149 insertions(+), 293 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index c9be71f9ef04..03230eed0688 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -289,12 +289,6 @@ def maybe_convert_int(value: str | None) -> int | None: return int(value) -def maybe_convert_float(value: str | None) -> float | None: - if value is None: - return None - return float(value) - - def maybe_convert_bool(value: str | None) -> bool | None: if value is None: return None diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index edc71a686db3..8d5e917bc98a 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -27,7 +27,6 @@ _fused_kv_compress_norm_rope_insert_indexer_attn, _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, _fused_kv_compress_norm_rope_insert_sparse_attn, - _fused_kv_compress_norm_rope_insert_sparse_attn_full_cache, ) from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( MXFP4_BLOCK_SIZE, @@ -345,9 +344,10 @@ def forward( k_cache_layer = self._static_forward_context[self.k_cache_prefix] kv_cache = k_cache_layer.kv_cache - if self.head_dim == 512 and kv_cache.dtype != torch.uint8: - assert kv_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) - _fused_kv_compress_norm_rope_insert_sparse_attn_full_cache[(num_actual,)]( + if self.head_dim == 512: + if kv_cache.dtype != torch.uint8: + assert kv_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) + self._fused_kernel[(num_actual,)]( # state cache state_cache, state_cache.stride(0), @@ -369,7 +369,11 @@ def forward( kv_cache, k_cache_metadata.slot_mapping, kv_cache.shape[1], - k_cache_layer._flashinfer_fp8_kv_scale, + ( + k_cache_layer._flashinfer_fp8_kv_scale + if kv_cache.dtype != torch.uint8 + else self.norm.weight + ), # constexprs HEAD_SIZE=self.head_dim, TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim), @@ -377,9 +381,16 @@ def forward( COMPRESS_RATIO=self.compress_ratio, OVERLAP=self.overlap, ROPE_HEAD_DIM=self.rope_head_dim, + FP8_MAX=448.0, + QUANT_BLOCK=self._quant_block, + TOKEN_STRIDE=self._token_stride, + SCALE_DIM=self._scale_dim, KV_BLOCK_STRIDE=kv_cache.stride(0), - KV_TOKEN_STRIDE=kv_cache.stride(1), - STORE_FP8=kv_cache.dtype == torch.float8_e4m3fn, + KV_TOKEN_STRIDE=( + kv_cache.stride(1) if kv_cache.dtype != torch.uint8 else 0 + ), + STORE_FULL_CACHE=kv_cache.dtype != torch.uint8, + STORE_FULL_FP8=kv_cache.dtype == torch.float8_e4m3fn, num_warps=self._num_warps, **pdl_kwargs, ) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index a418d82174f3..7726228a44f0 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -963,9 +963,6 @@ def _forward_decode( topk_indices = None topk_lens = None - use_flashinfer_dsv4 = ( - self.kv_cache_torch_dtype != torch.uint8 and not current_platform.is_rocm() - ) if not swa_only: assert attn_metadata is not None assert swa_metadata.is_valid_token is not None @@ -980,7 +977,6 @@ def _forward_decode( attn_metadata.block_table[:num_decodes], block_size, is_valid, - self.window_size if use_flashinfer_dsv4 else 0, ) topk_indices = global_indices.view(num_decode_tokens, 1, -1) else: @@ -1072,12 +1068,12 @@ def _forward_flashinfer( assert swa_metadata.query_start_loc_cpu is not None assert swa_metadata.token_to_req_indices is not None assert swa_metadata.decode_swa_indices is not None - assert swa_metadata.decode_swa_sparse_topk_lens is not None + assert swa_metadata.decode_zero_compressed_lens is not None decode_swa_indices = swa_metadata.decode_swa_indices.reshape( num_decode_tokens, self.window_size ) - decode_sparse_topk_lens = swa_metadata.decode_swa_sparse_topk_lens + decode_compressed_topk_lens = swa_metadata.decode_zero_compressed_lens if swa_only: assert self.topk_indices_buffer is not None @@ -1089,6 +1085,8 @@ def _forward_flashinfer( compressed_block_table = None compressed_block_size = swa_metadata.block_size top_k = 0 + sparse_indices = None + sparse_topk_lens = None else: assert kv_cache is not None assert attn_metadata is not None @@ -1110,14 +1108,13 @@ def _forward_flashinfer( if num_decode_tokens > 0: assert swa_metadata.is_valid_token is not None is_valid = swa_metadata.is_valid_token[:num_decode_tokens] - decode_global_indices, decode_sparse_topk_lens = ( + decode_global_indices, decode_compressed_topk_lens = ( compute_global_topk_indices_and_lens( self.topk_indices_buffer[:num_decode_tokens], swa_metadata.token_to_req_indices, attn_metadata.block_table[:num_decodes], compressed_block_size, is_valid, - self.window_size, ) ) decode_compressed_indices = decode_global_indices.view( @@ -1142,7 +1139,7 @@ def _forward_flashinfer( num_decode_tokens, -1 ) ) - decode_sparse_topk_lens = attn_metadata.c128a_decode_topk_lens + decode_compressed_topk_lens = attn_metadata.c128a_decode_topk_lens if num_prefill_tokens == 0: prefill_topk_indices = decode_compressed_indices[:0, :0] else: @@ -1157,7 +1154,7 @@ def _forward_flashinfer( sparse_indices, sparse_topk_lens = build_flashinfer_mixed_sparse_indices( decode_swa_indices, decode_compressed_indices, - decode_sparse_topk_lens, + decode_compressed_topk_lens, prefill_topk_indices[:num_prefill_tokens], query_start_loc, seq_lens, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5db672537ad6..ac62e9b279af 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -63,8 +63,6 @@ from transformers import AutoConfig MISTRAL_CONFIG_NAME = "params.json" -DEEPSEEK_V4_INFERENCE_CONFIG_NAME = "inference/config.json" -DEEPSEEK_V4_GLOBAL_INFERENCE_FIELDS = ("expert_dtype", "scale_fmt") logger = init_logger(__name__) @@ -85,41 +83,6 @@ def __getitem__(self, key): return getattr(configs, value) - -def _maybe_apply_deepseek_v4_inference_config( - config: PretrainedConfig, - model: str, - revision: str | None, -) -> None: - """Promote DeepSeek V4 inference config fields into hf_config.""" - if getattr(config, "model_type", None) != "deepseek_v4": - return - if not file_or_path_exists(model, DEEPSEEK_V4_INFERENCE_CONFIG_NAME, revision): - return - - inference_config = get_hf_file_to_dict( - DEEPSEEK_V4_INFERENCE_CONFIG_NAME, model, revision - ) - updates = { - key: inference_config[key] - for key in DEEPSEEK_V4_GLOBAL_INFERENCE_FIELDS - if key in inference_config and not hasattr(config, key) - } - if "scale_fmt" not in updates and not hasattr(config, "scale_fmt"): - quantization_config = getattr(config, "quantization_config", None) - if isinstance(quantization_config, dict): - scale_fmt = quantization_config.get("scale_fmt") - if scale_fmt is not None: - updates["scale_fmt"] = scale_fmt - - if updates: - config.update(updates) - logger.info_once( - "Applied DeepSeek V4 inference globals to hf_config: %s", - tuple(sorted(updates)), - ) - - _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( afmoe="AfmoeConfig", bagel="BagelConfig", @@ -851,8 +814,6 @@ def apply_gguf_default(key: str, gguf_default: Any): scale_fmt, ) - _maybe_apply_deepseek_v4_inference_config(config, model, revision) - if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) config.update(hf_overrides_kw) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index d6dec4847b13..c8c11236c1fb 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -255,7 +255,6 @@ class Chunk: # Prefill: local topk indices (used by combine_topk_swa_indices). c128a_prefill_topk_indices: torch.Tensor | None = None - def get_prefill_workspace_size(max_model_len: int): # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. # May be tuned later. @@ -359,11 +358,6 @@ def __init__( if self.is_deepseek_v4: assert hasattr(self.kv_cache_spec, "compress_ratio") self.compress_ratio = self.kv_cache_spec.compress_ratio - self.sliding_window = hf_config.sliding_window - self.use_dsv4_flashinfer_decode = ( - self.kv_cache_spec.dtype != torch.uint8 - and not current_platform.is_rocm() - ) # Pre-allocate compressed slot mapping buffer for CUDA graph # address stability when compress_ratio > 1. if self.compress_ratio > 1: @@ -703,9 +697,6 @@ def _build_c128a_metadata( self.c128a_decode_lens_buffer, self.c128a_prefill_buffer, max_compressed_tokens=self.c128a_max_compressed, - decode_lens_base=( - self.sliding_window if self.use_dsv4_flashinfer_decode else 0 - ), ) result: dict[str, torch.Tensor | None] = {} @@ -1074,7 +1065,6 @@ def build_c128a_topk_metadata( decode_lens_buffer: torch.Tensor, prefill_buffer: torch.Tensor, max_compressed_tokens: int = 8192, - decode_lens_base: int = 0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Single kernel for all C128A tokens (decode + prefill). @@ -1109,7 +1099,6 @@ def build_c128a_topk_metadata( block_table.stride(0), block_size, slot_mapping, - decode_lens_base, BLOCK_SIZE=1024, ) return global_decode, decode_lens, prefill_local @@ -1134,7 +1123,6 @@ def _build_c128a_topk_metadata_kernel( block_table_stride, block_size, slot_mapping_ptr, - decode_lens_base: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) @@ -1170,7 +1158,7 @@ def _build_c128a_topk_metadata_kernel( tl.store( decode_lens_ptr + token_idx, - tl.where(is_valid_token, count, 0) + decode_lens_base, + tl.where(is_valid_token, count, 0), ) else: # --- Prefill: write local indices --- diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 62e6362dd4cd..6743165fc07b 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -161,7 +161,9 @@ class DeepseekSparseSWAMetadata: token_to_req_indices: torch.Tensor | None = None # [num_tokens] decode_swa_indices: torch.Tensor | None = None # [num_decode_tokens, window_size] decode_swa_lens: torch.Tensor | None = None # [num_decode_tokens] - decode_swa_sparse_topk_lens: torch.Tensor | None = None # [num_decode_tokens] + # Zero compressed-prefix lengths used by the FlashInfer mixed sparse-index + # builder for SWA-only decode rows. + decode_zero_compressed_lens: torch.Tensor | None = None # [num_decode_tokens] # Number of decode/prefill requests/tokens (batch is reordered: decodes first) num_decodes: int = 0 @@ -256,9 +258,8 @@ def __init__(self, *args, **kwargs): dtype=torch.int32, device=self.device, ) - self.decode_swa_sparse_topk_lens = torch.full( + self.decode_zero_compressed_lens = torch.zeros( (max_tokens,), - self.window_size, dtype=torch.int32, device=self.device, ) @@ -348,8 +349,8 @@ def build( token_to_req_indices=token_to_req_indices, decode_swa_indices=self.decode_swa_indices[:num_decode_tokens], decode_swa_lens=self.decode_swa_lens[:num_decode_tokens], - decode_swa_sparse_topk_lens=( - self.decode_swa_sparse_topk_lens[:num_decode_tokens] + decode_zero_compressed_lens=( + self.decode_zero_compressed_lens[:num_decode_tokens] ), block_size=self.block_size, num_decodes=num_decodes, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index fca66b8c1ddc..071f4cbffb7a 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -66,7 +66,7 @@ def _apply_gptj_rope_512( @triton.jit -def _qnorm_rope_kernel( +def _qnorm_rope_insert_full_cache_kernel( q_ptr, q_stride0, q_stride1, @@ -75,14 +75,23 @@ def _qnorm_rope_kernel( q_fp8_stride0, q_fp8_stride1, q_fp8_scale_inv_ptr, + kv_ptr, + kv_stride0, + slot_mapping_ptr, positions_ptr, cos_sin_cache_ptr, cos_sin_stride, + k_cache_ptr, + cache_stride0, + cache_stride1, + cache_block_size, + fp8_scale_ptr, eps, HEAD_SIZE: tl.constexpr, ROPE_HEAD_DIM: tl.constexpr, BLOCK_SIZE: tl.constexpr, STORE_Q_FP8: tl.constexpr, + STORE_KV_FP8: tl.constexpr, ): token_idx = tl.program_id(0) head_idx = tl.program_id(1) @@ -98,12 +107,13 @@ def _qnorm_rope_kernel( ) return + position = tl.load(positions_ptr + token_idx) + q_row = q_ptr + token_idx * q_stride0 + head_idx * q_stride1 values = tl.load(q_row + offsets, mask=mask, other=0.0).to(tl.float32) variance = tl.sum(values * values, axis=0) / HEAD_SIZE values *= tl.rsqrt(variance + eps) - position = tl.load(positions_ptr + token_idx) values = _apply_gptj_rope_512( values, position, @@ -124,39 +134,18 @@ def _qnorm_rope_kernel( ) tl.store(q_fp8_row + offsets, q_fp8_values.to(tl.float8e4nv), mask=mask) + if head_idx != 0: + return -@triton.jit -def _kv_rope_insert_full_cache_kernel( - kv_ptr, - kv_stride0, - slot_mapping_ptr, - positions_ptr, - cos_sin_cache_ptr, - cos_sin_stride, - k_cache_ptr, - cache_stride0, - cache_stride1, - cache_block_size, - fp8_scale_ptr, - HEAD_SIZE: tl.constexpr, - ROPE_HEAD_DIM: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - STORE_FP8: tl.constexpr, -): - token_idx = tl.program_id(0) slot_idx = tl.load(slot_mapping_ptr + token_idx) if slot_idx < 0: return - offsets = tl.arange(0, BLOCK_SIZE) - mask = offsets < HEAD_SIZE - values = tl.load( + kv_values = tl.load( kv_ptr + token_idx * kv_stride0 + offsets, mask=mask, other=0.0 ).to(tl.float32) - - position = tl.load(positions_ptr + token_idx) - values = _apply_gptj_rope_512( - values, + kv_values = _apply_gptj_rope_512( + kv_values, position, cos_sin_cache_ptr, cos_sin_stride, @@ -172,12 +161,12 @@ def _kv_rope_insert_full_cache_kernel( + block_idx.to(tl.int64) * cache_stride0 + (pos_in_block * cache_stride1) ) - if STORE_FP8: + if STORE_KV_FP8: fp8_scale = tl.load(fp8_scale_ptr) - values = tl.clamp(values / fp8_scale, -448.0, 448.0) - tl.store(cache_row + offsets, values.to(tl.float8e4nv), mask=mask) + kv_values = tl.clamp(kv_values / fp8_scale, -448.0, 448.0) + tl.store(cache_row + offsets, kv_values.to(tl.float8e4nv), mask=mask) else: - tl.store(cache_row + offsets, values.to(tl.bfloat16), mask=mask) + tl.store(cache_row + offsets, kv_values.to(tl.bfloat16), mask=mask) def qnorm_rope_and_insert_full_k_cache( @@ -215,7 +204,7 @@ def qnorm_rope_and_insert_full_k_cache( num_tokens_full, num_heads, _ = q.shape q_heads_for_grid = q_fp8.shape[1] if q_fp8 is not None else num_heads - _qnorm_rope_kernel[(num_tokens_full, q_heads_for_grid)]( + _qnorm_rope_insert_full_cache_kernel[(num_tokens_full, q_heads_for_grid)]( q, q.stride(0), q.stride(1), @@ -224,19 +213,6 @@ def qnorm_rope_and_insert_full_k_cache( q_fp8.stride(0) if q_fp8 is not None else q.stride(0), q_fp8.stride(1) if q_fp8 is not None else q.stride(1), q_fp8_scale_inv if q_fp8_scale_inv is not None else fp8_scale, - positions, - cos_sin_cache, - cos_sin_cache.stride(0), - eps, - HEAD_SIZE=512, - ROPE_HEAD_DIM=64, - BLOCK_SIZE=512, - STORE_Q_FP8=q_fp8 is not None, - num_warps=8, - ) - - num_tokens_insert = slot_mapping.shape[0] - _kv_rope_insert_full_cache_kernel[(num_tokens_insert,)]( kv, kv.stride(0), slot_mapping, @@ -248,10 +224,12 @@ def qnorm_rope_and_insert_full_k_cache( k_cache.stride(1), cache_block_size, fp8_scale, + eps, HEAD_SIZE=512, ROPE_HEAD_DIM=64, BLOCK_SIZE=512, - STORE_FP8=k_cache.dtype == torch.float8_e4m3fn, + STORE_Q_FP8=q_fp8 is not None, + STORE_KV_FP8=k_cache.dtype == torch.float8_e4m3fn, num_warps=8, ) @@ -621,7 +599,6 @@ def compute_global_topk_indices_and_lens( block_table: torch.Tensor, block_size: int, is_valid_token: torch.Tensor, - topk_lens_base: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local topk indices to global KV cache slots and count valid entries. @@ -629,7 +606,6 @@ def compute_global_topk_indices_and_lens( 1. Block-table lookup (local index → global slot id) 2. Valid-entry counting (topk_lens per token) 3. Masking padding tokens to length 0 - 4. Optional constant top-k length base for callers with fixed prefixes """ num_tokens = topk_indices.shape[0] global_topk_indices = torch.empty_like(topk_indices) @@ -646,7 +622,6 @@ def compute_global_topk_indices_and_lens( block_table.stride(0), block_size, is_valid_token, - topk_lens_base, TRITON_BLOCK_SIZE=1024, ) return global_topk_indices, topk_lens @@ -655,7 +630,7 @@ def compute_global_topk_indices_and_lens( def build_flashinfer_mixed_sparse_indices( decode_swa_indices: torch.Tensor, decode_compressed_indices: torch.Tensor | None, - decode_sparse_topk_lens: torch.Tensor, + decode_compressed_topk_lens: torch.Tensor, prefill_topk_indices: torch.Tensor, query_start_loc: torch.Tensor, seq_lens: torch.Tensor, @@ -672,7 +647,7 @@ def build_flashinfer_mixed_sparse_indices( assert decode_swa_indices.dtype == torch.int32 assert decode_swa_indices.dim() == 2 assert decode_swa_indices.shape[-1] == window_size - assert decode_sparse_topk_lens.dtype == torch.int32 + assert decode_compressed_topk_lens.dtype == torch.int32 assert prefill_topk_indices.dtype == torch.int32 assert prefill_topk_indices.dim() == 2 assert query_start_loc.dtype == torch.int32 @@ -684,7 +659,7 @@ def build_flashinfer_mixed_sparse_indices( num_prefill_tokens = prefill_topk_indices.shape[0] num_tokens = num_decode_tokens + num_prefill_tokens assert token_to_req_indices.shape[0] >= num_tokens - assert decode_sparse_topk_lens.shape[0] >= num_decode_tokens + assert decode_compressed_topk_lens.shape[0] >= num_decode_tokens decode_compressed_topk = 0 if decode_compressed_indices is None: @@ -720,7 +695,7 @@ def build_flashinfer_mixed_sparse_indices( decode_swa_indices.stride(0), decode_compressed_indices, decode_compressed_indices.stride(0), - decode_sparse_topk_lens, + decode_compressed_topk_lens, prefill_topk_indices, prefill_topk_indices.stride(0), query_start_loc, @@ -754,7 +729,7 @@ def _build_flashinfer_mixed_sparse_indices_kernel( decode_swa_stride, decode_compressed_indices_ptr, decode_compressed_stride, - decode_sparse_topk_lens_ptr, + decode_compressed_topk_lens_ptr, prefill_topk_indices_ptr, prefill_topk_stride, query_start_loc_ptr, @@ -813,7 +788,7 @@ def _build_flashinfer_mixed_sparse_indices_kernel( tl.store( sparse_topk_lens_ptr + token_idx, - tl.load(decode_sparse_topk_lens_ptr + token_idx), + WINDOW_SIZE + tl.load(decode_compressed_topk_lens_ptr + token_idx), ) return @@ -894,7 +869,6 @@ def _compute_global_topk_indices_and_lens_kernel( block_table_stride, block_size, is_valid_token_ptr, - topk_lens_base: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) @@ -929,11 +903,7 @@ def _compute_global_topk_indices_and_lens_kernel( ) count += tl.sum(is_valid.to(tl.int32), axis=0) - # Mask compressed entries for padding tokens, then add any fixed prefix. - tl.store( - topk_lens_ptr + token_idx, - tl.where(is_valid_token, count, 0) + topk_lens_base, - ) + tl.store(topk_lens_ptr + token_idx, tl.where(is_valid_token, count, 0)) # FlashMLA sparse prefill asserts `params.topk % B_TOPK == 0` (see diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py index 7a9cc17b6e03..3012af309830 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py @@ -50,6 +50,7 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( k_cache_ptr, kv_slot_mapping_ptr, kv_cache_block_size, + fp8_scale_ptr, # ── constexprs ── HEAD_SIZE: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, @@ -62,6 +63,9 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( TOKEN_STRIDE: tl.constexpr, # 576 for DeepseekV4 SCALE_DIM: tl.constexpr, # 8 for DeepseekV4 (7 real + 1 pad) KV_BLOCK_STRIDE: tl.constexpr, + KV_TOKEN_STRIDE: tl.constexpr, + STORE_FULL_CACHE: tl.constexpr, + STORE_FULL_FP8: tl.constexpr, ): """Fused compress → RMSNorm → FP8 quant (nope) → RoPE → bf16 store (rope). @@ -141,6 +145,44 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( kv_block_idx = kv_slot_idx // kv_cache_block_size kv_pos_in_block = kv_slot_idx % kv_cache_block_size + if STORE_FULL_CACHE: + cache_row = ( + k_cache_ptr + + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE + + kv_pos_in_block * KV_TOKEN_STRIDE + ) + + NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM + HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 + NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 + NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 + + pair_2d = tl.reshape(normed, (NUM_PAIRS, 2)) + even, odd = tl.split(pair_2d) + pair_idx = tl.arange(0, NUM_PAIRS) + rope_pair_local = pair_idx - NOPE_PAIRS + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + + compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO + cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride + cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, + mask=is_rope_pair, + other=0.0) + + new_even = tl.where(is_rope_pair, even * cos_v - odd * sin_v, even) + new_odd = tl.where(is_rope_pair, odd * cos_v + even * sin_v, odd) + result = tl.interleave(new_even, new_odd).to(tl.bfloat16).to(tl.float32) + + if STORE_FULL_FP8: + fp8_scale = tl.load(fp8_scale_ptr) + result = tl.clamp(result / fp8_scale, -448.0, 448.0) + tl.store(cache_row + block, result.to(tl.float8e4nv), mask=mask) + else: + tl.store(cache_row + block, result.to(tl.bfloat16), mask=mask) + return + cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE fp8_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE scale_ptr = ( @@ -214,133 +256,6 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( tl.store(bf16_ptr + rope_local, result.to(tl.bfloat16), mask=is_rope) -@triton.jit -def _fused_kv_compress_norm_rope_insert_sparse_attn_full_cache( - # ── state cache (compressor internal state) ── - state_cache_ptr, - state_cache_stride0, - state_cache_stride1, - # ── metadata ── - token_to_req_indices_ptr, - positions_ptr, - slot_mapping_ptr, - block_table_ptr, - block_table_stride, - block_size, - # ── RMSNorm ── - rms_norm_weight_ptr, - rms_norm_eps, - # ── RoPE ── - cos_sin_cache_ptr, - cos_sin_stride, - # ── KV cache output ── - k_cache_ptr, - kv_slot_mapping_ptr, - kv_cache_block_size, - fp8_scale_ptr, - # ── constexprs ── - HEAD_SIZE: tl.constexpr, - TRITON_BLOCK_SIZE: tl.constexpr, - STATE_WIDTH: tl.constexpr, - COMPRESS_RATIO: tl.constexpr, - OVERLAP: tl.constexpr, - ROPE_HEAD_DIM: tl.constexpr, - KV_BLOCK_STRIDE: tl.constexpr, - KV_TOKEN_STRIDE: tl.constexpr, - STORE_FP8: tl.constexpr, -): - """Fused compress/RMSNorm/RoPE store for BF16 or per-tensor FP8 caches.""" - token_idx = tl.program_id(0) - - slot_id = tl.load(slot_mapping_ptr + token_idx) - if slot_id < 0: - return - - position = tl.load(positions_ptr + token_idx) - if (position + 1) % COMPRESS_RATIO != 0: - return - - req_idx = tl.load(token_to_req_indices_ptr + token_idx) - - start = position - (1 + OVERLAP) * COMPRESS_RATIO + 1 - tokens = tl.arange(0, (1 + OVERLAP) * COMPRESS_RATIO) - pos = start + tokens - mask_pos = pos >= 0 - - block_indices = pos // block_size - block_numbers = tl.load( - block_table_ptr + req_idx * block_table_stride + block_indices, - mask=mask_pos, - other=0, - ) - block_offsets = pos % block_size - head_offset = (tokens >= COMPRESS_RATIO).to(tl.int32) * HEAD_SIZE - - block = tl.arange(0, TRITON_BLOCK_SIZE) - mask = block < HEAD_SIZE - block_numbers_i64 = block_numbers.to(tl.int64) - row_base = ( - state_cache_ptr - + block_numbers_i64 * state_cache_stride0 - + block_offsets * state_cache_stride1 - + head_offset - ) - combined_mask = mask_pos[:, None] & mask[None, :] - - score = tl.load( - row_base[:, None] + STATE_WIDTH + block[None, :], - mask=combined_mask, - other=float("-inf"), - ) - score = tl.softmax(score, dim=0) - kv = tl.load(row_base[:, None] + block[None, :], mask=combined_mask, other=0.0) - compressed_kv = tl.sum(kv * score, axis=0) - - rms_w = tl.load(rms_norm_weight_ptr + block, mask=mask, other=0.0) - variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_SIZE - normed = compressed_kv * tl.rsqrt(variance + rms_norm_eps) * rms_w - - kv_slot_idx = tl.load(kv_slot_mapping_ptr + token_idx) - if kv_slot_idx < 0: - return - kv_block_idx = kv_slot_idx // kv_cache_block_size - kv_pos_in_block = kv_slot_idx % kv_cache_block_size - cache_row = ( - k_cache_ptr - + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE - + kv_pos_in_block * KV_TOKEN_STRIDE - ) - - NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM - HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 - NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 - NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 - - pair_2d = tl.reshape(normed, (NUM_PAIRS, 2)) - even, odd = tl.split(pair_2d) - pair_idx = tl.arange(0, NUM_PAIRS) - rope_pair_local = pair_idx - NOPE_PAIRS - is_rope_pair = rope_pair_local >= 0 - cs_idx = tl.maximum(rope_pair_local, 0) - - compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO - cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride - cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) - sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0) - - new_even = tl.where(is_rope_pair, even * cos_v - odd * sin_v, even) - new_odd = tl.where(is_rope_pair, odd * cos_v + even * sin_v, odd) - result = tl.interleave(new_even, new_odd) - result = result.to(tl.bfloat16).to(tl.float32) - - if STORE_FP8: - fp8_scale = tl.load(fp8_scale_ptr) - result = tl.clamp(result / fp8_scale, -448.0, 448.0) - tl.store(cache_row + block, result.to(tl.float8e4nv), mask=mask) - else: - tl.store(cache_row + block, result.to(tl.bfloat16), mask=mask) - - # ============================================================================= # Indexer path (head=128, all FP8, single quant block) # ============================================================================= diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index dad4f33c0dc2..ec880f7ab4c4 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -4,6 +4,7 @@ import torch from vllm.triton_utils import tl, triton +from vllm.utils.import_utils import has_cutedsl # MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. MXFP4_BLOCK_SIZE = 32 @@ -343,30 +344,48 @@ def fused_indexer_q_rope_quant( dtype=torch.uint8, device=index_q.device, ) - _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)]( - positions, - index_q, - index_q.stride(0), - index_q.stride(1), - index_q_cos_sin_cache, - index_q_cos_sin_cache.stride(0), - index_q_cos_sin_cache.shape[-1] // 2, - index_q_packed, - index_q_packed.stride(0), - index_q_packed.stride(1), - index_q_scale, - index_q_scale.stride(0), - index_q_scale.stride(1), - index_q_head_dim, - MXFP4_BLOCK_SIZE, - index_weights, - index_weights.stride(0), - index_weights_softmax_scale, - index_weights_head_scale, - index_weights_out, - index_weights_out.stride(0), - num_warps=1, # TODO: Tune this - ) + if has_cutedsl(): + # lazily import, otherwise some tests fail due to CUDA driver init failure. + from .fused_indexer_q_cutedsl import ( + fused_indexer_q_rope_quant_mxfp4_cutedsl, + ) + + fused_indexer_q_rope_quant_mxfp4_cutedsl( + positions, + index_q, + index_q_cos_sin_cache, + index_weights, + index_weights_softmax_scale, + index_weights_head_scale, + index_q_packed, + index_q_scale, + index_weights_out, + ) + else: + _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)]( + positions, + index_q, + index_q.stride(0), + index_q.stride(1), + index_q_cos_sin_cache, + index_q_cos_sin_cache.stride(0), + index_q_cos_sin_cache.shape[-1] // 2, + index_q_packed, + index_q_packed.stride(0), + index_q_packed.stride(1), + index_q_scale, + index_q_scale.stride(0), + index_q_scale.stride(1), + index_q_head_dim, + MXFP4_BLOCK_SIZE, + index_weights, + index_weights.stride(0), + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out, + index_weights_out.stride(0), + num_warps=1, # TODO: Tune this + ) # Values stay uint8 (2 E2M1 nibbles per byte). Scales are 4 ue8m0 # bytes per (token, head) reinterpreted as one int32, then squeezed From 250070e0d58017fca95964d833e046488ee8db10 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Mon, 11 May 2026 20:19:02 -0700 Subject: [PATCH 14/24] Fix merged DeepSeek V4 sparse compressor compile --- .../fused_compress_quant_cache.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py index 3012af309830..4e68c312b406 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py @@ -145,6 +145,11 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( kv_block_idx = kv_slot_idx // kv_cache_block_size kv_pos_in_block = kv_slot_idx % kv_cache_block_size + NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM # 448 + HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 # 32 + NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 + NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 + if STORE_FULL_CACHE: cache_row = ( k_cache_ptr @@ -152,11 +157,6 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( + kv_pos_in_block * KV_TOKEN_STRIDE ) - NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM - HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 - NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 - NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 - pair_2d = tl.reshape(normed, (NUM_PAIRS, 2)) even, odd = tl.split(pair_2d) pair_idx = tl.arange(0, NUM_PAIRS) @@ -184,15 +184,14 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( return cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE - fp8_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE + fp8_ptr = (cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE).to( + tl.pointer_type(tl.uint8) + ) scale_ptr = ( cache_block_ptr + kv_cache_block_size * TOKEN_STRIDE + kv_pos_in_block * SCALE_DIM - ) - - NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM # 448 - HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 # 32 + ).to(tl.pointer_type(tl.uint8)) # FP8 UE8M0 quant: cast fp32 → bf16 → fp32 before quant to match reference. N_QUANT_BLOCKS: tl.constexpr = TRITON_BLOCK_SIZE // QUANT_BLOCK @@ -229,9 +228,6 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( tl.store(scale_ptr + N_NOPE_BLOCKS, tl.zeros((), dtype=tl.uint8)) # Register-based GPT-J RoPE in fp32. - NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 - NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 - pair_2d = tl.reshape(normed, (NUM_PAIRS, 2)) even, odd = tl.split(pair_2d) # each [NUM_PAIRS] fp32 From ed87b1368785be825a5068e364dc4d9fdf1f1acd Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Mon, 11 May 2026 20:57:43 -0700 Subject: [PATCH 15/24] Clean DeepSeek V4 FlashInfer metadata path --- .../kernels/test_fused_inv_rope_fp8_quant.py | 10 +-- .../test_indexer_deepseek_v4_slot_mapping.py | 6 +- .../layers/deepseek_v4_attention.py | 27 +++--- vllm/model_executor/models/deepseek_v4.py | 86 +++++++------------ vllm/v1/attention/backends/mla/sparse_swa.py | 11 --- .../ops/deepseek_v4_ops/cache_utils.py | 57 ++++++++++-- 6 files changed, 103 insertions(+), 94 deletions(-) diff --git a/tests/kernels/test_fused_inv_rope_fp8_quant.py b/tests/kernels/test_fused_inv_rope_fp8_quant.py index 29ba1f8c6fe1..cf001e37b97f 100644 --- a/tests/kernels/test_fused_inv_rope_fp8_quant.py +++ b/tests/kernels/test_fused_inv_rope_fp8_quant.py @@ -725,12 +725,14 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups): This catches stride/layout bugs that only manifest when the einsum kernel actually consumes the quantized activations. """ + from deep_gemm.testing import calc_diff + from deep_gemm.utils.math import ceil_div + from vllm.utils.deep_gemm import ( fp8_einsum, per_block_cast_to_fp8, transform_sf_into_required_layout, ) - from vllm.utils.math_utils import cdiv heads_per_group = num_heads // n_groups d = heads_per_group * HEAD_DIM @@ -752,8 +754,8 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups): w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) w_scale = torch.empty( n_groups, - cdiv(o_lora_rank, 128), - cdiv(d, 128), + ceil_div(o_lora_rank, 128), + ceil_div(d, 128), device=device, dtype=torch.float32, ) @@ -808,8 +810,6 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups): # Einsum output: Triton and CUDA both rotate in fp32 now, so diffs # come from fp32 ordering and UE8M0 boundary shifts only. # Use relative diff (same metric as test_fp8_einsum.py). - from vllm.third_party.deep_gemm.testing import calc_diff - z_diff = calc_diff(z_fused, z_ref) assert z_diff < 0.01, ( f"Einsum output diff too large: {z_diff:.6f} (expected < 0.01)" diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py index 8d5a9b5880d2..7b47264823fc 100644 --- a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py +++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py @@ -3,9 +3,9 @@ import pytest import torch -from transformers import LlamaConfig from tests.v1.attention.utils import create_vllm_config +from vllm.transformers_utils.configs.deepseek_v4 import DeepseekV4Config from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadataBuilder from vllm.v1.kv_cache_interface import MLAAttentionSpec @@ -28,8 +28,8 @@ def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_ dtype=torch.bfloat16, compress_ratio=4, ) - hf_config = LlamaConfig( - architectures=["LlamaForCausalLM"], + hf_config = DeepseekV4Config( + architectures=["DeepseekV4ForCausalLM"], hidden_size=128, intermediate_size=256, max_position_embeddings=2048, diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 7726228a44f0..99772a5334da 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -1068,12 +1068,13 @@ def _forward_flashinfer( assert swa_metadata.query_start_loc_cpu is not None assert swa_metadata.token_to_req_indices is not None assert swa_metadata.decode_swa_indices is not None - assert swa_metadata.decode_zero_compressed_lens is not None decode_swa_indices = swa_metadata.decode_swa_indices.reshape( num_decode_tokens, self.window_size ) - decode_compressed_topk_lens = swa_metadata.decode_zero_compressed_lens + decode_compressed_topk_lens = None + decode_compressed_indices_are_local = False + decode_is_valid_token = None if swa_only: assert self.topk_indices_buffer is not None @@ -1107,19 +1108,13 @@ def _forward_flashinfer( if num_decode_tokens > 0: assert swa_metadata.is_valid_token is not None - is_valid = swa_metadata.is_valid_token[:num_decode_tokens] - decode_global_indices, decode_compressed_topk_lens = ( - compute_global_topk_indices_and_lens( - self.topk_indices_buffer[:num_decode_tokens], - swa_metadata.token_to_req_indices, - attn_metadata.block_table[:num_decodes], - compressed_block_size, - is_valid, - ) - ) - decode_compressed_indices = decode_global_indices.view( - num_decode_tokens, -1 - ) + decode_compressed_indices = self.topk_indices_buffer[ + :num_decode_tokens + ] + decode_compressed_indices_are_local = True + decode_is_valid_token = swa_metadata.is_valid_token[ + :num_decode_tokens + ] else: decode_compressed_indices = prefill_topk_indices[:0, :0] else: @@ -1166,6 +1161,8 @@ def _forward_flashinfer( self.window_size, self.compress_ratio, top_k, + decode_compressed_indices_are_local=decode_compressed_indices_are_local, + decode_is_valid_token=decode_is_valid_token, ) query = q diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index f347383084e5..3cd5a9cb73a1 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -1217,33 +1217,10 @@ def forward( x: torch.Tensor, positions: torch.Tensor, input_ids: torch.Tensor | None, - post_mix: torch.Tensor | None = None, - res_mix: torch.Tensor | None = None, - residual: torch.Tensor | None = None, - ) -> torch.Tensor | tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: - if residual is None and post_mix is None and res_mix is None: - residual = x - x, post, comb = self.hc_pre( - x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base - ) - x = self.attn_norm(x) - x = self.attn(positions, x, None) - x = self.hc_post(x, residual, post, comb) - - residual = x - x, post, comb = self.hc_pre( - x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base - ) - x = self.ffn_norm(x) - x = self.ffn(x, input_ids) - x = self.hc_post(x, residual, post, comb) - return x - + post_mix: torch.Tensor | None, + res_mix: torch.Tensor | None, + residual: torch.Tensor | None, + ) -> torch.Tensor: if residual is None: # Run standalone hc_pre on first layer residual = x @@ -1251,28 +1228,37 @@ def forward( x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base ) else: - # Keep the post/pre operations separate here. The fused post_pre - # path is faster, but changes DeepSeek V4 Flash FP8 accuracy. - residual = self.hc_post(x, residual, post_mix, res_mix) - x, post_mix, res_mix = self.hc_pre( + residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre( + x, residual, + post_mix, + res_mix, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base, + self.rms_norm_eps, + self.hc_eps, + self.hc_eps, + self.hc_post_alpha, + self.hc_sinkhorn_iters, ) x = self.attn_norm(x) x = self.attn(positions, x, None) - assert residual is not None - assert post_mix is not None - assert res_mix is not None - residual = self.hc_post(x, residual, post_mix, res_mix) - x, post_mix, res_mix = self.hc_pre( + residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre( + x, residual, + post_mix, + res_mix, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base, + self.rms_norm_eps, + self.hc_eps, + self.hc_eps, + self.hc_post_alpha, + self.hc_sinkhorn_iters, ) x = self.ffn_norm(x) @@ -1428,24 +1414,18 @@ def forward( if self.use_mega_moe: input_ids = input_ids.to(torch.int64) - if get_pp_group().world_size == 1: - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states = layer(hidden_states, positions, input_ids) + residual, post_mix, res_mix = None, None, None + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual, post_mix, res_mix = layer( + hidden_states, + positions, + input_ids, + post_mix, + res_mix, + residual, + ) else: - residual, post_mix, res_mix = None, None, None - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual, post_mix, res_mix = layer( - hidden_states, - positions, - input_ids, - post_mix, - res_mix, - residual, - ) - else: - hidden_states = layer.hc_post( - hidden_states, residual, post_mix, res_mix - ) + hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 6743165fc07b..8aa061d6ee42 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -161,9 +161,6 @@ class DeepseekSparseSWAMetadata: token_to_req_indices: torch.Tensor | None = None # [num_tokens] decode_swa_indices: torch.Tensor | None = None # [num_decode_tokens, window_size] decode_swa_lens: torch.Tensor | None = None # [num_decode_tokens] - # Zero compressed-prefix lengths used by the FlashInfer mixed sparse-index - # builder for SWA-only decode rows. - decode_zero_compressed_lens: torch.Tensor | None = None # [num_decode_tokens] # Number of decode/prefill requests/tokens (batch is reordered: decodes first) num_decodes: int = 0 @@ -258,11 +255,6 @@ def __init__(self, *args, **kwargs): dtype=torch.int32, device=self.device, ) - self.decode_zero_compressed_lens = torch.zeros( - (max_tokens,), - dtype=torch.int32, - device=self.device, - ) self.is_valid_token = torch.zeros( max_tokens, dtype=torch.bool, @@ -349,9 +341,6 @@ def build( token_to_req_indices=token_to_req_indices, decode_swa_indices=self.decode_swa_indices[:num_decode_tokens], decode_swa_lens=self.decode_swa_lens[:num_decode_tokens], - decode_zero_compressed_lens=( - self.decode_zero_compressed_lens[:num_decode_tokens] - ), block_size=self.block_size, num_decodes=num_decodes, num_prefills=num_prefills, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 071f4cbffb7a..d7deeb1d80a4 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -630,7 +630,7 @@ def compute_global_topk_indices_and_lens( def build_flashinfer_mixed_sparse_indices( decode_swa_indices: torch.Tensor, decode_compressed_indices: torch.Tensor | None, - decode_compressed_topk_lens: torch.Tensor, + decode_compressed_topk_lens: torch.Tensor | None, prefill_topk_indices: torch.Tensor, query_start_loc: torch.Tensor, seq_lens: torch.Tensor, @@ -642,12 +642,15 @@ def build_flashinfer_mixed_sparse_indices( window_size: int, compress_ratio: int, topk: int, + decode_compressed_indices_are_local: bool = False, + decode_is_valid_token: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Build FlashInfer DSV4 sparse indices for decode-first mixed batches.""" assert decode_swa_indices.dtype == torch.int32 assert decode_swa_indices.dim() == 2 assert decode_swa_indices.shape[-1] == window_size - assert decode_compressed_topk_lens.dtype == torch.int32 + if decode_compressed_topk_lens is not None: + assert decode_compressed_topk_lens.dtype == torch.int32 assert prefill_topk_indices.dtype == torch.int32 assert prefill_topk_indices.dim() == 2 assert query_start_loc.dtype == torch.int32 @@ -659,7 +662,8 @@ def build_flashinfer_mixed_sparse_indices( num_prefill_tokens = prefill_topk_indices.shape[0] num_tokens = num_decode_tokens + num_prefill_tokens assert token_to_req_indices.shape[0] >= num_tokens - assert decode_compressed_topk_lens.shape[0] >= num_decode_tokens + if decode_compressed_topk_lens is not None: + assert decode_compressed_topk_lens.shape[0] >= num_decode_tokens decode_compressed_topk = 0 if decode_compressed_indices is None: @@ -669,10 +673,19 @@ def build_flashinfer_mixed_sparse_indices( assert decode_compressed_indices.dim() == 2 assert decode_compressed_indices.shape[0] == num_decode_tokens decode_compressed_topk = decode_compressed_indices.shape[-1] + if decode_compressed_topk > 0 and decode_compressed_indices_are_local: + assert decode_is_valid_token is not None + assert decode_is_valid_token.dtype == torch.bool + assert decode_is_valid_token.shape[0] >= num_decode_tokens + else: + decode_is_valid_token = token_to_req_indices if compressed_block_table is None: compressed_block_table = swa_block_table assert compressed_block_table.dtype == torch.int32 + has_decode_compressed_lens = decode_compressed_topk_lens is not None + if decode_compressed_topk_lens is None: + decode_compressed_topk_lens = token_to_req_indices padded_topk = max(topk, decode_compressed_topk) padded_topk = (padded_topk + 3) // 4 * 4 @@ -696,6 +709,7 @@ def build_flashinfer_mixed_sparse_indices( decode_compressed_indices, decode_compressed_indices.stride(0), decode_compressed_topk_lens, + decode_is_valid_token, prefill_topk_indices, prefill_topk_indices.stride(0), query_start_loc, @@ -714,6 +728,8 @@ def build_flashinfer_mixed_sparse_indices( PADDED_TOP_K=padded_topk, PREFILL_TOPK_STRIDE=prefill_topk_indices.shape[-1], DECODE_COMPRESSED_TOPK=decode_compressed_topk, + DECODE_COMPRESSED_INDICES_ARE_LOCAL=decode_compressed_indices_are_local, + HAS_DECODE_COMPRESSED_LENS=has_decode_compressed_lens, BLOCK_SIZE=1024, num_warps=8, ) @@ -730,6 +746,7 @@ def _build_flashinfer_mixed_sparse_indices_kernel( decode_compressed_indices_ptr, decode_compressed_stride, decode_compressed_topk_lens_ptr, + decode_is_valid_token_ptr, prefill_topk_indices_ptr, prefill_topk_stride, query_start_loc_ptr, @@ -748,6 +765,8 @@ def _build_flashinfer_mixed_sparse_indices_kernel( PADDED_TOP_K: tl.constexpr, PREFILL_TOPK_STRIDE: tl.constexpr, DECODE_COMPRESSED_TOPK: tl.constexpr, + DECODE_COMPRESSED_INDICES_ARE_LOCAL: tl.constexpr, + HAS_DECODE_COMPRESSED_LENS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) @@ -767,6 +786,7 @@ def _build_flashinfer_mixed_sparse_indices_kernel( mask=mask, ) + compressed_len = tl.zeros((), dtype=tl.int32) for i in range(0, PADDED_TOP_K, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) mask = offset < PADDED_TOP_K @@ -777,6 +797,24 @@ def _build_flashinfer_mixed_sparse_indices_kernel( mask=offset < DECODE_COMPRESSED_TOPK, other=-1, ) + if DECODE_COMPRESSED_INDICES_ARE_LOCAL: + token_valid = tl.load(decode_is_valid_token_ptr + token_idx) + is_valid = values >= 0 + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + block_indices = values // compressed_block_size + block_numbers = tl.load( + compressed_block_table_ptr + + req_idx * compressed_block_table_stride + + block_indices, + mask=mask & is_valid, + other=-1, + ) + block_offsets = values % compressed_block_size + values = block_numbers * compressed_block_size + block_offsets + values = tl.where(is_valid, values, -1) + compressed_len += tl.sum( + (is_valid & token_valid).to(tl.int32), axis=0 + ) tl.store( sparse_indices_ptr + token_idx * sparse_indices_stride @@ -786,10 +824,15 @@ def _build_flashinfer_mixed_sparse_indices_kernel( mask=mask, ) - tl.store( - sparse_topk_lens_ptr + token_idx, - WINDOW_SIZE + tl.load(decode_compressed_topk_lens_ptr + token_idx), - ) + if DECODE_COMPRESSED_TOPK == 0: + compressed_len = tl.zeros((), dtype=tl.int32) + elif not DECODE_COMPRESSED_INDICES_ARE_LOCAL: + if HAS_DECODE_COMPRESSED_LENS: + compressed_len = tl.load(decode_compressed_topk_lens_ptr + token_idx) + else: + compressed_len = tl.full((), DECODE_COMPRESSED_TOPK, dtype=tl.int32) + + tl.store(sparse_topk_lens_ptr + token_idx, WINDOW_SIZE + compressed_len) return prefill_idx = token_idx - NUM_DECODE_TOKENS From 7fc2ba0d5fefea838d83a7c119464a109d0f2eeb Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Tue, 12 May 2026 07:57:04 -0700 Subject: [PATCH 16/24] Fix DeepSeek V4 FlashInfer padded graph tokens --- vllm/model_executor/layers/deepseek_v4_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 99772a5334da..f4f0d2db2ae7 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -1165,7 +1165,11 @@ def _forward_flashinfer( decode_is_valid_token=decode_is_valid_token, ) - query = q + # CUDA graph execution can pad q/output past the scheduled token count. + # The FlashInfer DSV4 launcher validates sparse_indices against the + # query length, so pass only the real tokens described by metadata. + query = q[:num_tokens] + output = output[:num_tokens] bmm1_scale: float | torch.Tensor = self.scale bmm2_scale: float | torch.Tensor = 1.0 if self.kv_cache_torch_dtype == torch.float8_e4m3fn: From 5f104504fb463735df48ebaf91dcf1b99afeb115 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Tue, 19 May 2026 09:22:36 -0700 Subject: [PATCH 17/24] Clean DeepSeek V4 FlashInfer sparse attention path --- vllm/config/vllm.py | 43 +++- .../layers/deepseek_v4_attention.py | 210 +++++++++++++----- vllm/model_executor/warmup/kernel_warmup.py | 126 +++++++++++ vllm/utils/flashinfer.py | 2 + .../attention/backends/mla/flashmla_sparse.py | 4 +- vllm/v1/attention/backends/mla/sparse_swa.py | 23 +- .../ops/deepseek_v4_ops/cache_utils.py | 73 +++--- 7 files changed, 389 insertions(+), 92 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d220aa65035d..b403eeb97260 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -829,11 +829,27 @@ def __post_init__(self): executor_backend = self.parallel_config.distributed_executor_backend executor_class = Executor.get_class(self) executor_supports_async_sched = executor_class.supports_async_scheduling() + deepseek_v4_async_unsupported = False + if self.model_config is not None: + architectures = ( + getattr(self.model_config.hf_config, "architectures", None) or [] + ) + model_type = getattr(self.model_config.hf_text_config, "model_type", None) + deepseek_v4_async_unsupported = ( + model_type == "deepseek_v4" + or "DeepseekV4ForCausalLM" in architectures + ) if self.scheduler_config.async_scheduling: # Async scheduling explicitly enabled, hard fail any incompatibilities. # Currently, async scheduling only support eagle speculative # decoding. + if deepseek_v4_async_unsupported: + raise ValueError( + "Async scheduling is currently disabled for DeepSeek V4. " + "The sparse MLA FlashInfer path can produce non-repeatable " + "outputs with async scheduling; set async_scheduling=False." + ) if self.speculative_config is not None: if ( self.speculative_config.method not in get_args(EagleModelTypes) @@ -856,7 +872,14 @@ def __post_init__(self): ) elif self.scheduler_config.async_scheduling is None: # Enable async scheduling unless there is an incompatible option. - if ( + if deepseek_v4_async_unsupported: + logger.warning_once( + "Async scheduling is disabled by default for DeepSeek V4 " + "because the sparse MLA FlashInfer path can produce " + "non-repeatable outputs with async scheduling." + ) + self.scheduler_config.async_scheduling = False + elif ( self.model_config is not None and self.model_config.runner_type == "pooling" ): @@ -901,6 +924,18 @@ def __post_init__(self): "enabled" if self.scheduler_config.async_scheduling else "disabled", ) + if ( + deepseek_v4_async_unsupported + and self.cache_config is not None + and self.cache_config.enable_prefix_caching + ): + logger.warning_once( + "Prefix caching is disabled for DeepSeek V4 because the sparse " + "MLA FlashInfer path can produce non-repeatable outputs with " + "cache-hit requests." + ) + self.cache_config.enable_prefix_caching = False + if self.parallel_config.disable_nccl_for_dp_synchronization is None: if self.scheduler_config.async_scheduling: if self.parallel_config.data_parallel_size > 1 and ( @@ -1044,6 +1079,12 @@ def has_blocked_weights(): # async tp is built on top of sequence parallelism and requires it. pass_config = self.compilation_config.pass_config + if deepseek_v4_async_unsupported and pass_config.fuse_allreduce_rms: + logger.warning_once( + "AllReduce + RMSNorm fusion is disabled for DeepSeek V4 " + "because this fused path can produce non-repeatable outputs." + ) + pass_config.fuse_allreduce_rms = False if pass_config.fuse_gemm_comms: pass_config.enable_sp = True if pass_config.enable_sp: diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index f4f0d2db2ae7..6129dc6ced6d 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -85,6 +85,14 @@ _FLASHINFER_DSV4_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 _flashinfer_dsv4_workspace_by_device: dict[torch.device, torch.Tensor] = {} +FlashInferSparseIndexMetadata = tuple[ + torch.Tensor, # compressed KV cache consumed by FlashInfer. + torch.Tensor, # query_start_loc. + torch.Tensor, # query_start_loc_cpu. + torch.Tensor, # seq_lens. + torch.Tensor, # sparse_indices. + torch.Tensor, # sparse_topk_lens. +] # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time @@ -354,11 +362,18 @@ def forward( hidden_states: torch.Tensor, llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: - # Pre-allocate attention output with FlashMLA-padded head count. - # The op writes into `o_padded`; we slice to n_local_heads after. + # FlashMLA requires 64/128 heads. FlashInfer full-cache modes run on + # the actual local head count, avoiding padded Q/output work. num_tokens = hidden_states.shape[0] - o_padded = torch.empty( - (num_tokens, self.padded_heads, self.head_dim), + use_flashinfer_full_cache = ( + self.mla_attn.kv_cache_torch_dtype != torch.uint8 + and not current_platform.is_rocm() + ) + output_heads = ( + self.n_local_heads if use_flashinfer_full_cache else self.padded_heads + ) + o_attn = torch.empty( + (num_tokens, output_heads, self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device, ) @@ -367,10 +382,10 @@ def forward( torch.ops.vllm.deepseek_v4_attention( hidden_states, positions, - o_padded, + o_attn, self.layer_name, ) - o = o_padded[:, : self.n_local_heads, :] + o = o_attn if use_flashinfer_full_cache else o_attn[:, : self.n_local_heads, :] # Keep ROCm on the BF16 reference wo_a path util kernel ready. if current_platform.is_rocm(): @@ -501,18 +516,19 @@ def attention_impl( # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride # on the default stream so q stays on its consumer stream (mla_attn - # downstream reads q on default). Indexer/compressor go on aux for + # downstream reads q on current). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. if self.indexer is not None: aux_stream = ( self.aux_stream_list[0] if self.aux_stream_list is not None else None ) indexer = self.indexer - # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None compressor = self.compressor - def wq_b_kv_insert_and_compress() -> tuple[torch.Tensor, torch.Tensor | None]: + def wq_b_kv_insert_and_compress() -> tuple[ + torch.Tensor, torch.Tensor | None + ]: q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) q_fp8 = self._fused_qnorm_rope_kv_insert( q, kv, positions, attn_metadata @@ -584,16 +600,25 @@ def wq_b_kv_insert() -> tuple[torch.Tensor, torch.Tensor | None]: q_for_attn = q_fp8 if q_fp8 is not None else q - # Pad q to FlashMLA-required head count (64 or 128). The per-tensor - # FP8 FlashInfer path emits a padded q_fp8 tensor directly from the - # qnorm/RoPE kernel. - if q_fp8 is None and self.n_local_heads < self.padded_heads: + # Pad q only for the legacy FlashMLA path, which requires 64 or 128 + # heads. FlashInfer full-cache modes keep the actual local head count. + if ( + q_fp8 is None + and self.mla_attn.kv_cache_torch_dtype == torch.uint8 + and self.n_local_heads < self.padded_heads + ): pad_size = self.padded_heads - self.n_local_heads q_for_attn = F.pad(q_for_attn, (0, 0, 0, pad_size), value=0.0) - # MLA attention writes into the pre-allocated `out` buffer - # ([num_tokens, padded_heads, head_dim]). - self.mla_attn(q_for_attn, kv, positions, output=out) + # MLA attention writes into the pre-allocated `out` buffer. FlashMLA + # gets a padded-head buffer; FlashInfer full-cache modes get actual + # local heads. + self.mla_attn( + q_for_attn, + kv, + positions, + output=out, + ) def _fused_qnorm_rope_kv_insert( self, @@ -617,7 +642,7 @@ def _fused_qnorm_rope_kv_insert( q_fp8 = None if swa_kv_cache.dtype == torch.float8_e4m3fn: q_fp8 = torch.empty( - (q.shape[0], self.padded_heads, q.shape[-1]), + (q.shape[0], self.n_local_heads, q.shape[-1]), dtype=torch.float8_e4m3fn, device=q.device, ) @@ -645,7 +670,7 @@ def _fused_qnorm_rope_kv_insert( kv, swa_kv_cache, swa_metadata.slot_mapping, - positions.to(torch.int64), + positions, self.rotary_emb.cos_sin_cache, self.eps, swa_metadata.block_size, @@ -1043,31 +1068,27 @@ def _forward_decode( out=output.unsqueeze(1), ) - def _forward_flashinfer( + def _build_flashinfer_sparse_index_metadata( self, - q: torch.Tensor, kv_cache: torch.Tensor | None, swa_k_cache: torch.Tensor, swa_metadata: "DeepseekSparseSWAMetadata", attn_metadata: FlashMLASparseMetadata | None, swa_only: bool, - output: torch.Tensor, - ) -> None: - assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn) + ) -> FlashInferSparseIndexMetadata: num_decodes = swa_metadata.num_decodes num_prefills = swa_metadata.num_prefills num_decode_tokens = swa_metadata.num_decode_tokens num_prefill_tokens = swa_metadata.num_prefill_tokens num_reqs = num_decodes + num_prefills num_tokens = num_decode_tokens + num_prefill_tokens - if num_tokens == 0: - return assert swa_metadata.seq_lens is not None assert swa_metadata.query_start_loc is not None assert swa_metadata.query_start_loc_cpu is not None assert swa_metadata.token_to_req_indices is not None assert swa_metadata.decode_swa_indices is not None + assert swa_metadata.block_table is not None decode_swa_indices = swa_metadata.decode_swa_indices.reshape( num_decode_tokens, self.window_size @@ -1086,8 +1107,6 @@ def _forward_flashinfer( compressed_block_table = None compressed_block_size = swa_metadata.block_size top_k = 0 - sparse_indices = None - sparse_topk_lens = None else: assert kv_cache is not None assert attn_metadata is not None @@ -1106,17 +1125,19 @@ def _forward_flashinfer( prefill_topk_indices = self.topk_indices_buffer[:0, :0] top_k = 0 + decode_compressed_indices_are_local = True + assert swa_metadata.is_valid_token is not None + decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens] if num_decode_tokens > 0: - assert swa_metadata.is_valid_token is not None decode_compressed_indices = self.topk_indices_buffer[ :num_decode_tokens ] - decode_compressed_indices_are_local = True - decode_is_valid_token = swa_metadata.is_valid_token[ - :num_decode_tokens - ] else: - decode_compressed_indices = prefill_topk_indices[:0, :0] + # The decode-side pointers are unused when there are no + # decode tokens. Keep their logical width aligned with the + # mixed-batch case so pure-prefill steps reuse the same + # Triton specialization compiled during graph capture. + decode_compressed_indices = prefill_topk_indices[:0] else: if num_prefill_tokens > 0: assert attn_metadata.c128a_prefill_topk_indices is not None @@ -1138,12 +1159,14 @@ def _forward_flashinfer( if num_prefill_tokens == 0: prefill_topk_indices = decode_compressed_indices[:0, :0] else: - decode_compressed_indices = prefill_topk_indices[:0, :0] + # As above, these decode tensors are unused for pure prefill. + # Preserve the C128A topk width and lens-present flag to + # share the mixed-batch sparse-index kernel variant. + decode_compressed_indices = prefill_topk_indices[:0] + decode_compressed_topk_lens = swa_metadata.seq_lens[:0] query_start_loc = swa_metadata.query_start_loc[: num_reqs + 1] query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_reqs + 1] - query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - max_q_len = int(query_lens_cpu.max().item()) seq_lens = swa_metadata.seq_lens[:num_reqs] assert seq_lens.dtype == torch.int32 sparse_indices, sparse_topk_lens = build_flashinfer_mixed_sparse_indices( @@ -1164,10 +1187,54 @@ def _forward_flashinfer( decode_compressed_indices_are_local=decode_compressed_indices_are_local, decode_is_valid_token=decode_is_valid_token, ) + return ( + compressed_kv_cache, + query_start_loc, + query_start_loc_cpu, + seq_lens, + sparse_indices, + sparse_topk_lens, + ) + + def _forward_flashinfer( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + swa_only: bool, + output: torch.Tensor, + ) -> None: + assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn) + num_decodes = swa_metadata.num_decodes + num_prefills = swa_metadata.num_prefills + num_decode_tokens = swa_metadata.num_decode_tokens + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_reqs = num_decodes + num_prefills + num_tokens = num_decode_tokens + num_prefill_tokens + if num_tokens == 0: + return + + flashinfer_sparse_metadata = self._build_flashinfer_sparse_index_metadata( + kv_cache=kv_cache, + swa_k_cache=swa_k_cache, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + swa_only=swa_only, + ) + ( + compressed_kv_cache, + query_start_loc, + query_start_loc_cpu, + seq_lens, + sparse_indices, + sparse_topk_lens, + ) = flashinfer_sparse_metadata # CUDA graph execution can pad q/output past the scheduled token count. - # The FlashInfer DSV4 launcher validates sparse_indices against the - # query length, so pass only the real tokens described by metadata. + # The FlashInfer DSV4 launcher validates sparse_indices against real + # tokens, so keep the tensors restricted to the scheduled token range. query = q[:num_tokens] output = output[:num_tokens] bmm1_scale: float | torch.Tensor = self.scale @@ -1180,21 +1247,58 @@ def _forward_flashinfer( assert query.dtype == torch.bfloat16 query = query.contiguous() - flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( - query=query, - swa_kv_cache=swa_k_cache, - workspace_buffer=_get_flashinfer_dsv4_workspace(q.device), - sparse_indices=sparse_indices, - compressed_kv_cache=compressed_kv_cache, - sparse_topk_lens=sparse_topk_lens, - seq_lens=seq_lens, - out=output, - bmm1_scale=bmm1_scale, - bmm2_scale=bmm2_scale, - sinks=self.attn_sink, - cum_seq_lens_q=query_start_loc, - max_q_len=max_q_len, - ) + workspace = _get_flashinfer_dsv4_workspace(q.device) + + if num_decode_tokens > 0: + decode_query_start_loc = query_start_loc[: num_decodes + 1] + decode_query_start_loc_cpu = query_start_loc_cpu[: num_decodes + 1] + decode_query_lens_cpu = ( + decode_query_start_loc_cpu[1:] - decode_query_start_loc_cpu[:-1] + ) + flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( + query=query[:num_decode_tokens], + swa_kv_cache=swa_k_cache, + workspace_buffer=workspace, + sparse_indices=sparse_indices[:num_decode_tokens], + compressed_kv_cache=compressed_kv_cache, + sparse_topk_lens=sparse_topk_lens[:num_decode_tokens], + seq_lens=seq_lens[:num_decodes], + out=output[:num_decode_tokens], + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + sinks=self.attn_sink, + cum_seq_lens_q=decode_query_start_loc, + max_q_len=int(decode_query_lens_cpu.max().item()), + ) + + if num_prefill_tokens > 0: + assert swa_metadata.prefill_query_start_loc is not None + prefill_query_start_loc = swa_metadata.prefill_query_start_loc + prefill_query_start_loc_cpu = query_start_loc_cpu[ + num_decodes : num_reqs + 1 + ] + prefill_query_lens_cpu = ( + prefill_query_start_loc_cpu[1:] - prefill_query_start_loc_cpu[:-1] + ) + prefill_query = query[num_decode_tokens:num_tokens] + prefill_output = output[num_decode_tokens:num_tokens] + prefill_sparse_indices = sparse_indices[num_decode_tokens:num_tokens] + prefill_sparse_topk_lens = sparse_topk_lens[num_decode_tokens:num_tokens] + flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( + query=prefill_query, + swa_kv_cache=swa_k_cache, + workspace_buffer=workspace, + sparse_indices=prefill_sparse_indices, + compressed_kv_cache=compressed_kv_cache, + sparse_topk_lens=prefill_sparse_topk_lens, + seq_lens=seq_lens[num_decodes:num_reqs], + out=prefill_output, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + sinks=self.attn_sink, + cum_seq_lens_q=prefill_query_start_loc, + max_q_len=int(prefill_query_lens_cpu.max().item()), + ) def _forward_prefill( self, diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 70abd8a6c503..af9f30af28dd 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -45,6 +45,8 @@ def kernel_warmup(worker: "Worker"): elif has_flashinfer() and current_platform.has_device_capability(90): flashinfer_autotune(worker.model_runner) + deepseek_v4_flashinfer_sparse_mla_warmup(worker) + # FlashInfer attention warmup # Only warmup if the model has FlashInfer attention groups # and is not a pooling model @@ -78,6 +80,130 @@ def _is_flashinfer_backend(backend): ) +def deepseek_v4_flashinfer_sparse_mla_warmup(worker: "Worker") -> None: + """Warm the DSV4 FlashInfer sparse-index builder variants. + + CUDA graph capture exercises mixed batches, but Triton can still see the + first real prefill wave as a separate specialization for the per-layer C4A + and C128A index shapes. Compile those tiny index-builder launches during + engine warmup so they do not appear as inference-time bubbles. + """ + from vllm.v1.attention.backends.mla.sparse_swa import ( + _compute_prefill_metadata_kernel, + ) + from vllm.v1.attention.ops.deepseek_v4_ops.cache_utils import ( + build_flashinfer_mixed_sparse_indices, + ) + + hf_config = worker.vllm_config.model_config.hf_config + compress_ratios = { + int(ratio) for ratio in getattr(hf_config, "compress_ratios", ()) + } + if not compress_ratios: + return + + window_size = int(getattr(hf_config, "sliding_window", 0)) + if window_size <= 0: + return + + logger.info("Warming up DeepSeek V4 FlashInfer sparse MLA index kernels.") + device = worker.model_runner.device + index_topk = int(getattr(hf_config, "index_topk", 0)) + max_model_len = worker.vllm_config.model_config.max_model_len + max_num_seqs = max(1, worker.scheduler_config.max_num_seqs) + + def _prefill_batch_sizes() -> list[int]: + sizes: list[int] = [] + size = 1 + while size < max_num_seqs: + sizes.append(size) + size *= 2 + sizes.append(max_num_seqs) + return sizes + + max_prefill_reqs = max(_prefill_batch_sizes()) + seq_lens = torch.ones((max_prefill_reqs,), device=device, dtype=torch.int32) + query_start_loc = torch.arange( + max_prefill_reqs + 1, device=device, dtype=torch.int32 + ) + prefill_query_start_loc = torch.empty( + max_prefill_reqs + 1, device=device, dtype=torch.int32 + ) + prefill_gather_lens = torch.empty( + max_prefill_reqs, device=device, dtype=torch.int32 + ) + for num_prefills in _prefill_batch_sizes(): + _compute_prefill_metadata_kernel[(1,)]( + prefill_query_start_loc[: num_prefills + 1], + prefill_gather_lens[:num_prefills], + seq_lens[:num_prefills], + query_start_loc[: num_prefills + 1], + num_prefills, + 0, + window_size, + BLOCK_SIZE=1 << num_prefills.bit_length(), + ) + + for compress_ratio in sorted(compress_ratios): + if compress_ratio == 4: + topk = index_topk + decode_compressed_indices_are_local = True + has_decode_compressed_lens = False + elif compress_ratio == 128: + topk = (max_model_len + compress_ratio - 1) // compress_ratio + topk = ((topk + 127) // 128) * 128 + decode_compressed_indices_are_local = False + has_decode_compressed_lens = True + else: + continue + + if topk <= 0: + continue + + decode_swa_indices = torch.zeros( + (1, window_size), device=device, dtype=torch.int32 + ) + decode_compressed_indices = torch.zeros( + (1, topk), device=device, dtype=torch.int32 + ) + prefill_topk_indices = torch.zeros((1, topk), device=device, dtype=torch.int32) + query_start_loc = torch.tensor([0, 1, 2], device=device, dtype=torch.int32) + seq_lens = torch.tensor([1, 2], device=device, dtype=torch.int32) + token_to_req_indices = torch.tensor([0, 1], device=device, dtype=torch.int32) + swa_block_table = torch.zeros((2, 1), device=device, dtype=torch.int32) + compressed_block_table = torch.zeros((2, 1), device=device, dtype=torch.int32) + decode_compressed_topk_lens = ( + torch.ones((1,), device=device, dtype=torch.int32) + if has_decode_compressed_lens + else None + ) + decode_is_valid_token = ( + torch.ones((1,), device=device, dtype=torch.bool) + if decode_compressed_indices_are_local + else None + ) + + build_flashinfer_mixed_sparse_indices( + decode_swa_indices, + decode_compressed_indices, + decode_compressed_topk_lens, + prefill_topk_indices, + query_start_loc, + seq_lens, + token_to_req_indices, + swa_block_table, + 256, + compressed_block_table, + max(1, 256 // compress_ratio), + window_size, + compress_ratio, + topk, + decode_compressed_indices_are_local=decode_compressed_indices_are_local, + decode_is_valid_token=decode_is_valid_token, + ) + torch.cuda.synchronize() + + def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ Autotune FlashInfer operations. diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 31e12883a1c6..23dfc985bf0a 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -161,6 +161,8 @@ def _get_dsv4_sparse_mla_raw_impl(): return None op = core.get_trtllm_gen_fmha_module() run_func = getattr(op, "trtllm_paged_attention_decode_sparse_mla_dsv4", None) + if run_func is None: + run_func = getattr(op, "dsv4_sparse_mla", None) if run_func is None: return None return run_func, core.device_support_pdl, core.get_device_sm_count diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index c8c11236c1fb..6a4519b17f6f 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -449,7 +449,7 @@ def _build_fp8_separate_prefill_decode( ) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode": num_tokens = common_attn_metadata.num_actual_tokens - (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( + (num_decodes, _, num_decode_tokens, num_prefill_tokens) = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold or 1, @@ -670,7 +670,7 @@ def _build_c128a_metadata( # `c128a_global_decode_topk_indices.shape[0]` lines up with q in # `_forward_decode`. The per-token C128A kernel handles non-uniform # query lengths. - (num_decodes, _, num_decode_tokens, num_prefill_tokens) = ( + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( split_decodes_and_prefills( cm, decode_threshold=self.reorder_batch_threshold or 1, diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 8aa061d6ee42..8364e044c150 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -169,6 +169,7 @@ class DeepseekSparseSWAMetadata: num_prefill_tokens: int = 0 # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. + prefill_query_start_loc: torch.Tensor | None = None prefill_seq_lens: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None @@ -400,28 +401,34 @@ def _build_deepseek_v4_metadata( # --- Prefill query metadata (single Triton kernel + CPU slicing) --- if num_prefills > 0: + pfx_query_start_loc = torch.empty( + num_prefills + 1, dtype=torch.int32, device=seq_lens.device + ) pfx_gather_lens = torch.empty( num_prefills, dtype=torch.int32, device=seq_lens.device ) _compute_prefill_metadata_kernel[(1,)]( + pfx_query_start_loc, pfx_gather_lens, seq_lens, query_start_loc, num_prefills, num_decodes, self.window_size, - BLOCK_SIZE=triton.next_power_of_2(num_prefills), + BLOCK_SIZE=triton.next_power_of_2(num_prefills + 1), ) + result["prefill_query_start_loc"] = pfx_query_start_loc result["prefill_seq_lens"] = seq_lens[num_decodes:] result["prefill_gather_lens"] = pfx_gather_lens return result -@triton.jit +@triton.jit(do_not_specialize=["num_prefills", "num_decodes", "window_size"]) def _compute_prefill_metadata_kernel( # Outputs + prefill_query_start_loc_ptr, prefill_gather_lens_ptr, # Inputs seq_lens_ptr, @@ -431,8 +438,18 @@ def _compute_prefill_metadata_kernel( window_size, BLOCK_SIZE: tl.constexpr, ): - """Compute prefill gather_lens in a single pass.""" + """Compute prefill-local query offsets and gather_lens in a single pass.""" offset = tl.arange(0, BLOCK_SIZE) + qsl_base = tl.load(query_start_loc_ptr + num_decodes) + + qsl_mask = offset < (num_prefills + 1) + qsl_value = tl.load( + query_start_loc_ptr + num_decodes + offset, + mask=qsl_mask, + other=qsl_base, + ) + tl.store(prefill_query_start_loc_ptr + offset, qsl_value - qsl_base, mask=qsl_mask) + mask = offset < num_prefills seq_len = tl.load(seq_lens_ptr + num_decodes + offset, mask=mask) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index d7deeb1d80a4..33f544152323 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -70,7 +70,6 @@ def _qnorm_rope_insert_full_cache_kernel( q_ptr, q_stride0, q_stride1, - num_q_heads, q_fp8_ptr, q_fp8_stride0, q_fp8_stride1, @@ -98,15 +97,6 @@ def _qnorm_rope_insert_full_cache_kernel( offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < HEAD_SIZE - if STORE_Q_FP8 and head_idx >= num_q_heads: - q_fp8_row = q_fp8_ptr + token_idx * q_fp8_stride0 + head_idx * q_fp8_stride1 - tl.store( - q_fp8_row + offsets, - tl.zeros((BLOCK_SIZE,), dtype=tl.float32).to(tl.float8e4nv), - mask=mask, - ) - return - position = tl.load(positions_ptr + token_idx) q_row = q_ptr + token_idx * q_stride0 + head_idx * q_stride1 @@ -123,16 +113,13 @@ def _qnorm_rope_insert_full_cache_kernel( ROPE_HEAD_DIM, BLOCK_SIZE, ) - q_bf16 = values.to(tl.bfloat16) - tl.store(q_row + offsets, q_bf16, mask=mask) - if STORE_Q_FP8: q_fp8_scale_inv = tl.load(q_fp8_scale_inv_ptr) q_fp8_row = q_fp8_ptr + token_idx * q_fp8_stride0 + head_idx * q_fp8_stride1 - q_fp8_values = tl.clamp( - q_bf16.to(tl.float32) * q_fp8_scale_inv, -448.0, 448.0 - ) + q_fp8_values = tl.clamp(values * q_fp8_scale_inv, -448.0, 448.0) tl.store(q_fp8_row + offsets, q_fp8_values.to(tl.float8e4nv), mask=mask) + else: + tl.store(q_row + offsets, values.to(tl.bfloat16), mask=mask) if head_idx != 0: return @@ -193,22 +180,22 @@ def qnorm_rope_and_insert_full_k_cache( assert q.dtype == torch.bfloat16 assert kv.dtype == torch.bfloat16 assert k_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) + assert positions.dtype in (torch.int32, torch.int64) assert cos_sin_cache.dtype == torch.float32 if q_fp8 is not None: assert q_fp8.dtype == torch.float8_e4m3fn assert q_fp8.dim() == 3 and q_fp8.shape[0] == q.shape[0] + assert q_fp8.shape[1] == q.shape[1] assert q_fp8.shape[-1] == q.shape[-1] assert q_fp8_scale_inv is not None assert q_fp8_scale_inv.dtype == torch.float32 assert q_fp8_scale_inv.numel() == 1 num_tokens_full, num_heads, _ = q.shape - q_heads_for_grid = q_fp8.shape[1] if q_fp8 is not None else num_heads - _qnorm_rope_insert_full_cache_kernel[(num_tokens_full, q_heads_for_grid)]( + _qnorm_rope_insert_full_cache_kernel[(num_tokens_full, num_heads)]( q, q.stride(0), q.stride(1), - num_heads, q_fp8 if q_fp8 is not None else q, q_fp8.stride(0) if q_fp8 is not None else q.stride(0), q_fp8.stride(1) if q_fp8 is not None else q.stride(1), @@ -700,6 +687,11 @@ def build_flashinfer_mixed_sparse_indices( if num_tokens == 0: return sparse_indices, sparse_topk_lens + window_block_size = triton.next_power_of_2(max(window_size, 1)) + topk_block_size = triton.next_power_of_2(max(padded_topk, 1)) + max_block_size = max(window_block_size, topk_block_size) + num_warps = 4 if max_block_size >= 256 else 1 + _build_flashinfer_mixed_sparse_indices_kernel[(num_tokens,)]( sparse_indices, sparse_indices.stride(0), @@ -730,13 +722,27 @@ def build_flashinfer_mixed_sparse_indices( DECODE_COMPRESSED_TOPK=decode_compressed_topk, DECODE_COMPRESSED_INDICES_ARE_LOCAL=decode_compressed_indices_are_local, HAS_DECODE_COMPRESSED_LENS=has_decode_compressed_lens, - BLOCK_SIZE=1024, - num_warps=8, + WINDOW_BLOCK_SIZE=window_block_size, + TOPK_BLOCK_SIZE=topk_block_size, + num_warps=num_warps, ) return sparse_indices, sparse_topk_lens -@triton.jit +@triton.jit( + do_not_specialize=[ + "sparse_indices_stride", + "decode_swa_stride", + "decode_compressed_stride", + "prefill_topk_stride", + "swa_block_table_stride", + "swa_block_size", + "compressed_block_table_stride", + "compressed_block_size", + "NUM_DECODE_TOKENS", + "PREFILL_TOPK_STRIDE", + ] +) def _build_flashinfer_mixed_sparse_indices_kernel( sparse_indices_ptr, sparse_indices_stride, @@ -758,22 +764,23 @@ def _build_flashinfer_mixed_sparse_indices_kernel( compressed_block_table_ptr, compressed_block_table_stride, compressed_block_size, - NUM_DECODE_TOKENS: tl.constexpr, + NUM_DECODE_TOKENS, WINDOW_SIZE: tl.constexpr, COMPRESS_RATIO: tl.constexpr, TOP_K: tl.constexpr, PADDED_TOP_K: tl.constexpr, - PREFILL_TOPK_STRIDE: tl.constexpr, + PREFILL_TOPK_STRIDE, DECODE_COMPRESSED_TOPK: tl.constexpr, DECODE_COMPRESSED_INDICES_ARE_LOCAL: tl.constexpr, HAS_DECODE_COMPRESSED_LENS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, + WINDOW_BLOCK_SIZE: tl.constexpr, + TOPK_BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) if token_idx < NUM_DECODE_TOKENS: - for i in range(0, WINDOW_SIZE, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) + for i in range(0, WINDOW_SIZE, WINDOW_BLOCK_SIZE): + offset = i + tl.arange(0, WINDOW_BLOCK_SIZE) mask = offset < WINDOW_SIZE values = tl.load( decode_swa_indices_ptr + token_idx * decode_swa_stride + offset, @@ -787,8 +794,8 @@ def _build_flashinfer_mixed_sparse_indices_kernel( ) compressed_len = tl.zeros((), dtype=tl.int32) - for i in range(0, PADDED_TOP_K, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) + for i in range(0, PADDED_TOP_K, TOPK_BLOCK_SIZE): + offset = i + tl.arange(0, TOPK_BLOCK_SIZE) mask = offset < PADDED_TOP_K values = tl.load( decode_compressed_indices_ptr @@ -848,8 +855,8 @@ def _build_flashinfer_mixed_sparse_indices_kernel( swa_start_pos = pos - swa_len + 1 topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) - for i in range(0, WINDOW_SIZE, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) + for i in range(0, WINDOW_SIZE, WINDOW_BLOCK_SIZE): + offset = i + tl.arange(0, WINDOW_BLOCK_SIZE) mask = offset < WINDOW_SIZE pos_offset = swa_start_pos + offset block_indices = pos_offset // swa_block_size @@ -867,8 +874,8 @@ def _build_flashinfer_mixed_sparse_indices_kernel( mask=mask, ) - for i in range(0, PADDED_TOP_K, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) + for i in range(0, PADDED_TOP_K, TOPK_BLOCK_SIZE): + offset = i + tl.arange(0, TOPK_BLOCK_SIZE) mask = offset < PADDED_TOP_K local_idx = tl.load( prefill_topk_indices_ptr + prefill_idx * prefill_topk_stride + offset, From b58aafe4b861d2d03a7f339a02e662738f162130 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Tue, 19 May 2026 23:13:40 -0700 Subject: [PATCH 18/24] Allow async scheduling for DeepSeek V4 Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> --- vllm/config/vllm.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b403eeb97260..2c611f3dcd87 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -829,27 +829,20 @@ def __post_init__(self): executor_backend = self.parallel_config.distributed_executor_backend executor_class = Executor.get_class(self) executor_supports_async_sched = executor_class.supports_async_scheduling() - deepseek_v4_async_unsupported = False + is_deepseek_v4 = False if self.model_config is not None: architectures = ( getattr(self.model_config.hf_config, "architectures", None) or [] ) model_type = getattr(self.model_config.hf_text_config, "model_type", None) - deepseek_v4_async_unsupported = ( - model_type == "deepseek_v4" - or "DeepseekV4ForCausalLM" in architectures + is_deepseek_v4 = ( + model_type == "deepseek_v4" or "DeepseekV4ForCausalLM" in architectures ) if self.scheduler_config.async_scheduling: # Async scheduling explicitly enabled, hard fail any incompatibilities. # Currently, async scheduling only support eagle speculative # decoding. - if deepseek_v4_async_unsupported: - raise ValueError( - "Async scheduling is currently disabled for DeepSeek V4. " - "The sparse MLA FlashInfer path can produce non-repeatable " - "outputs with async scheduling; set async_scheduling=False." - ) if self.speculative_config is not None: if ( self.speculative_config.method not in get_args(EagleModelTypes) @@ -872,14 +865,7 @@ def __post_init__(self): ) elif self.scheduler_config.async_scheduling is None: # Enable async scheduling unless there is an incompatible option. - if deepseek_v4_async_unsupported: - logger.warning_once( - "Async scheduling is disabled by default for DeepSeek V4 " - "because the sparse MLA FlashInfer path can produce " - "non-repeatable outputs with async scheduling." - ) - self.scheduler_config.async_scheduling = False - elif ( + if ( self.model_config is not None and self.model_config.runner_type == "pooling" ): @@ -925,7 +911,7 @@ def __post_init__(self): ) if ( - deepseek_v4_async_unsupported + is_deepseek_v4 and self.cache_config is not None and self.cache_config.enable_prefix_caching ): @@ -1079,7 +1065,7 @@ def has_blocked_weights(): # async tp is built on top of sequence parallelism and requires it. pass_config = self.compilation_config.pass_config - if deepseek_v4_async_unsupported and pass_config.fuse_allreduce_rms: + if is_deepseek_v4 and pass_config.fuse_allreduce_rms: logger.warning_once( "AllReduce + RMSNorm fusion is disabled for DeepSeek V4 " "because this fused path can produce non-repeatable outputs." From 27b39dd2219f5dedb669014f60649f1c48596321 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Tue, 19 May 2026 23:22:34 -0700 Subject: [PATCH 19/24] Fix sparse MLA pre-commit issues Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> --- docs/design/attention_backends.md | 2 +- vllm/model_executor/warmup/kernel_warmup.py | 2 +- vllm/v1/attention/backends/mla/flashmla_sparse.py | 3 ++- vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py | 4 +--- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 3f6feecf5f78..416d7073ffa1 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -217,7 +217,7 @@ MLA decode backends are selected using the standard | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | -| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | +| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_per_tensor`, `fp8_inc`, `fp8_ds_mla`, `fp8_e4m3` | 64 | 512, 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index af9f30af28dd..d08624d6fca2 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -201,7 +201,7 @@ def _prefill_batch_sizes() -> list[int]: decode_compressed_indices_are_local=decode_compressed_indices_are_local, decode_is_valid_token=decode_is_valid_token, ) - torch.cuda.synchronize() + torch.accelerator.synchronize() def flashinfer_autotune(runner: "GPUModelRunner") -> None: diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 6a4519b17f6f..99d3b002558d 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -255,6 +255,7 @@ class Chunk: # Prefill: local topk indices (used by combine_topk_swa_indices). c128a_prefill_topk_indices: torch.Tensor | None = None + def get_prefill_workspace_size(max_model_len: int): # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. # May be tuned later. @@ -449,7 +450,7 @@ def _build_fp8_separate_prefill_decode( ) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode": num_tokens = common_attn_metadata.num_actual_tokens - (num_decodes, _, num_decode_tokens, num_prefill_tokens) = ( + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold or 1, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 33f544152323..095064f813f6 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -819,9 +819,7 @@ def _build_flashinfer_mixed_sparse_indices_kernel( block_offsets = values % compressed_block_size values = block_numbers * compressed_block_size + block_offsets values = tl.where(is_valid, values, -1) - compressed_len += tl.sum( - (is_valid & token_valid).to(tl.int32), axis=0 - ) + compressed_len += tl.sum((is_valid & token_valid).to(tl.int32), axis=0) tl.store( sparse_indices_ptr + token_idx * sparse_indices_stride From 60a4b227a03a177c4fd180256f8e16ed099f90a0 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Tue, 19 May 2026 23:24:54 -0700 Subject: [PATCH 20/24] Restore DeepSeek V4 slot mapping test to main Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> --- .../test_indexer_deepseek_v4_slot_mapping.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py index 7b47264823fc..159bb8af3fb9 100644 --- a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py +++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py @@ -5,16 +5,13 @@ import torch from tests.v1.attention.utils import create_vllm_config -from vllm.transformers_utils.configs.deepseek_v4 import DeepseekV4Config from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadataBuilder from vllm.v1.kv_cache_interface import MLAAttentionSpec @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") -def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size( - tmp_path, -): +def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size(): """Regression test: DeepseekV4 compression path must compute slot_mapping from compressed positions, not reuse the uncompressed common metadata mapping. """ @@ -28,18 +25,7 @@ def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_ dtype=torch.bfloat16, compress_ratio=4, ) - hf_config = DeepseekV4Config( - architectures=["DeepseekV4ForCausalLM"], - hidden_size=128, - intermediate_size=256, - max_position_embeddings=2048, - num_attention_heads=4, - num_hidden_layers=1, - num_key_value_heads=4, - vocab_size=32000, - ) - hf_config.save_pretrained(tmp_path) - vllm_config = create_vllm_config(model_name=str(tmp_path), max_model_len=1024) + vllm_config = create_vllm_config(max_model_len=1024) builder = DeepseekV32IndexerMetadataBuilder( kv_cache_spec=kv_cache_spec, layer_names=["dummy"], From 057dcc0c2b547a54af75d3d6c77506b76cf79ef5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 21:19:02 -0700 Subject: [PATCH 21/24] Use CUDA full-cache FP8 insert for DeepSeek V4 --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 273 +++++++++++++++--- csrc/ops.h | 7 + csrc/torch_bindings.cpp | 13 + ..._fused_deepseek_v4_qnorm_rope_kv_insert.py | 112 ++++++- .../ops/deepseek_v4_ops/cache_utils.py | 196 +++---------- 5 files changed, 405 insertions(+), 196 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 2f2e7ecc1829..0b19a2cf16f9 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -91,6 +91,32 @@ constexpr float kFp8Max = 448.0f; constexpr int kNumLanes = 32; constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16 +__device__ __forceinline__ uint4 packFp8E4M3x16(float const* values, + float const scale) { +#ifndef USE_ROCM + uint4 out; + auto* out2 = reinterpret_cast<__nv_fp8x2_storage_t*>(&out); + #pragma unroll + for (int i = 0; i < kElemsPerLane / 2; i++) { + float2 scaled = + make_float2(values[2 * i] * scale, values[2 * i + 1] * scale); + scaled.x = fminf(fmaxf(scaled.x, -kFp8Max), kFp8Max); + scaled.y = fminf(fmaxf(scaled.y, -kFp8Max), kFp8Max); + out2[i] = __nv_cvt_float2_to_fp8x2(scaled, __NV_SATFINITE, __NV_E4M3); + } + return out; +#else + uint8_t out_bytes[kElemsPerLane]; + #pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + float scaled = values[i] * scale; + scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max); + out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); + } + return *reinterpret_cast(out_bytes); +#endif +} + // ──────────────────────────────────────────────────────────────────────────── // Small inline helpers // ──────────────────────────────────────────────────────────────────────────── @@ -127,20 +153,27 @@ __device__ __forceinline__ float warpSum(float val) { // them). The KV branch only inserts the first `num_tokens_insert` tokens // (= slot_mapping length) into the paged cache. // -template +template __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( - scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place - scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16 - uint8_t* __restrict__ k_cache, // [num_blocks, block_stride] - int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64 - int64_t const* __restrict__ position_ids, // [N] i64 - float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 + scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16 + uint8_t* __restrict__ q_fp8_out, // [N, H, 512] fp8, optional + int64_t const q_fp8_stride0, // elements, fp8 == bytes + int64_t const q_fp8_stride1, // elements, fp8 == bytes + scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16 + uint8_t* __restrict__ k_cache, // legacy uint8 or full fp8 + int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64 + int64_t const* __restrict__ position_ids, // [N] i64 + float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 + float const* __restrict__ fp8_scale_ptr, // scalar, full-cache fp8 only + float const* __restrict__ q_fp8_scale_inv, // scalar, q fp8 only float const eps, - int const num_tokens_full, // = q.size(0) = kv.size(0) - int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full - int const num_heads_q, // H - int const cache_block_size, // tokens per paged-cache block - int const kv_block_stride) { // bytes per paged-cache block + int const num_tokens_full, // = q.size(0) = kv.size(0) + int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full + int const num_heads_q, // H + int const cache_block_size, // tokens per paged-cache block + int64_t const kv_block_stride, // legacy bytes or full-cache elements + int64_t const kv_token_stride) { // full-cache elements, unused by legacy #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) // BF16 _typeConvert specialization is unavailable on pre-Ampere. The // DeepseekV4 kernel only runs with bf16 inputs in practice, so compile a @@ -256,30 +289,41 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( } // ═══════════════════════════════════════════════════════════════════════ - // Q branch: cast to bf16 and store back in place. + // Q branch: cast and store. Legacy writes bf16 in place. Full-cache + // per-tensor-FP8 writes q_fp8 and leaves q unchanged. // ═══════════════════════════════════════════════════════════════════════ if (!isKV) { - uint4 out0, out1; - typename Converter::packed_hip_type* po0 = - reinterpret_cast(&out0); - typename Converter::packed_hip_type* po1 = - reinterpret_cast(&out1); + if constexpr (STORE_Q_FP8) { + float const scale_inv = VLLM_LDG(q_fp8_scale_inv); + uint4 const out = packFp8E4M3x16(elements, scale_inv); + uint8_t* dst = q_fp8_out + + static_cast(tokenIdx) * q_fp8_stride0 + + static_cast(slotIdx) * q_fp8_stride1 + dim_base; + *reinterpret_cast(dst) = out; + } else { + uint4 out0, out1; + typename Converter::packed_hip_type* po0 = + reinterpret_cast(&out0); + typename Converter::packed_hip_type* po1 = + reinterpret_cast(&out1); #pragma unroll - for (int i = 0; i < 4; i++) { - po0[i] = Converter::convert( - make_float2(elements[2 * i], elements[2 * i + 1])); - } + for (int i = 0; i < 4; i++) { + po0[i] = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + } #pragma unroll - for (int i = 0; i < 4; i++) { - po1[i] = Converter::convert( - make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + for (int i = 0; i < 4; i++) { + po1[i] = Converter::convert( + make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + } + scalar_t_in* dst = + q_inout + + (static_cast(tokenIdx) * num_heads_q + slotIdx) * + kHeadDim + + dim_base; + *reinterpret_cast(dst) = out0; + *reinterpret_cast(dst + 8) = out1; } - scalar_t_in* dst = - q_inout + - (static_cast(tokenIdx) * num_heads_q + slotIdx) * kHeadDim + - dim_base; - *reinterpret_cast(dst) = out0; - *reinterpret_cast(dst + 8) = out1; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) cudaTriggerProgrammaticLaunchCompletion(); #endif @@ -299,6 +343,20 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( int64_t const block_idx = slot_id / cache_block_size; int64_t const pos_in_block = slot_id % cache_block_size; + if constexpr (STORE_FULL_CACHE) { + uint8_t* cache_row = k_cache + block_idx * kv_block_stride + + pos_in_block * kv_token_stride; + if constexpr (STORE_KV_FP8) { + float const inv_scale = 1.0f / VLLM_LDG(fp8_scale_ptr); + uint4 const out = packFp8E4M3x16(elements, inv_scale); + *reinterpret_cast(cache_row + dim_base) = out; + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + return; + } + uint8_t* block_base = k_cache + block_idx * static_cast(kv_block_stride); uint8_t* token_fp8_ptr = block_base + pos_in_block * kTokenDataBytes; @@ -431,18 +489,76 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( config.numAttrs = (sm_version >= 90) ? 1 : 0; cudaLaunchKernelEx( - &config, fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, - q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, - num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, - kv_block_stride); + &config, + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, + q_inout, nullptr, 0, 0, kv_in, k_cache, slot_mapping, position_ids, + cos_sin_cache, nullptr, nullptr, eps, num_tokens_full, num_tokens_insert, + num_heads_q, cache_block_size, kv_block_stride, 0); #else // ROCm: use standard kernel launch syntax (no PDL/stream serialization) // clang-format off - fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + <<>>( + q_inout, nullptr, 0, 0, kv_in, k_cache, slot_mapping, position_ids, + cos_sin_cache, nullptr, nullptr, eps, num_tokens_full, + num_tokens_insert, num_heads_q, cache_block_size, kv_block_stride, 0); +#endif +} + +template +void launchFusedDeepseekV4QNormRopeFullCacheFP8Insert( + scalar_t_in* q_in, scalar_t_in const* kv_in, uint8_t* q_fp8_out, + int64_t const q_fp8_stride0, int64_t const q_fp8_stride1, + uint8_t* k_cache, int64_t const* slot_mapping, + int64_t const* position_ids, float const* cos_sin_cache, + float const* fp8_scale, float const* q_fp8_scale_inv, float const eps, + int const num_tokens_full, int const num_tokens_insert, + int const num_heads_q, int const cache_block_size, + int64_t const kv_block_stride, int64_t const kv_token_stride, + cudaStream_t stream) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; + int64_t const total_warps = + static_cast(num_tokens_full) * (num_heads_q + 1); + int const grid = + static_cast((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock); + +#ifndef USE_ROCM + static int const sm_version = getSMVersion(); + TORCH_CHECK( + sm_version >= 80, + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert requires " + "sm_80+ (Ampere or newer); got sm_", + sm_version); + cudaLaunchConfig_t config; + config.gridDim = dim3(grid); + config.blockDim = dim3(kBlockSize); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attrs; + config.numAttrs = (sm_version >= 90) ? 1 : 0; + + cudaLaunchKernelEx( + &config, + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, + q_in, q_fp8_out, q_fp8_stride0, q_fp8_stride1, kv_in, k_cache, + slot_mapping, position_ids, cos_sin_cache, fp8_scale, q_fp8_scale_inv, + eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, + kv_block_stride, kv_token_stride); +#else + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel <<>>( - q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, - eps, num_tokens_full, num_tokens_insert, num_heads_q, - cache_block_size, kv_block_stride); + q_in, q_fp8_out, q_fp8_stride0, q_fp8_stride1, kv_in, k_cache, + slot_mapping, position_ids, cos_sin_cache, fp8_scale, + q_fp8_scale_inv, eps, num_tokens_full, num_tokens_insert, + num_heads_q, cache_block_size, kv_block_stride, kv_token_stride); #endif } @@ -509,3 +625,82 @@ void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( cache_block_size_i, kv_block_stride, stream); }); } + +void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( + torch::Tensor const& q, // [N, H, 512] bf16, read-only + torch::Tensor const& kv, // [N, 512] bf16, read-only + torch::Tensor& q_fp8, // [N, H, 512] fp8 e4m3 + torch::Tensor& k_cache, // [num_blocks, block_size, 512] fp8 + torch::Tensor const& slot_mapping, // [num_tokens_insert] int64 + torch::Tensor const& position_ids, // [N] int64 + torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] float32 + torch::Tensor const& fp8_scale, // scalar float32 + torch::Tensor const& q_fp8_scale_inv, // scalar float32 + double eps, int64_t cache_block_size) { + TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA"); + TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); + TORCH_CHECK(q_fp8.is_cuda() && q_fp8.is_contiguous(), + "q_fp8 must be contiguous CUDA"); + TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); + TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, + "slot_mapping must be int64 CUDA"); + TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, + "position_ids must be int64 CUDA"); + TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); + TORCH_CHECK(fp8_scale.is_cuda() && fp8_scale.dtype() == torch::kFloat32 && + fp8_scale.numel() == 1, + "fp8_scale must be a scalar float32 CUDA tensor"); + TORCH_CHECK(q_fp8_scale_inv.is_cuda() && + q_fp8_scale_inv.dtype() == torch::kFloat32 && + q_fp8_scale_inv.numel() == 1, + "q_fp8_scale_inv must be a scalar float32 CUDA tensor"); + TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]"); + TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); + TORCH_CHECK(q.dtype() == kv.dtype(), "q and kv dtype must match"); + TORCH_CHECK(q_fp8.sizes() == q.sizes(), "q_fp8 must match q shape"); + TORCH_CHECK(q_fp8.dtype() == torch::kFloat8_e4m3fn, + "q_fp8 must be float8_e4m3fn"); + TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size && + k_cache.size(2) == 512, + "k_cache shape [num_blocks, cache_block_size, 512]"); + TORCH_CHECK(k_cache.dtype() == torch::kFloat8_e4m3fn, + "k_cache must be float8_e4m3fn"); + TORCH_CHECK(k_cache.stride(2) == 1, + "k_cache last dimension must be contiguous"); + TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, + "cos_sin_cache must be float32"); + + int const num_tokens_full = static_cast(q.size(0)); + int const num_tokens_insert = static_cast(slot_mapping.size(0)); + TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && + static_cast(position_ids.size(0)) == num_tokens_full, + "q/kv/position_ids row counts must match"); + TORCH_CHECK(num_tokens_insert <= num_tokens_full, + "slot_mapping must not exceed q row count"); + int const num_heads_q = static_cast(q.size(1)); + int const cache_block_size_i = static_cast(cache_block_size); + + at::cuda::OptionalCUDAGuard device_guard(device_of(q)); + auto stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_HALF_TYPES( + q.scalar_type(), "fused_deepseek_v4_qnorm_rope_full_cache_fp8_insert", [&] { + using qkv_scalar_t = scalar_t; + vllm::deepseek_v4_fused_ops:: + launchFusedDeepseekV4QNormRopeFullCacheFP8Insert( + reinterpret_cast(q.data_ptr()), + reinterpret_cast(kv.data_ptr()), + reinterpret_cast(q_fp8.data_ptr()), + q_fp8.stride(0), q_fp8.stride(1), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(position_ids.data_ptr()), + cos_sin_cache.data_ptr(), fp8_scale.data_ptr(), + q_fp8_scale_inv.data_ptr(), static_cast(eps), + num_tokens_full, num_tokens_insert, num_heads_q, + cache_block_size_i, k_cache.stride(0), k_cache.stride(1), + stream); + }); +} diff --git a/csrc/ops.h b/csrc/ops.h index 16a78f570cf6..5e5bf1afdf4f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -105,6 +105,13 @@ void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size); +void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( + torch::Tensor const& q, torch::Tensor const& kv, torch::Tensor& q_fp8, + torch::Tensor& k_cache, torch::Tensor const& slot_mapping, + torch::Tensor const& position_ids, torch::Tensor const& cos_sin_cache, + torch::Tensor const& fp8_scale, torch::Tensor const& q_fp8_scale_inv, + double eps, int64_t cache_block_size); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7562d90c0b99..cc146b5dbf08 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -194,6 +194,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); + // Full-cache per-tensor FP8 variant for FlashInfer sparse MLA. Reuses the + // same CUDA warp-slot kernel structure as the legacy UE8M0 op, but writes Q + // to a separate FP8 tensor and KV into a full 512-wide FP8 paged cache. + ops.def( + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(" + "Tensor q, Tensor kv, Tensor! q_fp8, Tensor! k_cache, " + "Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, " + "Tensor fp8_scale, Tensor q_fp8_scale_inv, float eps, " + "int cache_block_size) -> ()"); + ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert", + torch::kCUDA, + &fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index 46d226e0f74e..8ad55dd22008 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -21,6 +21,7 @@ from vllm.v1.attention.ops.deepseek_v4_ops import ( dequantize_and_gather_k_cache, + qnorm_rope_and_insert_full_k_cache, quantize_and_insert_k_cache, ) @@ -68,7 +69,7 @@ def apply_rope_gptj_last_k( nope_dim = head_dim - rope_dim # Gather cos/sin for each token position: [num_tokens, rope_dim] - cs = cos_sin_cache[positions].to(torch.float32) # [N, rope_dim] + cs = cos_sin_cache[positions.long()].to(torch.float32) # [N, rope_dim] cos = cs[..., :half] # [N, half] sin = cs[..., half:] # [N, half] @@ -113,6 +114,12 @@ def _op_available() -> bool: return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert") +def _full_cache_fp8_op_available() -> bool: + return hasattr( + torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert" + ) + + pytestmark = pytest.mark.skipif( not torch.cuda.is_available() or not _op_available(), reason="CUDA not available or fused DeepseekV4 op not built in", @@ -125,6 +132,37 @@ def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs) ) +def _fp8_full_cache_reference( + q, + kv, + k_cache, + q_fp8, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + fp8_scale, + q_fp8_scale_inv, +): + q_ref = rmsnorm_no_weight(q, eps) + q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache) + q_fp8.copy_( + torch.clamp(q_ref.float() * q_fp8_scale_inv, -FP8_MAX, FP8_MAX).to( + torch.float8_e4m3fn + ) + ) + + kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache) + valid = slot_mapping >= 0 + slots = slot_mapping[valid] + block_idx = slots // block_size + pos_in_block = slots % block_size + k_cache[block_idx, pos_in_block] = torch.clamp( + kv_ref[valid].float() / fp8_scale, -FP8_MAX, FP8_MAX + ).to(torch.float8_e4m3fn) + + # ── Test 1: Q path numerical parity ────────────────────────────────────────── @@ -357,3 +395,75 @@ def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int): torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) + + +@pytest.mark.skipif( + not _full_cache_fp8_op_available(), + reason="full-cache per-tensor FP8 DeepseekV4 op not built in", +) +@pytest.mark.parametrize("num_tokens", [4, 17]) +@pytest.mark.parametrize("n_heads", [8, 17]) +@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64]) +def test_full_cache_per_tensor_fp8_matches_reference( + num_tokens: int, + n_heads: int, + positions_dtype: torch.dtype, +): + torch.manual_seed(4) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + block_size = 16 + max_pos = 4096 + + q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device) + kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=positions_dtype, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + fp8_scale = torch.tensor([1.0], dtype=torch.float32, device=device) + q_fp8_scale_inv = torch.tensor([1.0], dtype=torch.float32, device=device) + + q_fp8_ref = torch.empty_like(q, dtype=torch.float8_e4m3fn) + q_fp8_fused = torch.empty_like(q, dtype=torch.float8_e4m3fn) + k_cache_ref = torch.empty( + num_blocks, block_size, HEAD_DIM, dtype=torch.float8_e4m3fn, device=device + ) + k_cache_fused = torch.empty_like(k_cache_ref) + + _fp8_full_cache_reference( + q, + kv, + k_cache_ref, + q_fp8_ref, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + fp8_scale, + q_fp8_scale_inv, + ) + + qnorm_rope_and_insert_full_k_cache( + q.clone(), + kv, + k_cache_fused, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + fp8_scale, + q_fp8=q_fp8_fused, + q_fp8_scale_inv=q_fp8_scale_inv, + ) + + torch.testing.assert_close( + q_fp8_fused.float(), q_fp8_ref.float(), rtol=0, atol=0.25 + ) + torch.testing.assert_close( + k_cache_fused.float(), k_cache_ref.float(), rtol=0, atol=0.25 + ) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 095064f813f6..43054974be5a 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -34,128 +34,6 @@ def _has_dequant_gather_k_cutedsl() -> bool: return False -@triton.jit -def _apply_gptj_rope_512( - values, - position, - cos_sin_cache_ptr, - cos_sin_stride, - HEAD_SIZE: tl.constexpr, - ROPE_HEAD_DIM: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - NUM_PAIRS: tl.constexpr = BLOCK_SIZE // 2 - NOPE_PAIRS: tl.constexpr = (HEAD_SIZE - ROPE_HEAD_DIM) // 2 - HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 - - pairs = tl.reshape(values, (NUM_PAIRS, 2)) - even, odd = tl.split(pairs) - - pair_idx = tl.arange(0, NUM_PAIRS) - rope_pair_local = pair_idx - NOPE_PAIRS - is_rope_pair = rope_pair_local >= 0 - cs_idx = tl.maximum(rope_pair_local, 0) - - cache_base = cos_sin_cache_ptr + position * cos_sin_stride - cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) - sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0) - - new_even = tl.where(is_rope_pair, even * cos_v - odd * sin_v, even) - new_odd = tl.where(is_rope_pair, odd * cos_v + even * sin_v, odd) - return tl.interleave(new_even, new_odd) - - -@triton.jit -def _qnorm_rope_insert_full_cache_kernel( - q_ptr, - q_stride0, - q_stride1, - q_fp8_ptr, - q_fp8_stride0, - q_fp8_stride1, - q_fp8_scale_inv_ptr, - kv_ptr, - kv_stride0, - slot_mapping_ptr, - positions_ptr, - cos_sin_cache_ptr, - cos_sin_stride, - k_cache_ptr, - cache_stride0, - cache_stride1, - cache_block_size, - fp8_scale_ptr, - eps, - HEAD_SIZE: tl.constexpr, - ROPE_HEAD_DIM: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - STORE_Q_FP8: tl.constexpr, - STORE_KV_FP8: tl.constexpr, -): - token_idx = tl.program_id(0) - head_idx = tl.program_id(1) - offsets = tl.arange(0, BLOCK_SIZE) - mask = offsets < HEAD_SIZE - - position = tl.load(positions_ptr + token_idx) - - q_row = q_ptr + token_idx * q_stride0 + head_idx * q_stride1 - values = tl.load(q_row + offsets, mask=mask, other=0.0).to(tl.float32) - variance = tl.sum(values * values, axis=0) / HEAD_SIZE - values *= tl.rsqrt(variance + eps) - - values = _apply_gptj_rope_512( - values, - position, - cos_sin_cache_ptr, - cos_sin_stride, - HEAD_SIZE, - ROPE_HEAD_DIM, - BLOCK_SIZE, - ) - if STORE_Q_FP8: - q_fp8_scale_inv = tl.load(q_fp8_scale_inv_ptr) - q_fp8_row = q_fp8_ptr + token_idx * q_fp8_stride0 + head_idx * q_fp8_stride1 - q_fp8_values = tl.clamp(values * q_fp8_scale_inv, -448.0, 448.0) - tl.store(q_fp8_row + offsets, q_fp8_values.to(tl.float8e4nv), mask=mask) - else: - tl.store(q_row + offsets, values.to(tl.bfloat16), mask=mask) - - if head_idx != 0: - return - - slot_idx = tl.load(slot_mapping_ptr + token_idx) - if slot_idx < 0: - return - - kv_values = tl.load( - kv_ptr + token_idx * kv_stride0 + offsets, mask=mask, other=0.0 - ).to(tl.float32) - kv_values = _apply_gptj_rope_512( - kv_values, - position, - cos_sin_cache_ptr, - cos_sin_stride, - HEAD_SIZE, - ROPE_HEAD_DIM, - BLOCK_SIZE, - ) - - block_idx = slot_idx // cache_block_size - pos_in_block = slot_idx % cache_block_size - cache_row = ( - k_cache_ptr - + block_idx.to(tl.int64) * cache_stride0 - + (pos_in_block * cache_stride1) - ) - if STORE_KV_FP8: - fp8_scale = tl.load(fp8_scale_ptr) - kv_values = tl.clamp(kv_values / fp8_scale, -448.0, 448.0) - tl.store(cache_row + offsets, kv_values.to(tl.float8e4nv), mask=mask) - else: - tl.store(cache_row + offsets, kv_values.to(tl.bfloat16), mask=mask) - - def qnorm_rope_and_insert_full_k_cache( q: torch.Tensor, kv: torch.Tensor, @@ -169,55 +47,61 @@ def qnorm_rope_and_insert_full_k_cache( q_fp8: torch.Tensor | None = None, q_fp8_scale_inv: torch.Tensor | None = None, ) -> None: - """Apply DeepSeek V4 Q RMSNorm/RoPE and insert full-width BF16/FP8 KV. + """Apply DeepSeek V4 Q RMSNorm/RoPE and insert full-width FP8 KV. This path is for FlashInfer's DeepSeek V4 sparse MLA launcher, which accepts - full 512-wide BF16 or per-tensor FP8 E4M3 KV pools. The existing 584-byte - UE8M0 cache path remains handled by the CUDA fused op. + full 512-wide per-tensor FP8 E4M3 KV pools. The existing 584-byte UE8M0 + cache path remains handled by the CUDA fused op. """ assert q.dim() == 3 and q.shape[-1] == 512 assert kv.dim() == 2 and kv.shape[-1] == 512 assert q.dtype == torch.bfloat16 assert kv.dtype == torch.bfloat16 - assert k_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) + assert k_cache.dtype == torch.float8_e4m3fn assert positions.dtype in (torch.int32, torch.int64) assert cos_sin_cache.dtype == torch.float32 - if q_fp8 is not None: - assert q_fp8.dtype == torch.float8_e4m3fn - assert q_fp8.dim() == 3 and q_fp8.shape[0] == q.shape[0] - assert q_fp8.shape[1] == q.shape[1] - assert q_fp8.shape[-1] == q.shape[-1] - assert q_fp8_scale_inv is not None - assert q_fp8_scale_inv.dtype == torch.float32 - assert q_fp8_scale_inv.numel() == 1 - - num_tokens_full, num_heads, _ = q.shape - _qnorm_rope_insert_full_cache_kernel[(num_tokens_full, num_heads)]( + assert q_fp8 is not None + assert q_fp8.dtype == torch.float8_e4m3fn + assert q_fp8.dim() == 3 and q_fp8.shape == q.shape + assert q_fp8_scale_inv is not None + assert q_fp8_scale_inv.dtype == torch.float32 + assert q_fp8_scale_inv.numel() == 1 + + cuda_full_cache_fp8_op = ( + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert" + ) + assert hasattr(torch.ops._C, cuda_full_cache_fp8_op) + assert q.is_cuda + assert kv.is_cuda + assert q_fp8.is_cuda + assert k_cache.is_cuda + assert slot_mapping.is_cuda + assert slot_mapping.dtype == torch.int64 + assert positions.is_cuda + assert cos_sin_cache.is_cuda + assert fp8_scale.is_cuda + assert q_fp8_scale_inv.is_cuda + assert q.is_contiguous() + assert kv.is_contiguous() + assert q_fp8.is_contiguous() + assert k_cache.dim() == 3 + assert k_cache.stride(-1) == 1 + + positions_i64 = ( + positions if positions.dtype == torch.int64 else positions.to(torch.int64) + ) + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( q, - q.stride(0), - q.stride(1), - q_fp8 if q_fp8 is not None else q, - q_fp8.stride(0) if q_fp8 is not None else q.stride(0), - q_fp8.stride(1) if q_fp8 is not None else q.stride(1), - q_fp8_scale_inv if q_fp8_scale_inv is not None else fp8_scale, kv, - kv.stride(0), + q_fp8, + k_cache, slot_mapping, - positions, + positions_i64, cos_sin_cache, - cos_sin_cache.stride(0), - k_cache, - k_cache.stride(0), - k_cache.stride(1), - cache_block_size, fp8_scale, + q_fp8_scale_inv, eps, - HEAD_SIZE=512, - ROPE_HEAD_DIM=64, - BLOCK_SIZE=512, - STORE_Q_FP8=q_fp8 is not None, - STORE_KV_FP8=k_cache.dtype == torch.float8_e4m3fn, - num_warps=8, + cache_block_size, ) From c6dc5d20c764e5899f58a7b420458c222712fe77 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 21:27:32 -0700 Subject: [PATCH 22/24] Call DeepSeek V4 full-cache FP8 op directly --- ..._fused_deepseek_v4_qnorm_rope_kv_insert.py | 37 +++- vllm/config/vllm.py | 195 ++---------------- .../layers/deepseek_v4_attention.py | 12 +- .../attention/ops/deepseek_v4_ops/__init__.py | 2 - .../ops/deepseek_v4_ops/cache_utils.py | 71 ------- 5 files changed, 51 insertions(+), 266 deletions(-) diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index 8ad55dd22008..eebb2c0270c4 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -21,7 +21,6 @@ from vllm.v1.attention.ops.deepseek_v4_ops import ( dequantize_and_gather_k_cache, - qnorm_rope_and_insert_full_k_cache, quantize_and_insert_k_cache, ) @@ -132,6 +131,34 @@ def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs) ) +def _call_full_cache_fp8_fused( + q, + kv, + q_fp8, + k_cache, + slot_mapping, + positions, + cos_sin_cache, + fp8_scale, + q_fp8_scale_inv, + eps, + bs, +): + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( + q, + kv, + q_fp8, + k_cache, + slot_mapping, + positions.long(), + cos_sin_cache, + fp8_scale, + q_fp8_scale_inv, + eps, + bs, + ) + + def _fp8_full_cache_reference( q, kv, @@ -447,18 +474,18 @@ def test_full_cache_per_tensor_fp8_matches_reference( q_fp8_scale_inv, ) - qnorm_rope_and_insert_full_k_cache( + _call_full_cache_fp8_fused( q.clone(), kv, + q_fp8_fused, k_cache_fused, slot_mapping, positions, cos_sin_cache, + fp8_scale, + q_fp8_scale_inv, eps, block_size, - fp8_scale, - q_fp8=q_fp8_fused, - q_fp8_scale_inv=q_fp8_scale_inv, ) torch.testing.assert_close( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 2c611f3dcd87..f591605d08c7 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -121,13 +121,6 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer - if current_platform.is_rocm(): - from vllm._aiter_ops import rocm_aiter_ops - - return ( - rocm_aiter_ops.is_enabled() and cfg.parallel_config.tensor_parallel_size > 1 - ) - return ( cfg.parallel_config.tensor_parallel_size > 1 and current_platform.is_cuda() @@ -136,6 +129,12 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) + # tp-dp combination broken: + # https://github.com/vllm-project/vllm/issues/34458 + and cfg.parallel_config.data_parallel_size == 1 + # tp-pp combination broken: + # https://github.com/vllm-project/vllm/issues/35426 + and cfg.parallel_config.pipeline_parallel_size == 1 ) @@ -155,20 +154,12 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool: ) -def enable_rope_kvcache_mla_fusion(cfg: "VllmConfig") -> bool: - """Enable if use_inductor_graph_partition is enabled.""" - - return ( - cfg.compilation_config.use_inductor_graph_partition - or not cfg.compilation_config.splitting_ops_contain_kv_cache_update() - ) - - def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: """Enable if using AITER RMSNorm and hidden size is 2880 i.e. gpt-oss.""" + from vllm._aiter_ops import rocm_aiter_ops return ( - cfg.kernel_config.ir_op_priority.fused_add_rms_norm[0] == "aiter" + rocm_aiter_ops.is_rmsnorm_enabled() and cfg.model_config is not None and cfg.model_config.get_hidden_size() == 2880 ) @@ -193,7 +184,6 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": False, "fuse_mla_dual_rms_norm": False, "fuse_rope_kvcache": False, - "fuse_rope_kvcache_cat_mla": False, }, "cudagraph_mode": CUDAGraphMode.NONE, "use_inductor_graph_partition": False, @@ -214,15 +204,12 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": False, - "fuse_rope_kvcache_cat_mla": False, }, "cudagraph_mode": CUDAGraphMode.PIECEWISE, "use_inductor_graph_partition": False, }, "kernel_config": { - # Disabled for now due to correctness issues: - # https://github.com/flashinfer-ai/flashinfer/issues/3197 - "enable_flashinfer_autotune": False, + "enable_flashinfer_autotune": True, }, } OPTIMIZATION_LEVEL_02 = { @@ -237,15 +224,12 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": enable_rope_kvcache_fusion, - "fuse_rope_kvcache_cat_mla": enable_rope_kvcache_mla_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, }, "kernel_config": { - # Disabled for now due to correctness issues: - # https://github.com/flashinfer-ai/flashinfer/issues/3197 - "enable_flashinfer_autotune": False, + "enable_flashinfer_autotune": True, }, } OPTIMIZATION_LEVEL_03 = { @@ -260,7 +244,6 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": enable_rope_kvcache_fusion, - "fuse_rope_kvcache_cat_mla": enable_rope_kvcache_mla_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, @@ -726,48 +709,6 @@ def _post_init_kv_transfer_config(self) -> None: # This is the same for all backends self.kv_transfer_config.kv_role = "kv_both" - def _verify_kv_transfer_compat(self) -> None: - """Reject configurations that silently corrupt KV transfers.""" - if ( - self.kv_transfer_config is None - or self.kv_transfer_config.kv_connector is None - ): - return - - # PyTorch's expandable_segments allocator uses CUDA VMM, which can - # remap a virtual address range to different physical pages over the - # engine's lifetime. KV connectors that pin KV cache memory (e.g. - # NixlConnector via ibv_reg_mr, MooncakeConnector) end up with their - # registrations pointing at stale physical pages after any remap, - # producing RDMA failures like IBV_WC_REM_ACCESS_ERR / - # NIXL_ERR_REMOTE_DISCONNECT at the first inter-node KV transfer. - # We can't enumerate every in-tree and out-of-tree connector that - # pins memory, so we conservatively reject the combination whenever - # any KV connector is configured. - # - # Sleep mode is exempt: CuMemAllocator.use_memory_pool toggles - # expandable_segments off around its pool (see #40812), so the KV - # cache allocated within that context lands on stable physical pages - # even when the env var is set. - if "expandable_segments:True" not in os.environ.get( - "PYTORCH_CUDA_ALLOC_CONF", "" - ): - return - if self.model_config is not None and self.model_config.enable_sleep_mode: - return - - raise ValueError( - f"KV connector {self.kv_transfer_config.kv_connector} is " - "incompatible with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True " - "unless enable_sleep_mode is also enabled. PyTorch's CUDA VMM " - "allocator can remap KV cache virtual addresses to different " - "physical pages, invalidating any pinned/registered KV memory " - "(e.g. IB memory regions registered by NIXL or Mooncake). Either " - "unset expandable_segments:True or enable sleep mode (which " - "routes KV allocations through CuMemAllocator's pool, where " - "expandable_segments is automatically disabled)." - ) - def __post_init__(self): """Verify configs are valid & consistent with each other.""" @@ -829,15 +770,6 @@ def __post_init__(self): executor_backend = self.parallel_config.distributed_executor_backend executor_class = Executor.get_class(self) executor_supports_async_sched = executor_class.supports_async_scheduling() - is_deepseek_v4 = False - if self.model_config is not None: - architectures = ( - getattr(self.model_config.hf_config, "architectures", None) or [] - ) - model_type = getattr(self.model_config.hf_text_config, "model_type", None) - is_deepseek_v4 = ( - model_type == "deepseek_v4" or "DeepseekV4ForCausalLM" in architectures - ) if self.scheduler_config.async_scheduling: # Async scheduling explicitly enabled, hard fail any incompatibilities. @@ -910,18 +842,6 @@ def __post_init__(self): "enabled" if self.scheduler_config.async_scheduling else "disabled", ) - if ( - is_deepseek_v4 - and self.cache_config is not None - and self.cache_config.enable_prefix_caching - ): - logger.warning_once( - "Prefix caching is disabled for DeepSeek V4 because the sparse " - "MLA FlashInfer path can produce non-repeatable outputs with " - "cache-hit requests." - ) - self.cache_config.enable_prefix_caching = False - if self.parallel_config.disable_nccl_for_dp_synchronization is None: if self.scheduler_config.async_scheduling: if self.parallel_config.data_parallel_size > 1 and ( @@ -1065,12 +985,6 @@ def has_blocked_weights(): # async tp is built on top of sequence parallelism and requires it. pass_config = self.compilation_config.pass_config - if is_deepseek_v4 and pass_config.fuse_allreduce_rms: - logger.warning_once( - "AllReduce + RMSNorm fusion is disabled for DeepSeek V4 " - "because this fused path can produce non-repeatable outputs." - ) - pass_config.fuse_allreduce_rms = False if pass_config.fuse_gemm_comms: pass_config.enable_sp = True if pass_config.enable_sp: @@ -1239,12 +1153,6 @@ def has_blocked_weights(): if envs.VLLM_USE_V2_MODEL_RUNNER: self._validate_v2_model_runner() - if ( - self.model_config is not None - and self.model_config.enable_return_routed_experts - ): - self._validate_return_routed_experts() - # Re-compute compile ranges after platform-specific config updates # (e.g., XPU may lower max_num_batched_tokens when MLA is enabled) self._set_compile_ranges() @@ -1435,7 +1343,6 @@ def has_blocked_weights(): # Handle the KV connector configs self._post_init_kv_transfer_config() - self._verify_kv_transfer_compat() # Log the custom passes that are enabled self.compilation_config.pass_config.log_enabled_passes() @@ -1525,10 +1432,6 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( range(256, max_graph_size + 1, 16)) - `max_num_batched_tokens` is also appended to the list if it fits - within `max_cudagraph_capture_size`, so the max batch size is captured - even when off-stride. - In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` will be the final sizes to capture cudagraph (in ascending order). @@ -1617,12 +1520,6 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes += list( range(256, max_cudagraph_capture_size + 1, 16) ) - # ensure max_num_tokens is captured if within max capture size - if ( - max_num_tokens <= max_cudagraph_capture_size - and max_num_tokens not in cudagraph_capture_sizes - ): - cudagraph_capture_sizes.append(max_num_tokens) # de-duplicate and sort the sizes cudagraph_capture_sizes = sorted(set(cudagraph_capture_sizes)) @@ -1697,16 +1594,11 @@ def _set_compile_ranges(self): if compile_range_end is not None: computed_compile_ranges_endpoints.append(compile_range_end) - # Add the compile ranges for flashinfer/aiter. + # Add the compile ranges for flashinfer if compilation_config.pass_config.fuse_allreduce_rms: tp_size = self.parallel_config.tensor_parallel_size - from vllm._aiter_ops import rocm_aiter_ops - - if rocm_aiter_ops.is_enabled(): - max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() - else: - max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) - if max_size is not None and self.model_config is not None: + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: assert isinstance(self.model_config.dtype, torch.dtype) max_token_num = max_size // ( self.model_config.get_hidden_size() @@ -1935,51 +1827,6 @@ def _validate_v2_model_runner(self) -> None: + ", ".join(unsupported) ) - def _validate_return_routed_experts(self) -> None: - """Reject parallelism configurations not yet validated with - --enable-return-routed-experts. - - Validated scope (PR #39917): TP, EP, DP, single-node and multi-node, - prefix caching, and speculative decoding (MTP validated end-to-end; - Eagle/Eagle3/Ngram/Medusa supported by construction since the - routing buffer is bound only to the target model and verified-token - routing lands at the correct positions during the main forward). - - Out-of-scope (block until validated): PP > 1, prefill context - parallelism (PCP) > 1, decode context parallelism (DCP) > 1, - async scheduling. - """ - unsupported: list[str] = [] - - if self.parallel_config.pipeline_parallel_size > 1: - unsupported.append( - "pipeline parallelism " - f"(pipeline_parallel_size=" - f"{self.parallel_config.pipeline_parallel_size})" - ) - if self.parallel_config.prefill_context_parallel_size > 1: - unsupported.append( - "prefill context parallelism " - f"(prefill_context_parallel_size=" - f"{self.parallel_config.prefill_context_parallel_size})" - ) - if self.parallel_config.decode_context_parallel_size > 1: - unsupported.append( - "decode context parallelism " - f"(decode_context_parallel_size=" - f"{self.parallel_config.decode_context_parallel_size})" - ) - if self.scheduler_config.async_scheduling: - unsupported.append("async scheduling") - - if unsupported: - raise ValueError( - "--enable-return-routed-experts is not yet validated with: " - + ", ".join(unsupported) - + ". Disable these features or omit " - "--enable-return-routed-experts." - ) - def validate_block_size(self) -> None: """Validate block_size against DCP and mamba constraints. @@ -2026,22 +1873,6 @@ def validate_block_size(self) -> None: "to schedule a multiple of block_size tokens even if they are " "in the middle of a mm input" ) - # TODO: support align mamba cache mode for model runner v2 - assert not envs.VLLM_USE_V2_MODEL_RUNNER, ( - "Model Runner V2 has not yet supported mamba_cache_mode='align'. " - ) - - @model_validator(mode="after") - def validate_nvfp4_kv_cache_with_mla(self) -> "VllmConfig": - if self.model_config is None: - return self - if self.cache_config.cache_dtype == "nvfp4" and self.model_config.use_mla: - raise ValueError( - "nvfp4 KV cache is not supported with MLA (Multi-head Latent " - "Attention) backends. Please use a different --kv-cache-dtype " - "(e.g., 'fp8' or 'auto') for MLA models such as DeepSeek." - ) - return self @model_validator(mode="after") def validate_mamba_block_size(self) -> "VllmConfig": diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 6129dc6ced6d..ca745dc5282a 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -28,7 +28,6 @@ fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, - qnorm_rope_and_insert_full_k_cache, ) from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum @@ -665,18 +664,19 @@ def _fused_qnorm_rope_kv_insert( swa_metadata.block_size, ) else: - qnorm_rope_and_insert_full_k_cache( + assert q_fp8 is not None + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( q, kv, + q_fp8, swa_kv_cache, swa_metadata.slot_mapping, - positions, + positions.to(torch.int64), self.rotary_emb.cos_sin_cache, + self.mla_attn._flashinfer_fp8_kv_scale, + self.mla_attn._flashinfer_fp8_q_scale_inv, self.eps, swa_metadata.block_size, - self.mla_attn._flashinfer_fp8_kv_scale, - q_fp8=q_fp8, - q_fp8_scale_inv=self.mla_attn._flashinfer_fp8_q_scale_inv, ) return q_fp8 diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index 9e5499cbee4b..bc247adfb6b4 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -6,7 +6,6 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, - qnorm_rope_and_insert_full_k_cache, quantize_and_insert_k_cache, ) from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant @@ -22,6 +21,5 @@ "fused_indexer_q_rope_quant", "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", - "qnorm_rope_and_insert_full_k_cache", "quantize_and_insert_k_cache", ] diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 43054974be5a..e71e5fe6f497 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -34,77 +34,6 @@ def _has_dequant_gather_k_cutedsl() -> bool: return False -def qnorm_rope_and_insert_full_k_cache( - q: torch.Tensor, - kv: torch.Tensor, - k_cache: torch.Tensor, - slot_mapping: torch.Tensor, - positions: torch.Tensor, - cos_sin_cache: torch.Tensor, - eps: float, - cache_block_size: int, - fp8_scale: torch.Tensor, - q_fp8: torch.Tensor | None = None, - q_fp8_scale_inv: torch.Tensor | None = None, -) -> None: - """Apply DeepSeek V4 Q RMSNorm/RoPE and insert full-width FP8 KV. - - This path is for FlashInfer's DeepSeek V4 sparse MLA launcher, which accepts - full 512-wide per-tensor FP8 E4M3 KV pools. The existing 584-byte UE8M0 - cache path remains handled by the CUDA fused op. - """ - assert q.dim() == 3 and q.shape[-1] == 512 - assert kv.dim() == 2 and kv.shape[-1] == 512 - assert q.dtype == torch.bfloat16 - assert kv.dtype == torch.bfloat16 - assert k_cache.dtype == torch.float8_e4m3fn - assert positions.dtype in (torch.int32, torch.int64) - assert cos_sin_cache.dtype == torch.float32 - assert q_fp8 is not None - assert q_fp8.dtype == torch.float8_e4m3fn - assert q_fp8.dim() == 3 and q_fp8.shape == q.shape - assert q_fp8_scale_inv is not None - assert q_fp8_scale_inv.dtype == torch.float32 - assert q_fp8_scale_inv.numel() == 1 - - cuda_full_cache_fp8_op = ( - "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert" - ) - assert hasattr(torch.ops._C, cuda_full_cache_fp8_op) - assert q.is_cuda - assert kv.is_cuda - assert q_fp8.is_cuda - assert k_cache.is_cuda - assert slot_mapping.is_cuda - assert slot_mapping.dtype == torch.int64 - assert positions.is_cuda - assert cos_sin_cache.is_cuda - assert fp8_scale.is_cuda - assert q_fp8_scale_inv.is_cuda - assert q.is_contiguous() - assert kv.is_contiguous() - assert q_fp8.is_contiguous() - assert k_cache.dim() == 3 - assert k_cache.stride(-1) == 1 - - positions_i64 = ( - positions if positions.dtype == torch.int64 else positions.to(torch.int64) - ) - torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( - q, - kv, - q_fp8, - k_cache, - slot_mapping, - positions_i64, - cos_sin_cache, - fp8_scale, - q_fp8_scale_inv, - eps, - cache_block_size, - ) - - @triton.jit def quantize_and_insert_k_kernel( # Input tensors From ddbbdc388e8f97100e5c8cbeaeb143bc756d8575 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Wed, 20 May 2026 21:35:10 -0700 Subject: [PATCH 23/24] Restore vLLM config from PR base --- vllm/config/vllm.py | 168 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 155 insertions(+), 13 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f591605d08c7..d220aa65035d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -121,6 +121,13 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer + if current_platform.is_rocm(): + from vllm._aiter_ops import rocm_aiter_ops + + return ( + rocm_aiter_ops.is_enabled() and cfg.parallel_config.tensor_parallel_size > 1 + ) + return ( cfg.parallel_config.tensor_parallel_size > 1 and current_platform.is_cuda() @@ -129,12 +136,6 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) - # tp-dp combination broken: - # https://github.com/vllm-project/vllm/issues/34458 - and cfg.parallel_config.data_parallel_size == 1 - # tp-pp combination broken: - # https://github.com/vllm-project/vllm/issues/35426 - and cfg.parallel_config.pipeline_parallel_size == 1 ) @@ -154,12 +155,20 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool: ) +def enable_rope_kvcache_mla_fusion(cfg: "VllmConfig") -> bool: + """Enable if use_inductor_graph_partition is enabled.""" + + return ( + cfg.compilation_config.use_inductor_graph_partition + or not cfg.compilation_config.splitting_ops_contain_kv_cache_update() + ) + + def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: """Enable if using AITER RMSNorm and hidden size is 2880 i.e. gpt-oss.""" - from vllm._aiter_ops import rocm_aiter_ops return ( - rocm_aiter_ops.is_rmsnorm_enabled() + cfg.kernel_config.ir_op_priority.fused_add_rms_norm[0] == "aiter" and cfg.model_config is not None and cfg.model_config.get_hidden_size() == 2880 ) @@ -184,6 +193,7 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": False, "fuse_mla_dual_rms_norm": False, "fuse_rope_kvcache": False, + "fuse_rope_kvcache_cat_mla": False, }, "cudagraph_mode": CUDAGraphMode.NONE, "use_inductor_graph_partition": False, @@ -204,12 +214,15 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": False, + "fuse_rope_kvcache_cat_mla": False, }, "cudagraph_mode": CUDAGraphMode.PIECEWISE, "use_inductor_graph_partition": False, }, "kernel_config": { - "enable_flashinfer_autotune": True, + # Disabled for now due to correctness issues: + # https://github.com/flashinfer-ai/flashinfer/issues/3197 + "enable_flashinfer_autotune": False, }, } OPTIMIZATION_LEVEL_02 = { @@ -224,12 +237,15 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": enable_rope_kvcache_fusion, + "fuse_rope_kvcache_cat_mla": enable_rope_kvcache_mla_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, }, "kernel_config": { - "enable_flashinfer_autotune": True, + # Disabled for now due to correctness issues: + # https://github.com/flashinfer-ai/flashinfer/issues/3197 + "enable_flashinfer_autotune": False, }, } OPTIMIZATION_LEVEL_03 = { @@ -244,6 +260,7 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": enable_rope_kvcache_fusion, + "fuse_rope_kvcache_cat_mla": enable_rope_kvcache_mla_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, @@ -709,6 +726,48 @@ def _post_init_kv_transfer_config(self) -> None: # This is the same for all backends self.kv_transfer_config.kv_role = "kv_both" + def _verify_kv_transfer_compat(self) -> None: + """Reject configurations that silently corrupt KV transfers.""" + if ( + self.kv_transfer_config is None + or self.kv_transfer_config.kv_connector is None + ): + return + + # PyTorch's expandable_segments allocator uses CUDA VMM, which can + # remap a virtual address range to different physical pages over the + # engine's lifetime. KV connectors that pin KV cache memory (e.g. + # NixlConnector via ibv_reg_mr, MooncakeConnector) end up with their + # registrations pointing at stale physical pages after any remap, + # producing RDMA failures like IBV_WC_REM_ACCESS_ERR / + # NIXL_ERR_REMOTE_DISCONNECT at the first inter-node KV transfer. + # We can't enumerate every in-tree and out-of-tree connector that + # pins memory, so we conservatively reject the combination whenever + # any KV connector is configured. + # + # Sleep mode is exempt: CuMemAllocator.use_memory_pool toggles + # expandable_segments off around its pool (see #40812), so the KV + # cache allocated within that context lands on stable physical pages + # even when the env var is set. + if "expandable_segments:True" not in os.environ.get( + "PYTORCH_CUDA_ALLOC_CONF", "" + ): + return + if self.model_config is not None and self.model_config.enable_sleep_mode: + return + + raise ValueError( + f"KV connector {self.kv_transfer_config.kv_connector} is " + "incompatible with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True " + "unless enable_sleep_mode is also enabled. PyTorch's CUDA VMM " + "allocator can remap KV cache virtual addresses to different " + "physical pages, invalidating any pinned/registered KV memory " + "(e.g. IB memory regions registered by NIXL or Mooncake). Either " + "unset expandable_segments:True or enable sleep mode (which " + "routes KV allocations through CuMemAllocator's pool, where " + "expandable_segments is automatically disabled)." + ) + def __post_init__(self): """Verify configs are valid & consistent with each other.""" @@ -1153,6 +1212,12 @@ def has_blocked_weights(): if envs.VLLM_USE_V2_MODEL_RUNNER: self._validate_v2_model_runner() + if ( + self.model_config is not None + and self.model_config.enable_return_routed_experts + ): + self._validate_return_routed_experts() + # Re-compute compile ranges after platform-specific config updates # (e.g., XPU may lower max_num_batched_tokens when MLA is enabled) self._set_compile_ranges() @@ -1343,6 +1408,7 @@ def has_blocked_weights(): # Handle the KV connector configs self._post_init_kv_transfer_config() + self._verify_kv_transfer_compat() # Log the custom passes that are enabled self.compilation_config.pass_config.log_enabled_passes() @@ -1432,6 +1498,10 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( range(256, max_graph_size + 1, 16)) + `max_num_batched_tokens` is also appended to the list if it fits + within `max_cudagraph_capture_size`, so the max batch size is captured + even when off-stride. + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` will be the final sizes to capture cudagraph (in ascending order). @@ -1520,6 +1590,12 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes += list( range(256, max_cudagraph_capture_size + 1, 16) ) + # ensure max_num_tokens is captured if within max capture size + if ( + max_num_tokens <= max_cudagraph_capture_size + and max_num_tokens not in cudagraph_capture_sizes + ): + cudagraph_capture_sizes.append(max_num_tokens) # de-duplicate and sort the sizes cudagraph_capture_sizes = sorted(set(cudagraph_capture_sizes)) @@ -1594,11 +1670,16 @@ def _set_compile_ranges(self): if compile_range_end is not None: computed_compile_ranges_endpoints.append(compile_range_end) - # Add the compile ranges for flashinfer + # Add the compile ranges for flashinfer/aiter. if compilation_config.pass_config.fuse_allreduce_rms: tp_size = self.parallel_config.tensor_parallel_size - max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) - if max_size is not None: + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() + else: + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None and self.model_config is not None: assert isinstance(self.model_config.dtype, torch.dtype) max_token_num = max_size // ( self.model_config.get_hidden_size() @@ -1827,6 +1908,51 @@ def _validate_v2_model_runner(self) -> None: + ", ".join(unsupported) ) + def _validate_return_routed_experts(self) -> None: + """Reject parallelism configurations not yet validated with + --enable-return-routed-experts. + + Validated scope (PR #39917): TP, EP, DP, single-node and multi-node, + prefix caching, and speculative decoding (MTP validated end-to-end; + Eagle/Eagle3/Ngram/Medusa supported by construction since the + routing buffer is bound only to the target model and verified-token + routing lands at the correct positions during the main forward). + + Out-of-scope (block until validated): PP > 1, prefill context + parallelism (PCP) > 1, decode context parallelism (DCP) > 1, + async scheduling. + """ + unsupported: list[str] = [] + + if self.parallel_config.pipeline_parallel_size > 1: + unsupported.append( + "pipeline parallelism " + f"(pipeline_parallel_size=" + f"{self.parallel_config.pipeline_parallel_size})" + ) + if self.parallel_config.prefill_context_parallel_size > 1: + unsupported.append( + "prefill context parallelism " + f"(prefill_context_parallel_size=" + f"{self.parallel_config.prefill_context_parallel_size})" + ) + if self.parallel_config.decode_context_parallel_size > 1: + unsupported.append( + "decode context parallelism " + f"(decode_context_parallel_size=" + f"{self.parallel_config.decode_context_parallel_size})" + ) + if self.scheduler_config.async_scheduling: + unsupported.append("async scheduling") + + if unsupported: + raise ValueError( + "--enable-return-routed-experts is not yet validated with: " + + ", ".join(unsupported) + + ". Disable these features or omit " + "--enable-return-routed-experts." + ) + def validate_block_size(self) -> None: """Validate block_size against DCP and mamba constraints. @@ -1873,6 +1999,22 @@ def validate_block_size(self) -> None: "to schedule a multiple of block_size tokens even if they are " "in the middle of a mm input" ) + # TODO: support align mamba cache mode for model runner v2 + assert not envs.VLLM_USE_V2_MODEL_RUNNER, ( + "Model Runner V2 has not yet supported mamba_cache_mode='align'. " + ) + + @model_validator(mode="after") + def validate_nvfp4_kv_cache_with_mla(self) -> "VllmConfig": + if self.model_config is None: + return self + if self.cache_config.cache_dtype == "nvfp4" and self.model_config.use_mla: + raise ValueError( + "nvfp4 KV cache is not supported with MLA (Multi-head Latent " + "Attention) backends. Please use a different --kv-cache-dtype " + "(e.g., 'fp8' or 'auto') for MLA models such as DeepSeek." + ) + return self @model_validator(mode="after") def validate_mamba_block_size(self) -> "VllmConfig": From 8f0603b23c5668b779de49465ad7088197f21a52 Mon Sep 17 00:00:00 2001 From: PerkzZheng <67892460+PerkzZheng@users.noreply.github.com> Date: Wed, 20 May 2026 23:21:01 -0700 Subject: [PATCH 24/24] Use CUDA fused insert for full KV cache --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 139 +++++++++++++++++- csrc/ops.h | 5 + csrc/torch_bindings.cpp | 9 ++ ..._fused_deepseek_v4_qnorm_rope_kv_insert.py | 112 +++++++++++++- .../layers/deepseek_v4_attention.py | 34 +++-- 5 files changed, 283 insertions(+), 16 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 0b19a2cf16f9..2f817bc5e6a5 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -172,8 +172,8 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full int const num_heads_q, // H int const cache_block_size, // tokens per paged-cache block - int64_t const kv_block_stride, // legacy bytes or full-cache elements - int64_t const kv_token_stride) { // full-cache elements, unused by legacy + int64_t const kv_block_stride, // bytes per paged-cache block + int64_t const kv_token_stride) { // bytes per token, unused by legacy #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) // BF16 _typeConvert specialization is unavailable on pre-Ampere. The // DeepseekV4 kernel only runs with bf16 inputs in practice, so compile a @@ -350,6 +350,25 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( float const inv_scale = 1.0f / VLLM_LDG(fp8_scale_ptr); uint4 const out = packFp8E4M3x16(elements, inv_scale); *reinterpret_cast(cache_row + dim_base) = out; + } else { + uint4 out0, out1; + typename Converter::packed_hip_type* po0 = + reinterpret_cast(&out0); + typename Converter::packed_hip_type* po1 = + reinterpret_cast(&out1); +#pragma unroll + for (int i = 0; i < 4; i++) { + po0[i] = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + po1[i] = Converter::convert( + make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + } + scalar_t_in* dst = reinterpret_cast(cache_row) + dim_base; + *reinterpret_cast(dst) = out0; + *reinterpret_cast(dst + 8) = out1; } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) cudaTriggerProgrammaticLaunchCompletion(); @@ -507,6 +526,57 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( #endif } +template +void launchFusedDeepseekV4QNormRopeFullCacheBF16Insert( + scalar_t_in* q_inout, scalar_t_in const* kv_in, uint8_t* k_cache, + int64_t const* slot_mapping, int64_t const* position_ids, + float const* cos_sin_cache, float const eps, int const num_tokens_full, + int const num_tokens_insert, int const num_heads_q, + int const cache_block_size, int64_t const kv_block_stride, + int64_t const kv_token_stride, cudaStream_t stream) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; + int64_t const total_warps = + static_cast(num_tokens_full) * (num_heads_q + 1); + int const grid = + static_cast((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock); + +#ifndef USE_ROCM + static int const sm_version = getSMVersion(); + TORCH_CHECK( + sm_version >= 80, + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert requires " + "sm_80+ (Ampere or newer); got sm_", + sm_version); + cudaLaunchConfig_t config; + config.gridDim = dim3(grid); + config.blockDim = dim3(kBlockSize); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attrs; + config.numAttrs = (sm_version >= 90) ? 1 : 0; + + cudaLaunchKernelEx( + &config, + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, + q_inout, nullptr, 0, 0, kv_in, k_cache, slot_mapping, position_ids, + cos_sin_cache, nullptr, nullptr, eps, num_tokens_full, num_tokens_insert, + num_heads_q, cache_block_size, kv_block_stride, kv_token_stride); +#else + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + <<>>( + q_inout, nullptr, 0, 0, kv_in, k_cache, slot_mapping, position_ids, + cos_sin_cache, nullptr, nullptr, eps, num_tokens_full, + num_tokens_insert, num_heads_q, cache_block_size, kv_block_stride, + kv_token_stride); +#endif +} + template void launchFusedDeepseekV4QNormRopeFullCacheFP8Insert( scalar_t_in* q_in, scalar_t_in const* kv_in, uint8_t* q_fp8_out, @@ -700,7 +770,68 @@ void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( cos_sin_cache.data_ptr(), fp8_scale.data_ptr(), q_fp8_scale_inv.data_ptr(), static_cast(eps), num_tokens_full, num_tokens_insert, num_heads_q, - cache_block_size_i, k_cache.stride(0), k_cache.stride(1), - stream); + cache_block_size_i, k_cache.stride(0) * k_cache.element_size(), + k_cache.stride(1) * k_cache.element_size(), stream); + }); +} + +void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert( + torch::Tensor& q, // [N, H, 512] bf16, in place + torch::Tensor const& kv, // [N, 512] bf16, read-only + torch::Tensor& k_cache, // [num_blocks, block_size, 512] bf16 + torch::Tensor const& slot_mapping, // [num_tokens_insert] int64 + torch::Tensor const& position_ids, // [N] int64 + torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] float32 + double eps, int64_t cache_block_size) { + TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA"); + TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); + TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); + TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, + "slot_mapping must be int64 CUDA"); + TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, + "position_ids must be int64 CUDA"); + TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); + TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]"); + TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); + TORCH_CHECK(q.dtype() == kv.dtype(), "q and kv dtype must match"); + TORCH_CHECK(q.dtype() == torch::kBFloat16, "q and kv must be bfloat16"); + TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size && + k_cache.size(2) == 512, + "k_cache shape [num_blocks, cache_block_size, 512]"); + TORCH_CHECK(k_cache.dtype() == torch::kBFloat16, "k_cache must be bfloat16"); + TORCH_CHECK(k_cache.stride(2) == 1, + "k_cache last dimension must be contiguous"); + TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, + "cos_sin_cache must be float32"); + + int const num_tokens_full = static_cast(q.size(0)); + int const num_tokens_insert = static_cast(slot_mapping.size(0)); + TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && + static_cast(position_ids.size(0)) == num_tokens_full, + "q/kv/position_ids row counts must match"); + TORCH_CHECK(num_tokens_insert <= num_tokens_full, + "slot_mapping must not exceed q row count"); + int const num_heads_q = static_cast(q.size(1)); + int const cache_block_size_i = static_cast(cache_block_size); + + at::cuda::OptionalCUDAGuard device_guard(device_of(q)); + auto stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_HALF_TYPES( + q.scalar_type(), "fused_deepseek_v4_qnorm_rope_full_cache_bf16_insert", [&] { + using qkv_scalar_t = scalar_t; + vllm::deepseek_v4_fused_ops:: + launchFusedDeepseekV4QNormRopeFullCacheBF16Insert( + reinterpret_cast(q.data_ptr()), + reinterpret_cast(kv.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(position_ids.data_ptr()), + cos_sin_cache.data_ptr(), static_cast(eps), + num_tokens_full, num_tokens_insert, num_heads_q, + cache_block_size_i, k_cache.stride(0) * k_cache.element_size(), + k_cache.stride(1) * k_cache.element_size(), stream); }); } diff --git a/csrc/ops.h b/csrc/ops.h index 5e5bf1afdf4f..3ffa9a5bed44 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -112,6 +112,11 @@ void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( torch::Tensor const& fp8_scale, torch::Tensor const& q_fp8_scale_inv, double eps, int64_t cache_block_size); +void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert( + torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache, + torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, + torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index cc146b5dbf08..e99beca249d2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -207,6 +207,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { torch::kCUDA, &fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert); + ops.def( + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(" + "Tensor! q, Tensor kv, Tensor! k_cache, Tensor slot_mapping, " + "Tensor position_ids, Tensor cos_sin_cache, float eps, " + "int cache_block_size) -> ()"); + ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert", + torch::kCUDA, + &fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index eebb2c0270c4..03c64aa3fd83 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -119,6 +119,12 @@ def _full_cache_fp8_op_available() -> bool: ) +def _full_cache_bf16_op_available() -> bool: + return hasattr( + torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert" + ) + + pytestmark = pytest.mark.skipif( not torch.cuda.is_available() or not _op_available(), reason="CUDA not available or fused DeepseekV4 op not built in", @@ -159,6 +165,28 @@ def _call_full_cache_fp8_fused( ) +def _call_full_cache_bf16_fused( + q, + kv, + k_cache, + slot_mapping, + positions, + cos_sin_cache, + eps, + bs, +): + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert( + q, + kv, + k_cache, + slot_mapping, + positions.long(), + cos_sin_cache, + eps, + bs, + ) + + def _fp8_full_cache_reference( q, kv, @@ -190,6 +218,28 @@ def _fp8_full_cache_reference( ).to(torch.float8_e4m3fn) +def _bf16_full_cache_reference( + q, + kv, + k_cache, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, +): + q_ref = rmsnorm_no_weight(q, eps) + q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache) + + kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache) + valid = slot_mapping >= 0 + slots = slot_mapping[valid] + block_idx = slots // block_size + pos_in_block = slots % block_size + k_cache[block_idx, pos_in_block] = kv_ref[valid] + return q_ref + + # ── Test 1: Q path numerical parity ────────────────────────────────────────── @@ -455,10 +505,10 @@ def test_full_cache_per_tensor_fp8_matches_reference( q_fp8_ref = torch.empty_like(q, dtype=torch.float8_e4m3fn) q_fp8_fused = torch.empty_like(q, dtype=torch.float8_e4m3fn) - k_cache_ref = torch.empty( + k_cache_ref = torch.zeros( num_blocks, block_size, HEAD_DIM, dtype=torch.float8_e4m3fn, device=device ) - k_cache_fused = torch.empty_like(k_cache_ref) + k_cache_fused = torch.zeros_like(k_cache_ref) _fp8_full_cache_reference( q, @@ -494,3 +544,61 @@ def test_full_cache_per_tensor_fp8_matches_reference( torch.testing.assert_close( k_cache_fused.float(), k_cache_ref.float(), rtol=0, atol=0.25 ) + + +@pytest.mark.skipif( + not _full_cache_bf16_op_available(), + reason="full-cache BF16 DeepseekV4 op not built in", +) +@pytest.mark.parametrize("num_tokens", [4, 17]) +@pytest.mark.parametrize("n_heads", [8, 17]) +@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64]) +def test_full_cache_bf16_matches_reference( + num_tokens: int, + n_heads: int, + positions_dtype: torch.dtype, +): + torch.manual_seed(5) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + block_size = 16 + max_pos = 4096 + + q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device) + kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=positions_dtype, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + q_fused = q.clone() + k_cache_ref = torch.zeros( + num_blocks, block_size, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + k_cache_fused = torch.zeros_like(k_cache_ref) + q_ref = _bf16_full_cache_reference( + q, + kv, + k_cache_ref, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + ) + + _call_full_cache_bf16_fused( + q_fused, + kv, + k_cache_fused, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + ) + + torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index ca745dc5282a..356cb2d61d2a 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -638,13 +638,6 @@ def _fused_qnorm_rope_kv_insert( assert swa_metadata is not None swa_kv_cache = self.swa_cache_layer.kv_cache - q_fp8 = None - if swa_kv_cache.dtype == torch.float8_e4m3fn: - q_fp8 = torch.empty( - (q.shape[0], self.n_local_heads, q.shape[-1]), - dtype=torch.float8_e4m3fn, - device=q.device, - ) # Horizontally fused: # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE @@ -663,8 +656,27 @@ def _fused_qnorm_rope_kv_insert( self.eps, swa_metadata.block_size, ) - else: - assert q_fp8 is not None + return None + + if swa_kv_cache.dtype == torch.bfloat16: + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert( + q, + kv, + swa_kv_cache, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + self.eps, + swa_metadata.block_size, + ) + return None + + if swa_kv_cache.dtype == torch.float8_e4m3fn: + q_fp8 = torch.empty( + (q.shape[0], self.n_local_heads, q.shape[-1]), + dtype=torch.float8_e4m3fn, + device=q.device, + ) torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( q, kv, @@ -678,7 +690,9 @@ def _fused_qnorm_rope_kv_insert( self.eps, swa_metadata.block_size, ) - return q_fp8 + return q_fp8 + + raise AssertionError(f"Unsupported SWA KV cache dtype {swa_kv_cache.dtype}") def deepseek_v4_attention(