Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions vllm/models/deepseek_v4/common/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant
from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant
from .fused_qk_rmsnorm import fused_q_kv_rmsnorm
from .save_partial_states import save_partial_states

__all__ = [
"MXFP4_BLOCK_SIZE",
Expand All @@ -21,5 +20,4 @@
"fused_inv_rope_fp8_quant",
"fused_q_kv_rmsnorm",
"quantize_and_insert_k_cache",
"save_partial_states",
]
2 changes: 1 addition & 1 deletion vllm/models/deepseek_v4/common/ops/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def dequantize_and_gather_k_cache(
) -> None:
if has_cutedsl():
# lazily import, otherwise some tests fail due to CUDA driver init failure.
from vllm.models.deepseek_v4.nvidia.ops import (
from vllm.models.deepseek_v4.nvidia.ops.dequant_gather_k_cutedsl import (
dequantize_and_gather_k_cache_cutedsl,
)

Expand Down
112 changes: 34 additions & 78 deletions vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
- _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn:
head=128, MXFP4 (block=32), 4 ue8m0 bytes

Additional cutedsl kernels:
- _compress_kv_sparse_attn_cutedsl / _norm_rope_insert_sparse_attn_cutedsl:
CuTe DSL split kernels for C128
- _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl:
CuTe DSL fused kernels for C4

RoPE is register-based via tl.reshape -> tl.split -> tl.interleave (or the
even/odd halves are consumed directly for MXFP4, no interleave needed).
FP8 UE8M0 quant uses tl.reshape to tile [N_QUANT_BLOCKS, QUANT_BLOCK] for
Expand All @@ -19,92 +25,42 @@
and N_QUANT_BLOCKS ue8m0 bytes.
"""

from typing import Any

import torch
from functools import cache

from vllm.triton_utils import tl, triton

from .fused_indexer_q import _fp32x2_to_fp4x2


def compress_norm_rope_store_triton(
state_cache: torch.Tensor,
num_actual: int,
token_to_req_indices: torch.Tensor,
positions: torch.Tensor,
slot_mapping: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
state_width: int,
cos_sin_cache: torch.Tensor,
kv_cache: torch.Tensor,
k_cache_metadata: Any,
pdl_kwargs: dict,
head_dim: int,
rope_head_dim: int,
compress_ratio: int,
overlap: bool,
use_fp4_cache: bool,
rms_norm_weight: torch.Tensor,
rms_norm_eps: float,
quant_block: int,
token_stride: int,
scale_dim: int,
) -> None:
"""Shared triton launcher for the fused compress+norm+RoPE+insert path.

Picks one of the three kernels in this module based on ``head_dim`` and
``use_fp4_cache``. Identical launch signature for all three.
"""
if head_dim == 512:
kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
num_warps = 4
elif use_fp4_cache:
kernel = _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
num_warps = 1
else:
kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
num_warps = 1

kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
rms_norm_weight,
rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=compress_ratio,
OVERLAP=overlap,
ROPE_HEAD_DIM=rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=quant_block,
TOKEN_STRIDE=token_stride,
SCALE_DIM=scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=num_warps,
**pdl_kwargs,
@cache
def _get_sparse_attn_cutedsl_impls():
from .sparse_attn_compress_cutedsl import (
_compress_kv_sparse_attn_cutedsl,
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
_norm_rope_insert_sparse_attn_cutedsl,
)

return (
_compress_kv_sparse_attn_cutedsl,
_norm_rope_insert_sparse_attn_cutedsl,
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
)


def _compress_kv_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL sparse-attention compress wrapper."""
return _get_sparse_attn_cutedsl_impls()[0](*args, **kwargs)


def _norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL RMSNorm/RoPE/FP8-store wrapper."""
return _get_sparse_attn_cutedsl_impls()[1](*args, **kwargs)


def _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL fused C4 sparse-attention compressor wrapper."""
return _get_sparse_attn_cutedsl_impls()[2](*args, **kwargs)


# =============================================================================
# DeepseekV4 Attention path (head=512, nope=448 FP8 + rope=64 bf16)
Expand Down
4 changes: 2 additions & 2 deletions vllm/models/deepseek_v4/common/ops/fused_indexer_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def fused_indexer_q_rope_quant(
)
if has_cutedsl():
# lazily import, otherwise some tests fail due to CUDA driver init failure.
from vllm.models.deepseek_v4.nvidia.ops import (
from vllm.models.deepseek_v4.nvidia.ops.fused_indexer_q_cutedsl import (
fused_indexer_q_rope_quant_mxfp4_cutedsl,
)

Expand Down Expand Up @@ -400,7 +400,7 @@ def fused_indexer_q_rope_quant(
index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn)
if has_cutedsl():
# lazily import, otherwise some tests fail due to CUDA driver init failure.
from vllm.models.deepseek_v4.nvidia.ops import (
from vllm.models.deepseek_v4.nvidia.ops.fused_indexer_q_cutedsl import (
fused_indexer_q_rope_quant_fp8_cutedsl,
)

Expand Down
101 changes: 0 additions & 101 deletions vllm/models/deepseek_v4/common/ops/save_partial_states.py

This file was deleted.

Loading
Loading