diff --git a/vllm/models/deepseek_v4/common/ops/__init__.py b/vllm/models/deepseek_v4/common/ops/__init__.py index 0f80c329740b..959a79f292a5 100644 --- a/vllm/models/deepseek_v4/common/ops/__init__.py +++ b/vllm/models/deepseek_v4/common/ops/__init__.py @@ -10,7 +10,6 @@ from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant from .fused_qk_rmsnorm import fused_q_kv_rmsnorm -from .save_partial_states import save_partial_states __all__ = [ "MXFP4_BLOCK_SIZE", @@ -21,5 +20,4 @@ "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", "quantize_and_insert_k_cache", - "save_partial_states", ] diff --git a/vllm/models/deepseek_v4/common/ops/cache_utils.py b/vllm/models/deepseek_v4/common/ops/cache_utils.py index efb936be6c07..ac66751e3111 100644 --- a/vllm/models/deepseek_v4/common/ops/cache_utils.py +++ b/vllm/models/deepseek_v4/common/ops/cache_utils.py @@ -366,7 +366,7 @@ def dequantize_and_gather_k_cache( ) -> None: if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. - from vllm.models.deepseek_v4.nvidia.ops import ( + from vllm.models.deepseek_v4.nvidia.ops.dequant_gather_k_cutedsl import ( dequantize_and_gather_k_cache_cutedsl, ) diff --git a/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py b/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py index 9a5e478e315f..ed839e8c30c6 100644 --- a/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py +++ b/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py @@ -11,6 +11,12 @@ - _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn: head=128, MXFP4 (block=32), 4 ue8m0 bytes +Additional cutedsl kernels: + - _compress_kv_sparse_attn_cutedsl / _norm_rope_insert_sparse_attn_cutedsl: + CuTe DSL split kernels for C128 + - _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl: + CuTe DSL fused kernels for C4 + RoPE is register-based via tl.reshape -> tl.split -> tl.interleave (or the even/odd halves are consumed directly for MXFP4, no interleave needed). FP8 UE8M0 quant uses tl.reshape to tile [N_QUANT_BLOCKS, QUANT_BLOCK] for @@ -19,92 +25,42 @@ and N_QUANT_BLOCKS ue8m0 bytes. """ -from typing import Any - -import torch +from functools import cache from vllm.triton_utils import tl, triton from .fused_indexer_q import _fp32x2_to_fp4x2 -def compress_norm_rope_store_triton( - state_cache: torch.Tensor, - num_actual: int, - token_to_req_indices: torch.Tensor, - positions: torch.Tensor, - slot_mapping: torch.Tensor, - block_table: torch.Tensor, - block_size: int, - state_width: int, - cos_sin_cache: torch.Tensor, - kv_cache: torch.Tensor, - k_cache_metadata: Any, - pdl_kwargs: dict, - head_dim: int, - rope_head_dim: int, - compress_ratio: int, - overlap: bool, - use_fp4_cache: bool, - rms_norm_weight: torch.Tensor, - rms_norm_eps: float, - quant_block: int, - token_stride: int, - scale_dim: int, -) -> None: - """Shared triton launcher for the fused compress+norm+RoPE+insert path. - - Picks one of the three kernels in this module based on ``head_dim`` and - ``use_fp4_cache``. Identical launch signature for all three. - """ - if head_dim == 512: - kernel = _fused_kv_compress_norm_rope_insert_sparse_attn - num_warps = 4 - elif use_fp4_cache: - kernel = _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn - num_warps = 1 - else: - kernel = _fused_kv_compress_norm_rope_insert_indexer_attn - num_warps = 1 - - kernel[(num_actual,)]( - # state cache - state_cache, - state_cache.stride(0), - state_cache.stride(1), - # metadata - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_table.stride(0), - block_size, - # RMSNorm - rms_norm_weight, - rms_norm_eps, - # RoPE - cos_sin_cache, - cos_sin_cache.stride(0), - # KV cache - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size (tokens per block) - # constexprs - HEAD_SIZE=head_dim, - TRITON_BLOCK_SIZE=triton.next_power_of_2(head_dim), - STATE_WIDTH=state_width, - COMPRESS_RATIO=compress_ratio, - OVERLAP=overlap, - ROPE_HEAD_DIM=rope_head_dim, - FP8_MAX=448.0, - QUANT_BLOCK=quant_block, - TOKEN_STRIDE=token_stride, - SCALE_DIM=scale_dim, - KV_BLOCK_STRIDE=kv_cache.stride(0), - num_warps=num_warps, - **pdl_kwargs, +@cache +def _get_sparse_attn_cutedsl_impls(): + from .sparse_attn_compress_cutedsl import ( + _compress_kv_sparse_attn_cutedsl, + _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl, + _norm_rope_insert_sparse_attn_cutedsl, ) + return ( + _compress_kv_sparse_attn_cutedsl, + _norm_rope_insert_sparse_attn_cutedsl, + _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl, + ) + + +def _compress_kv_sparse_attn_cutedsl(*args, **kwargs): + """CuTe DSL sparse-attention compress wrapper.""" + return _get_sparse_attn_cutedsl_impls()[0](*args, **kwargs) + + +def _norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs): + """CuTe DSL RMSNorm/RoPE/FP8-store wrapper.""" + return _get_sparse_attn_cutedsl_impls()[1](*args, **kwargs) + + +def _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs): + """CuTe DSL fused C4 sparse-attention compressor wrapper.""" + return _get_sparse_attn_cutedsl_impls()[2](*args, **kwargs) + # ============================================================================= # DeepseekV4 Attention path (head=512, nope=448 FP8 + rope=64 bf16) diff --git a/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py b/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py index e88fe1529cd6..d5aaf10feba4 100644 --- a/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py +++ b/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py @@ -346,7 +346,7 @@ def fused_indexer_q_rope_quant( ) if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. - from vllm.models.deepseek_v4.nvidia.ops import ( + from vllm.models.deepseek_v4.nvidia.ops.fused_indexer_q_cutedsl import ( fused_indexer_q_rope_quant_mxfp4_cutedsl, ) @@ -400,7 +400,7 @@ def fused_indexer_q_rope_quant( index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. - from vllm.models.deepseek_v4.nvidia.ops import ( + from vllm.models.deepseek_v4.nvidia.ops.fused_indexer_q_cutedsl import ( fused_indexer_q_rope_quant_fp8_cutedsl, ) diff --git a/vllm/models/deepseek_v4/common/ops/save_partial_states.py b/vllm/models/deepseek_v4/common/ops/save_partial_states.py deleted file mode 100644 index e3d7d38f454b..000000000000 --- a/vllm/models/deepseek_v4/common/ops/save_partial_states.py +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.triton_utils import tl, triton - - -def save_partial_states( - kv: torch.Tensor, - score: torch.Tensor, - ape: torch.Tensor, - positions: torch.Tensor, - state_cache: torch.Tensor, - slot_mapping: torch.Tensor, - block_size: int, - state_width: int, - compress_ratio: int, - pdl_kwargs: dict | None = None, -) -> None: - """Write packed [kv, score+ape] partial states into the compressor cache. - - One program per token; pads (slot_id == -1) are skipped. - """ - num_actual = slot_mapping.shape[0] - head_size = kv.shape[-1] - _save_partial_states_kernel[(num_actual,)]( - kv, - kv.stride(0), - score, - score.stride(0), - ape, - ape.stride(0), - positions, - state_cache, - state_cache.stride(0), - state_cache.stride(1), - slot_mapping, - block_size, - HEAD_SIZE=head_size, - TRITON_BLOCK_SIZE=triton.next_power_of_2(head_size), - STATE_WIDTH=state_width, - COMPRESS_RATIO=compress_ratio, - **(pdl_kwargs or {}), - ) - - -@triton.jit -def _save_partial_states_kernel( - kv_ptr, - kv_stride, - score_ptr, - score_stride, - ape_ptr, - ape_stride, - positions_ptr, - state_cache_ptr, - state_cache_stride0, - state_cache_stride1, - slot_mapping_ptr, - block_size, - HEAD_SIZE: tl.constexpr, - TRITON_BLOCK_SIZE: tl.constexpr, - # state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide. - STATE_WIDTH: tl.constexpr, - COMPRESS_RATIO: tl.constexpr, -): - token_idx = tl.program_id(0) - slot_id = tl.load(slot_mapping_ptr + token_idx) - - # Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used - # by vLLM). During CUDA graph replay the batch may contain padding - # tokens whose slot_mapping is -1; writing to kv_state[-1] would be an - # illegal memory access. - if slot_id < 0: - return - - block_idx = slot_id // block_size - pos_in_block = slot_id % block_size - base_ptr = ( - state_cache_ptr - + block_idx * state_cache_stride0 - + pos_in_block * state_cache_stride1 - ) - - block = tl.arange(0, TRITON_BLOCK_SIZE) - mask = block < HEAD_SIZE - - kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask) - tl.store(base_ptr + block, kv, mask=mask) - - # Fused: score += ape[position % compress_ratio] - position = tl.load(positions_ptr + token_idx) - ape_row = position % COMPRESS_RATIO - ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask) - score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask) - tl.store( - base_ptr + STATE_WIDTH + block, - score + ape, - mask=mask, - ) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py b/vllm/models/deepseek_v4/common/ops/sparse_attn_compress_cutedsl.py similarity index 93% rename from vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py rename to vllm/models/deepseek_v4/common/ops/sparse_attn_compress_cutedsl.py index 0eba82126afd..1f475c8fb840 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py +++ b/vllm/models/deepseek_v4/common/ops/sparse_attn_compress_cutedsl.py @@ -8,7 +8,6 @@ from __future__ import annotations from functools import cache -from typing import Any import cutlass import cutlass.cute as cute @@ -1087,7 +1086,7 @@ def compile( ) -def compress_kv_sparse_attn_cutedsl( +def _compress_kv_sparse_attn_cutedsl( state_cache: torch.Tensor, token_to_req_indices: torch.Tensor, positions: torch.Tensor, @@ -1119,7 +1118,7 @@ def compress_kv_sparse_attn_cutedsl( ) -def norm_rope_insert_sparse_attn_cutedsl( +def _norm_rope_insert_sparse_attn_cutedsl( compressed_kv: torch.Tensor, positions: torch.Tensor, slot_mapping: torch.Tensor, @@ -1175,7 +1174,7 @@ def norm_rope_insert_sparse_attn_cutedsl( ) -def fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( +def _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( state_cache: torch.Tensor, token_to_req_indices: torch.Tensor, positions: torch.Tensor, @@ -1239,94 +1238,3 @@ def fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( kv_slot_mapping, kv_cache_block_size, ) - - -def compress_norm_rope_store_cutedsl( - state_cache: torch.Tensor, - num_actual: int, - token_to_req_indices: torch.Tensor, - positions: torch.Tensor, - slot_mapping: torch.Tensor, - block_table: torch.Tensor, - block_size: int, - state_width: int, - cos_sin_cache: torch.Tensor, - kv_cache: torch.Tensor, - k_cache_metadata: Any, - pdl_kwargs: dict, - head_dim: int, - rope_head_dim: int, - compress_ratio: int, - overlap: bool, - use_fp4_cache: bool, - rms_norm_weight: torch.Tensor, - rms_norm_eps: float, - quant_block: int, - token_stride: int, - scale_dim: int, -) -> None: - if compress_ratio == 4: - # For C4A, the single fused kernel is faster than the two-kernel version. - fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( - state_cache, - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_size, - rms_norm_weight, - rms_norm_eps, - cos_sin_cache, - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size - kv_cache.stride(0), - head_size=head_dim, - state_width=state_width, - rope_head_dim=rope_head_dim, - fp8_max=448.0, - quant_block=quant_block, - token_stride=token_stride, - scale_dim=scale_dim, - compress_ratio=compress_ratio, - overlap=overlap, - ) - else: - # For C128, the two-kernel version is faster than the single fused kernel. - compressed_kv = torch.empty( - (num_actual, head_dim), - dtype=torch.float32, - device=state_cache.device, - ) - compress_kv_sparse_attn_cutedsl( - state_cache, - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_size, - compressed_kv, - head_size=head_dim, - state_width=state_width, - compress_ratio=compress_ratio, - overlap=overlap, - ) - norm_rope_insert_sparse_attn_cutedsl( - compressed_kv, - positions, - slot_mapping, - rms_norm_weight, - rms_norm_eps, - cos_sin_cache, - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size - kv_cache.stride(0), - head_size=head_dim, - rope_head_dim=rope_head_dim, - fp8_max=448.0, - quant_block=quant_block, - token_stride=token_stride, - scale_dim=scale_dim, - compress_ratio=compress_ratio, - ) diff --git a/vllm/models/deepseek_v4/compressor.py b/vllm/models/deepseek_v4/compressor.py index 3234faa5eb05..bc66178f74bf 100644 --- a/vllm/models/deepseek_v4/compressor.py +++ b/vllm/models/deepseek_v4/compressor.py @@ -13,13 +13,15 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import MergedColumnParallelLinear from vllm.models.deepseek_v4.common.ops.fused_compress_quant_cache import ( - compress_norm_rope_store_triton, + _compress_kv_sparse_attn_cutedsl, + _fused_kv_compress_norm_rope_insert_indexer_attn, + _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, + _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl, + _norm_rope_insert_sparse_attn_cutedsl, ) from vllm.models.deepseek_v4.common.ops.fused_indexer_q import MXFP4_BLOCK_SIZE -from vllm.models.deepseek_v4.common.ops.save_partial_states import ( - save_partial_states, -) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -171,16 +173,6 @@ def get_attn_backend(self) -> type[AttentionBackend]: class DeepseekCompressor(nn.Module): - """DeepSeek V4 KV/score compressor. - - Owns the linear / norm / state-cache / ape state and the shared forward - prologue (kv/score split, save_partial_states launch). The - compress → norm → RoPE → store step is dispatched to a triton kernel - (``compress_norm_rope_store_triton``) by default, except for the NVIDIA - head_dim=128 indexer path which uses the cutedsl kernel - (``compress_norm_rope_store_cutedsl``) for better performance. - """ - def __init__( self, vllm_config: VllmConfig, @@ -250,18 +242,32 @@ def __init__( assert not use_fp4_cache, ( "MXFP4 cache is only supported for indexer (head=128)" ) + self._use_cutedsl_sparse_compressor = True + self._use_cutedsl_fused_sparse_compressor = self.compress_ratio == 4 + self._compress_kernel = _compress_kv_sparse_attn_cutedsl + self._norm_rope_store_kernel = _norm_rope_insert_sparse_attn_cutedsl + self._fused_sparse_kernel = ( + _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl + ) self._quant_block = 64 self._token_stride = self.nope_head_dim + self.rope_head_dim * 2 self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad + self._num_warps = 4 elif self.head_dim == 128: + self._use_cutedsl_sparse_compressor = False if use_fp4_cache: + self._fused_kernel = ( + _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn + ) self._quant_block = MXFP4_BLOCK_SIZE self._token_stride = self.head_dim // 2 self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE else: + self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn self._quant_block = 128 self._token_stride = self.head_dim self._scale_dim = 4 # single float32 scale + self._num_warps = 1 else: raise ValueError( f"Unsupported head_dim for fused quant+cache: {self.head_dim}" @@ -306,22 +312,29 @@ def forward( ) # Store the KV and score (with fused APE addition) in the state. - # NOTE: PDL is disabled — both this kernel and the compress kernels - # below depend on preceding kernel outputs (kv/score from the cublas - # GEMM; state_cache from this kernel) but neither emits/waits on PDL - # grid dependency primitives, so launch_pdl=True caused a - # read-after-write race and non-deterministic output. - save_partial_states( - kv=kv, - score=score, - ape=self.ape, - positions=positions, - state_cache=state_cache, - slot_mapping=slot_mapping, - block_size=block_size, - state_width=state_width, - compress_ratio=self.compress_ratio, - pdl_kwargs=pdl_kwargs, + # NOTE: PDL is disabled — both this kernel and _fused_kernel below + # depend on preceding kernel outputs (kv/score from the cublas GEMM; + # state_cache from this kernel) but neither emits/waits on PDL grid + # dependency primitives, so launch_pdl=True caused a read-after-write + # race and non-deterministic output. + _save_partial_states_kernel[(num_actual,)]( + kv, + kv.stride(0), + score, + score.stride(0), + self.ape, + self.ape.stride(0), + positions, + state_cache, + state_cache.stride(0), + state_cache.stride(1), + slot_mapping, + block_size, + HEAD_SIZE=kv.shape[-1], + TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), + STATE_WIDTH=state_width, + COMPRESS_RATIO=self.compress_ratio, + **pdl_kwargs, ) # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. @@ -335,44 +348,161 @@ def forward( k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache - if current_platform.is_cuda(): - # NVIDIA GPUs. - if self.head_dim == 512: - from .nvidia.ops import compress_norm_rope_store_cutedsl - - # Main compressor path. - # Use a cutedsl kernel for better performance. - compress_norm_rope_store_fn = compress_norm_rope_store_cutedsl + if self._use_cutedsl_sparse_compressor: + if self._use_cutedsl_fused_sparse_compressor: + self._fused_sparse_kernel( + state_cache, + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_size, + self.norm.weight, + self.rms_norm_eps, + cos_sin_cache, + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size + kv_cache.stride(0), + head_size=self.head_dim, + state_width=state_width, + rope_head_dim=self.rope_head_dim, + fp8_max=448.0, + quant_block=self._quant_block, + token_stride=self._token_stride, + scale_dim=self._scale_dim, + compress_ratio=self.compress_ratio, + overlap=self.overlap, + ) else: - # Indexer path (head_dim == 128). - # Use a triton kernel. - compress_norm_rope_store_fn = compress_norm_rope_store_triton + compressed_kv = torch.empty( + (num_actual, self.head_dim), + dtype=torch.float32, + device=state_cache.device, + ) + self._compress_kernel( + state_cache, + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_size, + compressed_kv, + head_size=self.head_dim, + state_width=state_width, + compress_ratio=self.compress_ratio, + overlap=self.overlap, + ) + self._norm_rope_store_kernel( + compressed_kv, + positions, + slot_mapping, + self.norm.weight, + self.rms_norm_eps, + cos_sin_cache, + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size + kv_cache.stride(0), + head_size=self.head_dim, + rope_head_dim=self.rope_head_dim, + fp8_max=448.0, + quant_block=self._quant_block, + token_stride=self._token_stride, + scale_dim=self._scale_dim, + compress_ratio=self.compress_ratio, + ) else: - # AMD GPUs. - # Always use a triton kernel. - compress_norm_rope_store_fn = compress_norm_rope_store_triton + self._fused_kernel[(num_actual,)]( + # state cache + state_cache, + state_cache.stride(0), + state_cache.stride(1), + # metadata + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_table.stride(0), + block_size, + # RMSNorm + self.norm.weight, + self.rms_norm_eps, + # RoPE + cos_sin_cache, + cos_sin_cache.stride(0), + # KV cache + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size (tokens per block) + # constexprs + HEAD_SIZE=self.head_dim, + TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim), + STATE_WIDTH=state_width, + COMPRESS_RATIO=self.compress_ratio, + OVERLAP=self.overlap, + ROPE_HEAD_DIM=self.rope_head_dim, + FP8_MAX=448.0, + QUANT_BLOCK=self._quant_block, + TOKEN_STRIDE=self._token_stride, + SCALE_DIM=self._scale_dim, + KV_BLOCK_STRIDE=kv_cache.stride(0), + num_warps=self._num_warps, + **pdl_kwargs, + ) - compress_norm_rope_store_fn( - state_cache=state_cache, - num_actual=num_actual, - token_to_req_indices=token_to_req_indices, - positions=positions, - slot_mapping=slot_mapping, - block_table=block_table, - block_size=block_size, - state_width=state_width, - cos_sin_cache=cos_sin_cache, - kv_cache=kv_cache, - k_cache_metadata=k_cache_metadata, - pdl_kwargs=pdl_kwargs, - head_dim=self.head_dim, - rope_head_dim=self.rope_head_dim, - compress_ratio=self.compress_ratio, - overlap=self.overlap, - use_fp4_cache=self.use_fp4_cache, - rms_norm_weight=self.norm.weight, - rms_norm_eps=self.rms_norm_eps, - quant_block=self._quant_block, - token_stride=self._token_stride, - scale_dim=self._scale_dim, - ) + +@triton.jit +def _save_partial_states_kernel( + kv_ptr, + kv_stride, + score_ptr, + score_stride, + ape_ptr, + ape_stride, + positions_ptr, + state_cache_ptr, + state_cache_stride0, + state_cache_stride1, + slot_mapping_ptr, + block_size, + HEAD_SIZE: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, + # state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide. + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, +): + token_idx = tl.program_id(0) + slot_id = tl.load(slot_mapping_ptr + token_idx) + + # Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used + # by vLLM). During CUDA graph replay the batch may contain padding + # tokens whose slot_mapping is -1; writing to kv_state[-1] would be an + # illegal memory access. + if slot_id < 0: + return + + block_idx = slot_id // block_size + pos_in_block = slot_id % block_size + base_ptr = ( + state_cache_ptr + + block_idx * state_cache_stride0 + + pos_in_block * state_cache_stride1 + ) + + block = tl.arange(0, TRITON_BLOCK_SIZE) + mask = block < HEAD_SIZE + + kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask) + tl.store(base_ptr + block, kv, mask=mask) + + # Fused: score += ape[position % compress_ratio] + position = tl.load(positions_ptr + token_idx) + ape_row = position % COMPRESS_RATIO + ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask) + score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask) + tl.store( + base_ptr + STATE_WIDTH + block, + score + ape, + mask=mask, + ) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 974593a8d390..25f0a730fdbd 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -59,7 +59,7 @@ DeepseekV4MLAModules, DeepseekV4MultiHeadLatentAttentionWrapper, ) -from vllm.models.deepseek_v4.nvidia.ops import prepare_megamoe_inputs +from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.torch_utils import direct_register_custom_op diff --git a/vllm/models/deepseek_v4/nvidia/ops/__init__.py b/vllm/models/deepseek_v4/nvidia/ops/__init__.py index dca25345ea6f..37276e1816f0 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/__init__.py +++ b/vllm/models/deepseek_v4/nvidia/ops/__init__.py @@ -6,19 +6,3 @@ not be imported on non-CUDA platforms. Callers should gate on ``vllm.utils.import_utils.has_cutedsl()`` before importing from here. """ - -from .dequant_gather_k_cutedsl import dequantize_and_gather_k_cache_cutedsl -from .fused_indexer_q_cutedsl import ( - fused_indexer_q_rope_quant_fp8_cutedsl, - fused_indexer_q_rope_quant_mxfp4_cutedsl, -) -from .prepare_megamoe import prepare_megamoe_inputs -from .sparse_attn_compress_cutedsl import compress_norm_rope_store_cutedsl - -__all__ = [ - "compress_norm_rope_store_cutedsl", - "dequantize_and_gather_k_cache_cutedsl", - "fused_indexer_q_rope_quant_fp8_cutedsl", - "fused_indexer_q_rope_quant_mxfp4_cutedsl", - "prepare_megamoe_inputs", -]