From 0e4bd5e5b2801b606363db42860bbe7af42deb69 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 26 May 2026 20:09:57 +0000 Subject: [PATCH 1/5] [DSv4] Refactor compressor & BugFix for ROCm Signed-off-by: Woosuk Kwon --- .../models/deepseek_v4/common/ops/__init__.py | 2 + .../common/ops/fused_compress_quant_cache.py | 112 +++++--- .../common/ops/save_partial_states.py | 101 +++++++ vllm/models/deepseek_v4/compressor.py | 266 +++++------------- vllm/models/deepseek_v4/nvidia/compressor.py | 108 +++++++ .../ops/sparse_attn_compress_cutedsl.py | 6 +- 6 files changed, 356 insertions(+), 239 deletions(-) create mode 100644 vllm/models/deepseek_v4/common/ops/save_partial_states.py create mode 100644 vllm/models/deepseek_v4/nvidia/compressor.py rename vllm/models/deepseek_v4/{common => nvidia}/ops/sparse_attn_compress_cutedsl.py (99%) diff --git a/vllm/models/deepseek_v4/common/ops/__init__.py b/vllm/models/deepseek_v4/common/ops/__init__.py index 959a79f292a5..0f80c329740b 100644 --- a/vllm/models/deepseek_v4/common/ops/__init__.py +++ b/vllm/models/deepseek_v4/common/ops/__init__.py @@ -10,6 +10,7 @@ from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant from .fused_qk_rmsnorm import fused_q_kv_rmsnorm +from .save_partial_states import save_partial_states __all__ = [ "MXFP4_BLOCK_SIZE", @@ -20,4 +21,5 @@ "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", "quantize_and_insert_k_cache", + "save_partial_states", ] diff --git a/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py b/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py index ed839e8c30c6..5d1131c5edd6 100644 --- a/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py +++ b/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py @@ -11,12 +11,6 @@ - _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn: head=128, MXFP4 (block=32), 4 ue8m0 bytes -Additional cutedsl kernels: - - _compress_kv_sparse_attn_cutedsl / _norm_rope_insert_sparse_attn_cutedsl: - CuTe DSL split kernels for C128 - - _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl: - CuTe DSL fused kernels for C4 - RoPE is register-based via tl.reshape -> tl.split -> tl.interleave (or the even/odd halves are consumed directly for MXFP4, no interleave needed). FP8 UE8M0 quant uses tl.reshape to tile [N_QUANT_BLOCKS, QUANT_BLOCK] for @@ -25,43 +19,93 @@ and N_QUANT_BLOCKS ue8m0 bytes. """ -from functools import cache +from typing import Any + +import torch from vllm.triton_utils import tl, triton from .fused_indexer_q import _fp32x2_to_fp4x2 -@cache -def _get_sparse_attn_cutedsl_impls(): - from .sparse_attn_compress_cutedsl import ( - _compress_kv_sparse_attn_cutedsl, - _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl, - _norm_rope_insert_sparse_attn_cutedsl, - ) - - return ( - _compress_kv_sparse_attn_cutedsl, - _norm_rope_insert_sparse_attn_cutedsl, - _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl, +def fused_kv_compress_norm_rope_insert( + state_cache: torch.Tensor, + num_actual: int, + token_to_req_indices: torch.Tensor, + positions: torch.Tensor, + slot_mapping: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + state_width: int, + cos_sin_cache: torch.Tensor, + kv_cache: torch.Tensor, + k_cache_metadata: Any, + pdl_kwargs: dict, + head_dim: int, + rope_head_dim: int, + compress_ratio: int, + overlap: bool, + use_fp4_cache: bool, + rms_norm_weight: torch.Tensor, + rms_norm_eps: float, + quant_block: int, + token_stride: int, + scale_dim: int, +) -> None: + """Shared triton launcher for the fused compress+norm+RoPE+insert path. + + Picks one of the three kernels in this module based on ``head_dim`` and + ``use_fp4_cache``. Identical launch signature for all three. + """ + if head_dim == 512: + kernel = _fused_kv_compress_norm_rope_insert_sparse_attn + num_warps = 4 + elif use_fp4_cache: + kernel = _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn + num_warps = 1 + else: + kernel = _fused_kv_compress_norm_rope_insert_indexer_attn + num_warps = 1 + + kernel[(num_actual,)]( + # state cache + state_cache, + state_cache.stride(0), + state_cache.stride(1), + # metadata + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_table.stride(0), + block_size, + # RMSNorm + rms_norm_weight, + rms_norm_eps, + # RoPE + cos_sin_cache, + cos_sin_cache.stride(0), + # KV cache + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size (tokens per block) + # constexprs + HEAD_SIZE=head_dim, + TRITON_BLOCK_SIZE=triton.next_power_of_2(head_dim), + STATE_WIDTH=state_width, + COMPRESS_RATIO=compress_ratio, + OVERLAP=overlap, + ROPE_HEAD_DIM=rope_head_dim, + FP8_MAX=448.0, + QUANT_BLOCK=quant_block, + TOKEN_STRIDE=token_stride, + SCALE_DIM=scale_dim, + KV_BLOCK_STRIDE=kv_cache.stride(0), + num_warps=num_warps, + **pdl_kwargs, ) -def _compress_kv_sparse_attn_cutedsl(*args, **kwargs): - """CuTe DSL sparse-attention compress wrapper.""" - return _get_sparse_attn_cutedsl_impls()[0](*args, **kwargs) - - -def _norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs): - """CuTe DSL RMSNorm/RoPE/FP8-store wrapper.""" - return _get_sparse_attn_cutedsl_impls()[1](*args, **kwargs) - - -def _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs): - """CuTe DSL fused C4 sparse-attention compressor wrapper.""" - return _get_sparse_attn_cutedsl_impls()[2](*args, **kwargs) - - # ============================================================================= # DeepseekV4 Attention path (head=512, nope=448 FP8 + rope=64 bf16) # ============================================================================= diff --git a/vllm/models/deepseek_v4/common/ops/save_partial_states.py b/vllm/models/deepseek_v4/common/ops/save_partial_states.py new file mode 100644 index 000000000000..e3d7d38f454b --- /dev/null +++ b/vllm/models/deepseek_v4/common/ops/save_partial_states.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.triton_utils import tl, triton + + +def save_partial_states( + kv: torch.Tensor, + score: torch.Tensor, + ape: torch.Tensor, + positions: torch.Tensor, + state_cache: torch.Tensor, + slot_mapping: torch.Tensor, + block_size: int, + state_width: int, + compress_ratio: int, + pdl_kwargs: dict | None = None, +) -> None: + """Write packed [kv, score+ape] partial states into the compressor cache. + + One program per token; pads (slot_id == -1) are skipped. + """ + num_actual = slot_mapping.shape[0] + head_size = kv.shape[-1] + _save_partial_states_kernel[(num_actual,)]( + kv, + kv.stride(0), + score, + score.stride(0), + ape, + ape.stride(0), + positions, + state_cache, + state_cache.stride(0), + state_cache.stride(1), + slot_mapping, + block_size, + HEAD_SIZE=head_size, + TRITON_BLOCK_SIZE=triton.next_power_of_2(head_size), + STATE_WIDTH=state_width, + COMPRESS_RATIO=compress_ratio, + **(pdl_kwargs or {}), + ) + + +@triton.jit +def _save_partial_states_kernel( + kv_ptr, + kv_stride, + score_ptr, + score_stride, + ape_ptr, + ape_stride, + positions_ptr, + state_cache_ptr, + state_cache_stride0, + state_cache_stride1, + slot_mapping_ptr, + block_size, + HEAD_SIZE: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, + # state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide. + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, +): + token_idx = tl.program_id(0) + slot_id = tl.load(slot_mapping_ptr + token_idx) + + # Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used + # by vLLM). During CUDA graph replay the batch may contain padding + # tokens whose slot_mapping is -1; writing to kv_state[-1] would be an + # illegal memory access. + if slot_id < 0: + return + + block_idx = slot_id // block_size + pos_in_block = slot_id % block_size + base_ptr = ( + state_cache_ptr + + block_idx * state_cache_stride0 + + pos_in_block * state_cache_stride1 + ) + + block = tl.arange(0, TRITON_BLOCK_SIZE) + mask = block < HEAD_SIZE + + kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask) + tl.store(base_ptr + block, kv, mask=mask) + + # Fused: score += ape[position % compress_ratio] + position = tl.load(positions_ptr + token_idx) + ape_row = position % COMPRESS_RATIO + ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask) + score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask) + tl.store( + base_ptr + STATE_WIDTH + block, + score + ape, + mask=mask, + ) diff --git a/vllm/models/deepseek_v4/compressor.py b/vllm/models/deepseek_v4/compressor.py index bc66178f74bf..7918b0a86cfe 100644 --- a/vllm/models/deepseek_v4/compressor.py +++ b/vllm/models/deepseek_v4/compressor.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast import torch from torch import nn @@ -13,15 +13,19 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import MergedColumnParallelLinear from vllm.models.deepseek_v4.common.ops.fused_compress_quant_cache import ( - _compress_kv_sparse_attn_cutedsl, - _fused_kv_compress_norm_rope_insert_indexer_attn, - _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, - _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl, - _norm_rope_insert_sparse_attn_cutedsl, + fused_kv_compress_norm_rope_insert, ) from vllm.models.deepseek_v4.common.ops.fused_indexer_q import MXFP4_BLOCK_SIZE +from vllm.models.deepseek_v4.common.ops.save_partial_states import ( + save_partial_states, +) from vllm.platforms import current_platform -from vllm.triton_utils import tl, triton + +if TYPE_CHECKING or not current_platform.is_rocm(): + from vllm.models.deepseek_v4.nvidia.compressor import compress_norm_rope_store +else: + # AMD head=512 has no extra logic over the shared triton launcher. + compress_norm_rope_store = fused_kv_compress_norm_rope_insert from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -173,6 +177,14 @@ def get_attn_backend(self) -> type[AttentionBackend]: class DeepseekCompressor(nn.Module): + """DeepSeek V4 KV/score compressor. + + Owns the linear / norm / state-cache / ape state and the shared forward + prologue (kv/score split, save_partial_states launch). The per-platform + compress → norm → RoPE → store step is dispatched to + ``compress_norm_rope_store`` imported from ``nvidia/`` or ``amd/``. + """ + def __init__( self, vllm_config: VllmConfig, @@ -242,32 +254,18 @@ def __init__( assert not use_fp4_cache, ( "MXFP4 cache is only supported for indexer (head=128)" ) - self._use_cutedsl_sparse_compressor = True - self._use_cutedsl_fused_sparse_compressor = self.compress_ratio == 4 - self._compress_kernel = _compress_kv_sparse_attn_cutedsl - self._norm_rope_store_kernel = _norm_rope_insert_sparse_attn_cutedsl - self._fused_sparse_kernel = ( - _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl - ) self._quant_block = 64 self._token_stride = self.nope_head_dim + self.rope_head_dim * 2 self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad - self._num_warps = 4 elif self.head_dim == 128: - self._use_cutedsl_sparse_compressor = False if use_fp4_cache: - self._fused_kernel = ( - _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn - ) self._quant_block = MXFP4_BLOCK_SIZE self._token_stride = self.head_dim // 2 self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE else: - self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn self._quant_block = 128 self._token_stride = self.head_dim self._scale_dim = 4 # single float32 scale - self._num_warps = 1 else: raise ValueError( f"Unsupported head_dim for fused quant+cache: {self.head_dim}" @@ -312,29 +310,22 @@ def forward( ) # Store the KV and score (with fused APE addition) in the state. - # NOTE: PDL is disabled — both this kernel and _fused_kernel below - # depend on preceding kernel outputs (kv/score from the cublas GEMM; - # state_cache from this kernel) but neither emits/waits on PDL grid - # dependency primitives, so launch_pdl=True caused a read-after-write - # race and non-deterministic output. - _save_partial_states_kernel[(num_actual,)]( - kv, - kv.stride(0), - score, - score.stride(0), - self.ape, - self.ape.stride(0), - positions, - state_cache, - state_cache.stride(0), - state_cache.stride(1), - slot_mapping, - block_size, - HEAD_SIZE=kv.shape[-1], - TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), - STATE_WIDTH=state_width, - COMPRESS_RATIO=self.compress_ratio, - **pdl_kwargs, + # NOTE: PDL is disabled — both this kernel and the compress kernels + # below depend on preceding kernel outputs (kv/score from the cublas + # GEMM; state_cache from this kernel) but neither emits/waits on PDL + # grid dependency primitives, so launch_pdl=True caused a + # read-after-write race and non-deterministic output. + save_partial_states( + kv=kv, + score=score, + ape=self.ape, + positions=positions, + state_cache=state_cache, + slot_mapping=slot_mapping, + block_size=block_size, + state_width=state_width, + compress_ratio=self.compress_ratio, + pdl_kwargs=pdl_kwargs, ) # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. @@ -348,161 +339,32 @@ def forward( k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache - if self._use_cutedsl_sparse_compressor: - if self._use_cutedsl_fused_sparse_compressor: - self._fused_sparse_kernel( - state_cache, - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_size, - self.norm.weight, - self.rms_norm_eps, - cos_sin_cache, - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size - kv_cache.stride(0), - head_size=self.head_dim, - state_width=state_width, - rope_head_dim=self.rope_head_dim, - fp8_max=448.0, - quant_block=self._quant_block, - token_stride=self._token_stride, - scale_dim=self._scale_dim, - compress_ratio=self.compress_ratio, - overlap=self.overlap, - ) - else: - compressed_kv = torch.empty( - (num_actual, self.head_dim), - dtype=torch.float32, - device=state_cache.device, - ) - self._compress_kernel( - state_cache, - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_size, - compressed_kv, - head_size=self.head_dim, - state_width=state_width, - compress_ratio=self.compress_ratio, - overlap=self.overlap, - ) - self._norm_rope_store_kernel( - compressed_kv, - positions, - slot_mapping, - self.norm.weight, - self.rms_norm_eps, - cos_sin_cache, - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size - kv_cache.stride(0), - head_size=self.head_dim, - rope_head_dim=self.rope_head_dim, - fp8_max=448.0, - quant_block=self._quant_block, - token_stride=self._token_stride, - scale_dim=self._scale_dim, - compress_ratio=self.compress_ratio, - ) - else: - self._fused_kernel[(num_actual,)]( - # state cache - state_cache, - state_cache.stride(0), - state_cache.stride(1), - # metadata - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_table.stride(0), - block_size, - # RMSNorm - self.norm.weight, - self.rms_norm_eps, - # RoPE - cos_sin_cache, - cos_sin_cache.stride(0), - # KV cache - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size (tokens per block) - # constexprs - HEAD_SIZE=self.head_dim, - TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim), - STATE_WIDTH=state_width, - COMPRESS_RATIO=self.compress_ratio, - OVERLAP=self.overlap, - ROPE_HEAD_DIM=self.rope_head_dim, - FP8_MAX=448.0, - QUANT_BLOCK=self._quant_block, - TOKEN_STRIDE=self._token_stride, - SCALE_DIM=self._scale_dim, - KV_BLOCK_STRIDE=kv_cache.stride(0), - num_warps=self._num_warps, - **pdl_kwargs, - ) - - -@triton.jit -def _save_partial_states_kernel( - kv_ptr, - kv_stride, - score_ptr, - score_stride, - ape_ptr, - ape_stride, - positions_ptr, - state_cache_ptr, - state_cache_stride0, - state_cache_stride1, - slot_mapping_ptr, - block_size, - HEAD_SIZE: tl.constexpr, - TRITON_BLOCK_SIZE: tl.constexpr, - # state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide. - STATE_WIDTH: tl.constexpr, - COMPRESS_RATIO: tl.constexpr, -): - token_idx = tl.program_id(0) - slot_id = tl.load(slot_mapping_ptr + token_idx) - - # Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used - # by vLLM). During CUDA graph replay the batch may contain padding - # tokens whose slot_mapping is -1; writing to kv_state[-1] would be an - # illegal memory access. - if slot_id < 0: - return - - block_idx = slot_id // block_size - pos_in_block = slot_id % block_size - base_ptr = ( - state_cache_ptr - + block_idx * state_cache_stride0 - + pos_in_block * state_cache_stride1 - ) - - block = tl.arange(0, TRITON_BLOCK_SIZE) - mask = block < HEAD_SIZE - - kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask) - tl.store(base_ptr + block, kv, mask=mask) - - # Fused: score += ape[position % compress_ratio] - position = tl.load(positions_ptr + token_idx) - ape_row = position % COMPRESS_RATIO - ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask) - score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask) - tl.store( - base_ptr + STATE_WIDTH + block, - score + ape, - mask=mask, - ) + dispatch = ( + fused_kv_compress_norm_rope_insert + if self.head_dim == 128 + else compress_norm_rope_store + ) + dispatch( + state_cache=state_cache, + num_actual=num_actual, + token_to_req_indices=token_to_req_indices, + positions=positions, + slot_mapping=slot_mapping, + block_table=block_table, + block_size=block_size, + state_width=state_width, + cos_sin_cache=cos_sin_cache, + kv_cache=kv_cache, + k_cache_metadata=k_cache_metadata, + pdl_kwargs=pdl_kwargs, + head_dim=self.head_dim, + rope_head_dim=self.rope_head_dim, + compress_ratio=self.compress_ratio, + overlap=self.overlap, + use_fp4_cache=self.use_fp4_cache, + rms_norm_weight=self.norm.weight, + rms_norm_eps=self.rms_norm_eps, + quant_block=self._quant_block, + token_stride=self._token_stride, + scale_dim=self._scale_dim, + ) diff --git a/vllm/models/deepseek_v4/nvidia/compressor.py b/vllm/models/deepseek_v4/nvidia/compressor.py new file mode 100644 index 000000000000..8840971632ae --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/compressor.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""NVIDIA compress → norm → RoPE → store dispatch for head=512. + +The head=128 indexer path is handled by the shared triton launcher in +``common/ops/fused_compress_quant_cache.py``; only head=512 (cutedsl) +flows through here. Uses fused C4 or split C128 kernels based on +``compress_ratio``. +""" + +from typing import Any + +import torch + +from vllm.models.deepseek_v4.nvidia.ops.sparse_attn_compress_cutedsl import ( + compress_kv_sparse_attn_cutedsl, + fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl, + norm_rope_insert_sparse_attn_cutedsl, +) + + +def compress_norm_rope_store( + state_cache: torch.Tensor, + num_actual: int, + token_to_req_indices: torch.Tensor, + positions: torch.Tensor, + slot_mapping: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + state_width: int, + cos_sin_cache: torch.Tensor, + kv_cache: torch.Tensor, + k_cache_metadata: Any, + pdl_kwargs: dict, + head_dim: int, + rope_head_dim: int, + compress_ratio: int, + overlap: bool, + use_fp4_cache: bool, + rms_norm_weight: torch.Tensor, + rms_norm_eps: float, + quant_block: int, + token_stride: int, + scale_dim: int, +) -> None: + if compress_ratio == 4: + fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( + state_cache, + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_size, + rms_norm_weight, + rms_norm_eps, + cos_sin_cache, + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size + kv_cache.stride(0), + head_size=head_dim, + state_width=state_width, + rope_head_dim=rope_head_dim, + fp8_max=448.0, + quant_block=quant_block, + token_stride=token_stride, + scale_dim=scale_dim, + compress_ratio=compress_ratio, + overlap=overlap, + ) + else: + compressed_kv = torch.empty( + (num_actual, head_dim), + dtype=torch.float32, + device=state_cache.device, + ) + compress_kv_sparse_attn_cutedsl( + state_cache, + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_size, + compressed_kv, + head_size=head_dim, + state_width=state_width, + compress_ratio=compress_ratio, + overlap=overlap, + ) + norm_rope_insert_sparse_attn_cutedsl( + compressed_kv, + positions, + slot_mapping, + rms_norm_weight, + rms_norm_eps, + cos_sin_cache, + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size + kv_cache.stride(0), + head_size=head_dim, + rope_head_dim=rope_head_dim, + fp8_max=448.0, + quant_block=quant_block, + token_stride=token_stride, + scale_dim=scale_dim, + compress_ratio=compress_ratio, + ) diff --git a/vllm/models/deepseek_v4/common/ops/sparse_attn_compress_cutedsl.py b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py similarity index 99% rename from vllm/models/deepseek_v4/common/ops/sparse_attn_compress_cutedsl.py rename to vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py index 1f475c8fb840..8109d2d695ec 100644 --- a/vllm/models/deepseek_v4/common/ops/sparse_attn_compress_cutedsl.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py @@ -1086,7 +1086,7 @@ def compile( ) -def _compress_kv_sparse_attn_cutedsl( +def compress_kv_sparse_attn_cutedsl( state_cache: torch.Tensor, token_to_req_indices: torch.Tensor, positions: torch.Tensor, @@ -1118,7 +1118,7 @@ def _compress_kv_sparse_attn_cutedsl( ) -def _norm_rope_insert_sparse_attn_cutedsl( +def norm_rope_insert_sparse_attn_cutedsl( compressed_kv: torch.Tensor, positions: torch.Tensor, slot_mapping: torch.Tensor, @@ -1174,7 +1174,7 @@ def _norm_rope_insert_sparse_attn_cutedsl( ) -def _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( +def fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( state_cache: torch.Tensor, token_to_req_indices: torch.Tensor, positions: torch.Tensor, From 80cdb82ab056e3cd85ebd3b0422dbf48681cf2c2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 26 May 2026 21:23:21 +0000 Subject: [PATCH 2/5] minor Signed-off-by: Woosuk Kwon --- vllm/models/deepseek_v4/compressor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/models/deepseek_v4/compressor.py b/vllm/models/deepseek_v4/compressor.py index 7918b0a86cfe..4933b9e0e74b 100644 --- a/vllm/models/deepseek_v4/compressor.py +++ b/vllm/models/deepseek_v4/compressor.py @@ -339,12 +339,12 @@ def forward( k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache - dispatch = ( + compress_norm_rope_insert_fn = ( fused_kv_compress_norm_rope_insert if self.head_dim == 128 else compress_norm_rope_store ) - dispatch( + compress_norm_rope_insert_fn( state_cache=state_cache, num_actual=num_actual, token_to_req_indices=token_to_req_indices, From efd875b924e6a39dedd3d157e037c6112039fe26 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 May 2026 02:32:39 +0000 Subject: [PATCH 3/5] cleanup Signed-off-by: Woosuk Kwon --- .../deepseek_v4/common/ops/cache_utils.py | 2 +- .../common/ops/fused_compress_quant_cache.py | 2 +- .../deepseek_v4/common/ops/fused_indexer_q.py | 4 +- vllm/models/deepseek_v4/compressor.py | 32 +++--- vllm/models/deepseek_v4/nvidia/compressor.py | 108 ------------------ vllm/models/deepseek_v4/nvidia/model.py | 2 +- .../models/deepseek_v4/nvidia/ops/__init__.py | 16 +++ .../ops/sparse_attn_compress_cutedsl.py | 92 +++++++++++++++ 8 files changed, 131 insertions(+), 127 deletions(-) delete mode 100644 vllm/models/deepseek_v4/nvidia/compressor.py diff --git a/vllm/models/deepseek_v4/common/ops/cache_utils.py b/vllm/models/deepseek_v4/common/ops/cache_utils.py index ac66751e3111..efb936be6c07 100644 --- a/vllm/models/deepseek_v4/common/ops/cache_utils.py +++ b/vllm/models/deepseek_v4/common/ops/cache_utils.py @@ -366,7 +366,7 @@ def dequantize_and_gather_k_cache( ) -> None: if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. - from vllm.models.deepseek_v4.nvidia.ops.dequant_gather_k_cutedsl import ( + from vllm.models.deepseek_v4.nvidia.ops import ( dequantize_and_gather_k_cache_cutedsl, ) diff --git a/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py b/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py index 5d1131c5edd6..9a5e478e315f 100644 --- a/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py +++ b/vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py @@ -28,7 +28,7 @@ from .fused_indexer_q import _fp32x2_to_fp4x2 -def fused_kv_compress_norm_rope_insert( +def compress_norm_rope_store_triton( state_cache: torch.Tensor, num_actual: int, token_to_req_indices: torch.Tensor, diff --git a/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py b/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py index d5aaf10feba4..e88fe1529cd6 100644 --- a/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py +++ b/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py @@ -346,7 +346,7 @@ def fused_indexer_q_rope_quant( ) if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. - from vllm.models.deepseek_v4.nvidia.ops.fused_indexer_q_cutedsl import ( + from vllm.models.deepseek_v4.nvidia.ops import ( fused_indexer_q_rope_quant_mxfp4_cutedsl, ) @@ -400,7 +400,7 @@ def fused_indexer_q_rope_quant( index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. - from vllm.models.deepseek_v4.nvidia.ops.fused_indexer_q_cutedsl import ( + from vllm.models.deepseek_v4.nvidia.ops import ( fused_indexer_q_rope_quant_fp8_cutedsl, ) diff --git a/vllm/models/deepseek_v4/compressor.py b/vllm/models/deepseek_v4/compressor.py index 4933b9e0e74b..674490a18d98 100644 --- a/vllm/models/deepseek_v4/compressor.py +++ b/vllm/models/deepseek_v4/compressor.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import Any, ClassVar, cast import torch from torch import nn @@ -13,19 +13,13 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import MergedColumnParallelLinear from vllm.models.deepseek_v4.common.ops.fused_compress_quant_cache import ( - fused_kv_compress_norm_rope_insert, + compress_norm_rope_store_triton, ) from vllm.models.deepseek_v4.common.ops.fused_indexer_q import MXFP4_BLOCK_SIZE from vllm.models.deepseek_v4.common.ops.save_partial_states import ( save_partial_states, ) from vllm.platforms import current_platform - -if TYPE_CHECKING or not current_platform.is_rocm(): - from vllm.models.deepseek_v4.nvidia.compressor import compress_norm_rope_store -else: - # AMD head=512 has no extra logic over the shared triton launcher. - compress_norm_rope_store = fused_kv_compress_norm_rope_insert from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -339,12 +333,22 @@ def forward( k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache - compress_norm_rope_insert_fn = ( - fused_kv_compress_norm_rope_insert - if self.head_dim == 128 - else compress_norm_rope_store - ) - compress_norm_rope_insert_fn( + if current_platform.is_cuda(): + # NVIDIA GPUs. + if self.head_dim == 128: + from .nvidia.ops import compress_norm_rope_store_cutedsl + + # Indexer path. Use a cutedsl kernel as it performs better. + compress_norm_rope_store_fn = compress_norm_rope_store_cutedsl + else: + # Main compressor path (head_dim == 512). Use a triton kernel. + compress_norm_rope_store_fn = compress_norm_rope_store_triton + else: + # AMD GPUs. + # Always use a triton kernel. + compress_norm_rope_store_fn = compress_norm_rope_store_triton + + compress_norm_rope_store_fn( state_cache=state_cache, num_actual=num_actual, token_to_req_indices=token_to_req_indices, diff --git a/vllm/models/deepseek_v4/nvidia/compressor.py b/vllm/models/deepseek_v4/nvidia/compressor.py deleted file mode 100644 index 8840971632ae..000000000000 --- a/vllm/models/deepseek_v4/nvidia/compressor.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""NVIDIA compress → norm → RoPE → store dispatch for head=512. - -The head=128 indexer path is handled by the shared triton launcher in -``common/ops/fused_compress_quant_cache.py``; only head=512 (cutedsl) -flows through here. Uses fused C4 or split C128 kernels based on -``compress_ratio``. -""" - -from typing import Any - -import torch - -from vllm.models.deepseek_v4.nvidia.ops.sparse_attn_compress_cutedsl import ( - compress_kv_sparse_attn_cutedsl, - fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl, - norm_rope_insert_sparse_attn_cutedsl, -) - - -def compress_norm_rope_store( - state_cache: torch.Tensor, - num_actual: int, - token_to_req_indices: torch.Tensor, - positions: torch.Tensor, - slot_mapping: torch.Tensor, - block_table: torch.Tensor, - block_size: int, - state_width: int, - cos_sin_cache: torch.Tensor, - kv_cache: torch.Tensor, - k_cache_metadata: Any, - pdl_kwargs: dict, - head_dim: int, - rope_head_dim: int, - compress_ratio: int, - overlap: bool, - use_fp4_cache: bool, - rms_norm_weight: torch.Tensor, - rms_norm_eps: float, - quant_block: int, - token_stride: int, - scale_dim: int, -) -> None: - if compress_ratio == 4: - fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( - state_cache, - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_size, - rms_norm_weight, - rms_norm_eps, - cos_sin_cache, - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size - kv_cache.stride(0), - head_size=head_dim, - state_width=state_width, - rope_head_dim=rope_head_dim, - fp8_max=448.0, - quant_block=quant_block, - token_stride=token_stride, - scale_dim=scale_dim, - compress_ratio=compress_ratio, - overlap=overlap, - ) - else: - compressed_kv = torch.empty( - (num_actual, head_dim), - dtype=torch.float32, - device=state_cache.device, - ) - compress_kv_sparse_attn_cutedsl( - state_cache, - token_to_req_indices, - positions, - slot_mapping, - block_table, - block_size, - compressed_kv, - head_size=head_dim, - state_width=state_width, - compress_ratio=compress_ratio, - overlap=overlap, - ) - norm_rope_insert_sparse_attn_cutedsl( - compressed_kv, - positions, - slot_mapping, - rms_norm_weight, - rms_norm_eps, - cos_sin_cache, - kv_cache, - k_cache_metadata.slot_mapping, - kv_cache.shape[1], # paged KV cache block size - kv_cache.stride(0), - head_size=head_dim, - rope_head_dim=rope_head_dim, - fp8_max=448.0, - quant_block=quant_block, - token_stride=token_stride, - scale_dim=scale_dim, - compress_ratio=compress_ratio, - ) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 25f0a730fdbd..974593a8d390 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -59,7 +59,7 @@ DeepseekV4MLAModules, DeepseekV4MultiHeadLatentAttentionWrapper, ) -from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs +from vllm.models.deepseek_v4.nvidia.ops import prepare_megamoe_inputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.torch_utils import direct_register_custom_op diff --git a/vllm/models/deepseek_v4/nvidia/ops/__init__.py b/vllm/models/deepseek_v4/nvidia/ops/__init__.py index 37276e1816f0..dca25345ea6f 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/__init__.py +++ b/vllm/models/deepseek_v4/nvidia/ops/__init__.py @@ -6,3 +6,19 @@ not be imported on non-CUDA platforms. Callers should gate on ``vllm.utils.import_utils.has_cutedsl()`` before importing from here. """ + +from .dequant_gather_k_cutedsl import dequantize_and_gather_k_cache_cutedsl +from .fused_indexer_q_cutedsl import ( + fused_indexer_q_rope_quant_fp8_cutedsl, + fused_indexer_q_rope_quant_mxfp4_cutedsl, +) +from .prepare_megamoe import prepare_megamoe_inputs +from .sparse_attn_compress_cutedsl import compress_norm_rope_store_cutedsl + +__all__ = [ + "compress_norm_rope_store_cutedsl", + "dequantize_and_gather_k_cache_cutedsl", + "fused_indexer_q_rope_quant_fp8_cutedsl", + "fused_indexer_q_rope_quant_mxfp4_cutedsl", + "prepare_megamoe_inputs", +] diff --git a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py index 8109d2d695ec..0eba82126afd 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py @@ -8,6 +8,7 @@ from __future__ import annotations from functools import cache +from typing import Any import cutlass import cutlass.cute as cute @@ -1238,3 +1239,94 @@ def fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( kv_slot_mapping, kv_cache_block_size, ) + + +def compress_norm_rope_store_cutedsl( + state_cache: torch.Tensor, + num_actual: int, + token_to_req_indices: torch.Tensor, + positions: torch.Tensor, + slot_mapping: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + state_width: int, + cos_sin_cache: torch.Tensor, + kv_cache: torch.Tensor, + k_cache_metadata: Any, + pdl_kwargs: dict, + head_dim: int, + rope_head_dim: int, + compress_ratio: int, + overlap: bool, + use_fp4_cache: bool, + rms_norm_weight: torch.Tensor, + rms_norm_eps: float, + quant_block: int, + token_stride: int, + scale_dim: int, +) -> None: + if compress_ratio == 4: + # For C4A, the single fused kernel is faster than the two-kernel version. + fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl( + state_cache, + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_size, + rms_norm_weight, + rms_norm_eps, + cos_sin_cache, + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size + kv_cache.stride(0), + head_size=head_dim, + state_width=state_width, + rope_head_dim=rope_head_dim, + fp8_max=448.0, + quant_block=quant_block, + token_stride=token_stride, + scale_dim=scale_dim, + compress_ratio=compress_ratio, + overlap=overlap, + ) + else: + # For C128, the two-kernel version is faster than the single fused kernel. + compressed_kv = torch.empty( + (num_actual, head_dim), + dtype=torch.float32, + device=state_cache.device, + ) + compress_kv_sparse_attn_cutedsl( + state_cache, + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_size, + compressed_kv, + head_size=head_dim, + state_width=state_width, + compress_ratio=compress_ratio, + overlap=overlap, + ) + norm_rope_insert_sparse_attn_cutedsl( + compressed_kv, + positions, + slot_mapping, + rms_norm_weight, + rms_norm_eps, + cos_sin_cache, + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size + kv_cache.stride(0), + head_size=head_dim, + rope_head_dim=rope_head_dim, + fp8_max=448.0, + quant_block=quant_block, + token_stride=token_stride, + scale_dim=scale_dim, + compress_ratio=compress_ratio, + ) From ea794557332724f4172cc1672d69acb82f002f71 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 May 2026 02:35:17 +0000 Subject: [PATCH 4/5] docstring Signed-off-by: Woosuk Kwon --- vllm/models/deepseek_v4/compressor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/models/deepseek_v4/compressor.py b/vllm/models/deepseek_v4/compressor.py index 674490a18d98..0de30ef0ab62 100644 --- a/vllm/models/deepseek_v4/compressor.py +++ b/vllm/models/deepseek_v4/compressor.py @@ -174,9 +174,11 @@ class DeepseekCompressor(nn.Module): """DeepSeek V4 KV/score compressor. Owns the linear / norm / state-cache / ape state and the shared forward - prologue (kv/score split, save_partial_states launch). The per-platform - compress → norm → RoPE → store step is dispatched to - ``compress_norm_rope_store`` imported from ``nvidia/`` or ``amd/``. + prologue (kv/score split, save_partial_states launch). The + compress → norm → RoPE → store step is dispatched to a triton kernel + (``compress_norm_rope_store_triton``) by default, except for the NVIDIA + head_dim=128 indexer path which uses the cutedsl kernel + (``compress_norm_rope_store_cutedsl``) for better performance. """ def __init__( From b1a607f58cadf7549ec681cf92d7d6098b747e2a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 May 2026 02:41:40 +0000 Subject: [PATCH 5/5] fix Signed-off-by: Woosuk Kwon --- vllm/models/deepseek_v4/compressor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/models/deepseek_v4/compressor.py b/vllm/models/deepseek_v4/compressor.py index 0de30ef0ab62..3234faa5eb05 100644 --- a/vllm/models/deepseek_v4/compressor.py +++ b/vllm/models/deepseek_v4/compressor.py @@ -337,13 +337,15 @@ def forward( if current_platform.is_cuda(): # NVIDIA GPUs. - if self.head_dim == 128: + if self.head_dim == 512: from .nvidia.ops import compress_norm_rope_store_cutedsl - # Indexer path. Use a cutedsl kernel as it performs better. + # Main compressor path. + # Use a cutedsl kernel for better performance. compress_norm_rope_store_fn = compress_norm_rope_store_cutedsl else: - # Main compressor path (head_dim == 512). Use a triton kernel. + # Indexer path (head_dim == 128). + # Use a triton kernel. compress_norm_rope_store_fn = compress_norm_rope_store_triton else: # AMD GPUs.