Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/models/deepseek_v4/common/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant
from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant
from .fused_qk_rmsnorm import fused_q_kv_rmsnorm
from .save_partial_states import save_partial_states

__all__ = [
"MXFP4_BLOCK_SIZE",
Expand All @@ -20,4 +21,5 @@
"fused_inv_rope_fp8_quant",
"fused_q_kv_rmsnorm",
"quantize_and_insert_k_cache",
"save_partial_states",
]
112 changes: 78 additions & 34 deletions vllm/models/deepseek_v4/common/ops/fused_compress_quant_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@
- _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn:
head=128, MXFP4 (block=32), 4 ue8m0 bytes

Additional cutedsl kernels:
- _compress_kv_sparse_attn_cutedsl / _norm_rope_insert_sparse_attn_cutedsl:
CuTe DSL split kernels for C128
- _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl:
CuTe DSL fused kernels for C4

RoPE is register-based via tl.reshape -> tl.split -> tl.interleave (or the
even/odd halves are consumed directly for MXFP4, no interleave needed).
FP8 UE8M0 quant uses tl.reshape to tile [N_QUANT_BLOCKS, QUANT_BLOCK] for
Expand All @@ -25,43 +19,93 @@
and N_QUANT_BLOCKS ue8m0 bytes.
"""

from functools import cache
from typing import Any

import torch

from vllm.triton_utils import tl, triton

from .fused_indexer_q import _fp32x2_to_fp4x2


@cache
def _get_sparse_attn_cutedsl_impls():
from .sparse_attn_compress_cutedsl import (
_compress_kv_sparse_attn_cutedsl,
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
_norm_rope_insert_sparse_attn_cutedsl,
)

return (
_compress_kv_sparse_attn_cutedsl,
_norm_rope_insert_sparse_attn_cutedsl,
_fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
def fused_kv_compress_norm_rope_insert(
state_cache: torch.Tensor,
num_actual: int,
token_to_req_indices: torch.Tensor,
positions: torch.Tensor,
slot_mapping: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
state_width: int,
cos_sin_cache: torch.Tensor,
kv_cache: torch.Tensor,
k_cache_metadata: Any,
pdl_kwargs: dict,
head_dim: int,
rope_head_dim: int,
compress_ratio: int,
overlap: bool,
use_fp4_cache: bool,
rms_norm_weight: torch.Tensor,
rms_norm_eps: float,
quant_block: int,
token_stride: int,
scale_dim: int,
) -> None:
"""Shared triton launcher for the fused compress+norm+RoPE+insert path.

Picks one of the three kernels in this module based on ``head_dim`` and
``use_fp4_cache``. Identical launch signature for all three.
"""
if head_dim == 512:
kernel = _fused_kv_compress_norm_rope_insert_sparse_attn
num_warps = 4
elif use_fp4_cache:
kernel = _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
num_warps = 1
else:
kernel = _fused_kv_compress_norm_rope_insert_indexer_attn
num_warps = 1

kernel[(num_actual,)](
# state cache
state_cache,
state_cache.stride(0),
state_cache.stride(1),
# metadata
token_to_req_indices,
positions,
slot_mapping,
block_table,
block_table.stride(0),
block_size,
# RMSNorm
rms_norm_weight,
rms_norm_eps,
# RoPE
cos_sin_cache,
cos_sin_cache.stride(0),
# KV cache
kv_cache,
k_cache_metadata.slot_mapping,
kv_cache.shape[1], # paged KV cache block size (tokens per block)
# constexprs
HEAD_SIZE=head_dim,
TRITON_BLOCK_SIZE=triton.next_power_of_2(head_dim),
STATE_WIDTH=state_width,
COMPRESS_RATIO=compress_ratio,
OVERLAP=overlap,
ROPE_HEAD_DIM=rope_head_dim,
FP8_MAX=448.0,
QUANT_BLOCK=quant_block,
TOKEN_STRIDE=token_stride,
SCALE_DIM=scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=num_warps,
**pdl_kwargs,
)


def _compress_kv_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL sparse-attention compress wrapper."""
return _get_sparse_attn_cutedsl_impls()[0](*args, **kwargs)


def _norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL RMSNorm/RoPE/FP8-store wrapper."""
return _get_sparse_attn_cutedsl_impls()[1](*args, **kwargs)


def _fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(*args, **kwargs):
"""CuTe DSL fused C4 sparse-attention compressor wrapper."""
return _get_sparse_attn_cutedsl_impls()[2](*args, **kwargs)


# =============================================================================
# DeepseekV4 Attention path (head=512, nope=448 FP8 + rope=64 bf16)
# =============================================================================
Expand Down
101 changes: 101 additions & 0 deletions vllm/models/deepseek_v4/common/ops/save_partial_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.triton_utils import tl, triton


def save_partial_states(
kv: torch.Tensor,
score: torch.Tensor,
ape: torch.Tensor,
positions: torch.Tensor,
state_cache: torch.Tensor,
slot_mapping: torch.Tensor,
block_size: int,
state_width: int,
compress_ratio: int,
pdl_kwargs: dict | None = None,
) -> None:
"""Write packed [kv, score+ape] partial states into the compressor cache.

One program per token; pads (slot_id == -1) are skipped.
"""
num_actual = slot_mapping.shape[0]
head_size = kv.shape[-1]
_save_partial_states_kernel[(num_actual,)](
kv,
kv.stride(0),
score,
score.stride(0),
ape,
ape.stride(0),
positions,
state_cache,
state_cache.stride(0),
state_cache.stride(1),
slot_mapping,
block_size,
HEAD_SIZE=head_size,
TRITON_BLOCK_SIZE=triton.next_power_of_2(head_size),
STATE_WIDTH=state_width,
COMPRESS_RATIO=compress_ratio,
**(pdl_kwargs or {}),
)


@triton.jit
def _save_partial_states_kernel(
kv_ptr,
kv_stride,
score_ptr,
score_stride,
ape_ptr,
ape_stride,
positions_ptr,
state_cache_ptr,
state_cache_stride0,
state_cache_stride1,
slot_mapping_ptr,
block_size,
HEAD_SIZE: tl.constexpr,
TRITON_BLOCK_SIZE: tl.constexpr,
# state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide.
STATE_WIDTH: tl.constexpr,
COMPRESS_RATIO: tl.constexpr,
):
token_idx = tl.program_id(0)
slot_id = tl.load(slot_mapping_ptr + token_idx)

# Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used
# by vLLM). During CUDA graph replay the batch may contain padding
# tokens whose slot_mapping is -1; writing to kv_state[-1] would be an
# illegal memory access.
if slot_id < 0:
return

block_idx = slot_id // block_size
pos_in_block = slot_id % block_size
base_ptr = (
state_cache_ptr
+ block_idx * state_cache_stride0
+ pos_in_block * state_cache_stride1
)

block = tl.arange(0, TRITON_BLOCK_SIZE)
mask = block < HEAD_SIZE

kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask)
tl.store(base_ptr + block, kv, mask=mask)

# Fused: score += ape[position % compress_ratio]
position = tl.load(positions_ptr + token_idx)
ape_row = position % COMPRESS_RATIO
ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask)
score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask)
tl.store(
base_ptr + STATE_WIDTH + block,
score + ape,
mask=mask,
)
Loading
Loading