Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions vllm/models/deepseek_v4/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
_compress_kv_sparse_attn_cutedsl,
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn,
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
_norm_rope_insert_sparse_attn_cutedsl,
)
from vllm.models.deepseek_v4.common.ops.fused_indexer_q import MXFP4_BLOCK_SIZE
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_cutedsl
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
Expand Down Expand Up @@ -242,13 +244,21 @@ def __init__(
assert not use_fp4_cache, (
"MXFP4 cache is only supported for indexer (head=128)"
)
self._use_cutedsl_sparse_compressor = True
self._use_cutedsl_fused_sparse_compressor = self.compress_ratio == 4
self._compress_kernel = _compress_kv_sparse_attn_cutedsl
self._norm_rope_store_kernel = _norm_rope_insert_sparse_attn_cutedsl
self._fused_sparse_kernel = (
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl
)
self._use_cutedsl_sparse_compressor = has_cutedsl()
if self._use_cutedsl_sparse_compressor:
self._use_cutedsl_fused_sparse_compressor = self.compress_ratio == 4
self._compress_kernel = _compress_kv_sparse_attn_cutedsl
self._norm_rope_store_kernel = _norm_rope_insert_sparse_attn_cutedsl
self._fused_sparse_kernel = (
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl
)
self._compressed_kv_buffer = self._get_compressed_kv_buffer(
self.device,
vllm_config.scheduler_config.max_num_batched_tokens,
self.head_dim,
)
else:
self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
self._quant_block = 64
self._token_stride = self.nope_head_dim + self.rope_head_dim * 2
self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad
Expand Down
Loading