Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0fb4233
Port DeepSeek V4 FlashInfer sparse MLA
PerkzZheng May 9, 2026
6d0ccbd
Optimize DeepSeek V4 FlashInfer decode
PerkzZheng May 9, 2026
1e2a685
Fix DeepSeek V4 FlashInfer FP8 cache scaling
PerkzZheng May 9, 2026
32934d1
Optimize DeepSeek V4 FlashInfer FP8 decode quantization
PerkzZheng May 9, 2026
53fe443
Optimize DeepSeek V4 FlashInfer FP8 sparse MLA
PerkzZheng May 11, 2026
cc5ec7b
Use one FlashInfer call for mixed DSV4 FP8 batches
PerkzZheng May 11, 2026
db6aff8
Unify DeepSeek V4 FlashInfer attention path
PerkzZheng May 11, 2026
377fa2b
Clean up DeepSeek V4 FlashInfer FP8 path
PerkzZheng May 11, 2026
9c8f3b0
Fix DeepSeek V4 post-rebase test coverage
PerkzZheng May 11, 2026
706e92d
Preserve DeepSeek V4 Flash accuracy after rebase
PerkzZheng May 11, 2026
159b9b7
Rename DeepSeek V4 per-tensor FP8 KV cache dtype
PerkzZheng May 11, 2026
f7510df
Clean DeepSeek V4 post-rebase indexer path
PerkzZheng May 12, 2026
0de10a2
Clean DeepSeek V4 FlashInfer sparse path
PerkzZheng May 12, 2026
250070e
Fix merged DeepSeek V4 sparse compressor compile
PerkzZheng May 12, 2026
ed87b13
Clean DeepSeek V4 FlashInfer metadata path
PerkzZheng May 12, 2026
7fc2ba0
Fix DeepSeek V4 FlashInfer padded graph tokens
PerkzZheng May 12, 2026
5f10450
Clean DeepSeek V4 FlashInfer sparse attention path
PerkzZheng May 19, 2026
b58aafe
Allow async scheduling for DeepSeek V4
PerkzZheng May 20, 2026
27b39dd
Fix sparse MLA pre-commit issues
PerkzZheng May 20, 2026
60a4b22
Restore DeepSeek V4 slot mapping test to main
PerkzZheng May 20, 2026
057dcc0
Use CUDA full-cache FP8 insert for DeepSeek V4
May 21, 2026
c6dc5d2
Call DeepSeek V4 full-cache FP8 op directly
May 21, 2026
ddbbdc3
Restore vLLM config from PR base
PerkzZheng May 21, 2026
8f0603b
Use CUDA fused insert for full KV cache
PerkzZheng May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
404 changes: 365 additions & 39 deletions csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size);

void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
torch::Tensor const& q, torch::Tensor const& kv, torch::Tensor& q_fp8,
torch::Tensor& k_cache, torch::Tensor const& slot_mapping,
torch::Tensor const& position_ids, torch::Tensor const& cos_sin_cache,
torch::Tensor const& fp8_scale, torch::Tensor const& q_fp8_scale_inv,
double eps, int64_t cache_block_size);

void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size);

void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,
Expand Down
22 changes: 22 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);

// Full-cache per-tensor FP8 variant for FlashInfer sparse MLA. Reuses the
// same CUDA warp-slot kernel structure as the legacy UE8M0 op, but writes Q
// to a separate FP8 tensor and KV into a full 512-wide FP8 paged cache.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert("
"Tensor q, Tensor kv, Tensor! q_fp8, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"Tensor fp8_scale, Tensor q_fp8_scale_inv, float eps, "
"int cache_block_size) -> ()");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert);

ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert("
"Tensor! q, Tensor kv, Tensor! k_cache, Tensor slot_mapping, "
"Tensor position_ids, Tensor cos_sin_cache, float eps, "
"int cache_block_size) -> ()");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert);

// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
Expand Down
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ MLA decode backends are selected using the standard
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_per_tensor`, `fp8_inc`, `fp8_ds_mla`, `fp8_e4m3` | 64 | 512, 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
Expand Down
247 changes: 246 additions & 1 deletion tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def apply_rope_gptj_last_k(
nope_dim = head_dim - rope_dim

# Gather cos/sin for each token position: [num_tokens, rope_dim]
cs = cos_sin_cache[positions].to(torch.float32) # [N, rope_dim]
cs = cos_sin_cache[positions.long()].to(torch.float32) # [N, rope_dim]
cos = cs[..., :half] # [N, half]
sin = cs[..., half:] # [N, half]

Expand Down Expand Up @@ -113,6 +113,18 @@ def _op_available() -> bool:
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")


def _full_cache_fp8_op_available() -> bool:
return hasattr(
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert"
)


def _full_cache_bf16_op_available() -> bool:
return hasattr(
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert"
)


pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or not _op_available(),
reason="CUDA not available or fused DeepseekV4 op not built in",
Expand All @@ -125,6 +137,109 @@ def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs)
)


def _call_full_cache_fp8_fused(
q,
kv,
q_fp8,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
bs,
):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
q,
kv,
q_fp8,
k_cache,
slot_mapping,
positions.long(),
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
bs,
)


def _call_full_cache_bf16_fused(
q,
kv,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
eps,
bs,
):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
q,
kv,
k_cache,
slot_mapping,
positions.long(),
cos_sin_cache,
eps,
bs,
)


def _fp8_full_cache_reference(
q,
kv,
k_cache,
q_fp8,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
fp8_scale,
q_fp8_scale_inv,
):
q_ref = rmsnorm_no_weight(q, eps)
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
q_fp8.copy_(
torch.clamp(q_ref.float() * q_fp8_scale_inv, -FP8_MAX, FP8_MAX).to(
torch.float8_e4m3fn
)
)

kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
valid = slot_mapping >= 0
slots = slot_mapping[valid]
block_idx = slots // block_size
pos_in_block = slots % block_size
k_cache[block_idx, pos_in_block] = torch.clamp(
kv_ref[valid].float() / fp8_scale, -FP8_MAX, FP8_MAX
).to(torch.float8_e4m3fn)


def _bf16_full_cache_reference(
q,
kv,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
):
q_ref = rmsnorm_no_weight(q, eps)
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)

kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
valid = slot_mapping >= 0
slots = slot_mapping[valid]
block_idx = slots // block_size
pos_in_block = slots % block_size
k_cache[block_idx, pos_in_block] = kv_ref[valid]
return q_ref


# ── Test 1: Q path numerical parity ──────────────────────────────────────────


Expand Down Expand Up @@ -357,3 +472,133 @@ def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int):

torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)


@pytest.mark.skipif(
not _full_cache_fp8_op_available(),
reason="full-cache per-tensor FP8 DeepseekV4 op not built in",
)
@pytest.mark.parametrize("num_tokens", [4, 17])
@pytest.mark.parametrize("n_heads", [8, 17])
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
def test_full_cache_per_tensor_fp8_matches_reference(
num_tokens: int,
n_heads: int,
positions_dtype: torch.dtype,
):
torch.manual_seed(4)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
block_size = 16
max_pos = 4096

q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)

num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
fp8_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
q_fp8_scale_inv = torch.tensor([1.0], dtype=torch.float32, device=device)

q_fp8_ref = torch.empty_like(q, dtype=torch.float8_e4m3fn)
q_fp8_fused = torch.empty_like(q, dtype=torch.float8_e4m3fn)
k_cache_ref = torch.zeros(
num_blocks, block_size, HEAD_DIM, dtype=torch.float8_e4m3fn, device=device
)
k_cache_fused = torch.zeros_like(k_cache_ref)

_fp8_full_cache_reference(
q,
kv,
k_cache_ref,
q_fp8_ref,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
fp8_scale,
q_fp8_scale_inv,
)

_call_full_cache_fp8_fused(
q.clone(),
kv,
q_fp8_fused,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
block_size,
)

torch.testing.assert_close(
q_fp8_fused.float(), q_fp8_ref.float(), rtol=0, atol=0.25
)
torch.testing.assert_close(
k_cache_fused.float(), k_cache_ref.float(), rtol=0, atol=0.25
)


@pytest.mark.skipif(
not _full_cache_bf16_op_available(),
reason="full-cache BF16 DeepseekV4 op not built in",
)
@pytest.mark.parametrize("num_tokens", [4, 17])
@pytest.mark.parametrize("n_heads", [8, 17])
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
def test_full_cache_bf16_matches_reference(
num_tokens: int,
n_heads: int,
positions_dtype: torch.dtype,
):
torch.manual_seed(5)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
block_size = 16
max_pos = 4096

q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)

num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)

q_fused = q.clone()
k_cache_ref = torch.zeros(
num_blocks, block_size, HEAD_DIM, dtype=torch.bfloat16, device=device
)
k_cache_fused = torch.zeros_like(k_cache_ref)
q_ref = _bf16_full_cache_reference(
q,
kv,
k_cache_ref,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)

_call_full_cache_bf16_fused(
q_fused,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)

torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
3 changes: 1 addition & 2 deletions tests/kernels/test_fused_inv_rope_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups):
This catches stride/layout bugs that only manifest when the einsum
kernel actually consumes the quantized activations.
"""
from deep_gemm.testing import calc_diff
from deep_gemm.utils.math import ceil_div

from vllm.utils.deep_gemm import (
Expand Down Expand Up @@ -809,8 +810,6 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups):
# Einsum output: Triton and CUDA both rotate in fp32 now, so diffs
# come from fp32 ordering and UE8M0 boundary shifts only.
# Use relative diff (same metric as test_fp8_einsum.py).
from deep_gemm.testing import calc_diff

z_diff = calc_diff(z_fused, z_ref)
assert z_diff < 0.01, (
f"Einsum output diff too large: {z_diff:.6f} (expected < 0.01)"
Expand Down
12 changes: 12 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import vllm.config.vllm as vllm_config_module
from vllm.compilation.backends import VllmBackend
from vllm.config import (
CacheConfig,
CompilationConfig,
KernelConfig,
ModelConfig,
Expand All @@ -33,10 +34,21 @@
OptimizationLevel,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE

DEVICE_TYPE = current_platform.device_type


def test_fp8_per_tensor_cache_dtype():
cfg = CacheConfig(cache_dtype="fp8_per_tensor")

assert cfg.cache_dtype == "fp8_per_tensor"
assert (
STR_DTYPE_TO_TORCH_DTYPE["fp8_per_tensor"]
is STR_DTYPE_TO_TORCH_DTYPE["fp8_inc"]
)


def test_compile_config_repr_succeeds():
# setup: VllmBackend mutates the config object
config = VllmConfig()
Expand Down
1 change: 1 addition & 0 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"fp8_per_tensor",
"fp8_inc",
"fp8_ds_mla",
"turboquant_k8v4",
Expand Down
Loading