Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .buildkite/test_areas/kernels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ steps:
commands:
- pytest -v -s kernels/core/test_minimax_reduce_rms.py

- label: Deepseek V4 Kernel Test (H100)
key: deepseek-v4-kernel-test-h100
timeout_in_minutes: 15
device: h100
source_file_dependencies:
- csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu
- vllm/models/deepseek_v4/common/ops/
- tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py
commands:
- pytest -v -s kernels/test_fused_deepseek_v4_*.py

- label: Deepseek V4 Kernel Test (B200)
key: deepseek-v4-kernel-test-b200
timeout_in_minutes: 15
device: b200-k8s
source_file_dependencies:
- csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu
- vllm/models/deepseek_v4/common/ops/
- tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py
commands:
- pytest -v -s kernels/test_fused_deepseek_v4_*.py

- label: Kernels Attention Test %N
key: kernels-attention-test
timeout_in_minutes: 35
Expand Down
267 changes: 184 additions & 83 deletions csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);

void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::Tensor const& q_in, torch::Tensor const& kv, torch::Tensor& k_cache,
torch::Tensor const& slot_mapping, torch::Tensor const& position_ids,
torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size);
torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps,
int64_t cache_block_size);

void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
Expand Down
4 changes: 2 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// kernel launch.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor! q, Tensor kv, Tensor! k_cache, "
"Tensor q_in, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"float eps, int cache_block_size) -> ()");
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);

Expand Down
89 changes: 72 additions & 17 deletions tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,43 @@ def _op_available() -> bool:
)


def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs
def _call_fused(
q_in, q_head_padded, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs
):
return torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q_in,
kv,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
q_head_padded,
eps,
bs,
)


# ── Test 1: Q path numerical parity ──────────────────────────────────────────


@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64, 2048])
@pytest.mark.parametrize("n_heads", [8, 64])
def test_q_path_matches_reference(num_tokens: int, n_heads: int):
@pytest.mark.parametrize(
"n_heads,padded_heads",
[
# Each supported padded_heads instantiation: padded (n_heads <
# padded_heads) and unpadded (n_heads == padded_heads).
(1, 8),
(8, 8),
(8, 16),
(16, 16),
(16, 32),
(32, 32),
(8, 64),
(64, 64),
(64, 128),
],
)
def test_q_path_matches_reference(num_tokens: int, n_heads: int, padded_heads: int):
torch.manual_seed(0)
device = "cuda"
dtype = torch.bfloat16
Expand All @@ -156,10 +181,16 @@ def test_q_path_matches_reference(num_tokens: int, n_heads: int):
num_blocks, bs, HEAD_BYTES, dtype=torch.uint8, device=device
).view(num_blocks, -1)
slot_mapping = torch.full((num_tokens,), -1, dtype=torch.int64, device=device)
q_fused = q.clone()
_call_fused(q_fused, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs)
q_out = _call_fused(
q, padded_heads, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs
)

torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(q_out[:, :n_heads], q_ref, rtol=1e-2, atol=1e-2)
if n_heads < padded_heads:
pad_region = q_out[:, n_heads:padded_heads]
assert pad_region.abs().max().item() == 0.0, (
"padded head slots must be exact zero"
)


# ── Test 2: KV path round-trip byte/value parity ─────────────────────────────
Expand Down Expand Up @@ -201,11 +232,12 @@ def test_kv_path_matches_reference(num_tokens: int, block_size: int):
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
)

# ── Fused path (dummy q, single head) ──────────────────────────────────
# ── Fused path (dummy q, padded to FlashMLA's min head count 64) ───────
k_cache_fused = torch.zeros_like(k_cache_ref)
q_dummy = torch.zeros(num_tokens, 1, HEAD_DIM, dtype=dtype, device=device)
_call_fused(
_ = _call_fused(
q_dummy,
64,
kv,
k_cache_fused,
slot_mapping,
Expand Down Expand Up @@ -298,8 +330,9 @@ def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int):
# Fused: pass full-sized q/kv/positions, shorter slot_mapping.
q_dummy = torch.zeros(total, 1, HEAD_DIM, dtype=dtype, device=device)
k_cache_fused = torch.zeros_like(k_cache_ref)
_call_fused(
_ = _call_fused(
q_dummy,
64,
kv,
k_cache_fused,
slot_mapping,
Expand All @@ -316,9 +349,26 @@ def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int):


@pytest.mark.parametrize("num_tokens", [1, 4, 17, 2048])
@pytest.mark.parametrize("n_heads", [8, 64])
@pytest.mark.parametrize(
"n_heads,padded_heads",
[
# Each supported padded_heads instantiation: padded (n_heads <
# padded_heads) and unpadded (n_heads == padded_heads).
(1, 8),
(8, 8),
(8, 16),
(16, 16),
(16, 32),
(32, 32),
(8, 64),
(64, 64),
(64, 128),
],
)
@pytest.mark.parametrize("block_size", [16, 64])
def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int):
def test_combined_q_and_kv(
num_tokens: int, n_heads: int, padded_heads: int, block_size: int
):
torch.manual_seed(2)
device = "cuda"
dtype = torch.bfloat16
Expand All @@ -345,10 +395,10 @@ def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int):
)

# Fused single call.
q_fused = q.clone()
k_cache_fused = torch.zeros_like(k_cache_ref)
_call_fused(
q_fused,
q_out = _call_fused(
q,
padded_heads,
kv,
k_cache_fused,
slot_mapping,
Expand All @@ -358,5 +408,10 @@ def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int):
block_size,
)

torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(q_out[:, :n_heads], q_ref, rtol=1e-2, atol=1e-2)
if n_heads < padded_heads:
pad_region = q_out[:, n_heads:padded_heads]
assert pad_region.abs().max().item() == 0.0, (
"padded head slots must be exact zero"
)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
4 changes: 2 additions & 2 deletions tests/kernels/test_fused_inv_rope_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
via per_token_group_quant_fp8).

The reference faithfully reproduces the exact flow in
deepseek_v4/nvidia/ops/attention.py:295-310:
deepseek_v4/attention.py:295-310:
1. Apply inverse RoPE (NeoX style, last rope_dim=64 dims of each head)
2. Reshape [T, H, head_dim] -> [T, G, D]
3. Transpose+flatten to [G*T, D], quantize, reshape back
Expand Down Expand Up @@ -668,7 +668,7 @@ def _unfused_inv_rope_fp8_quant(
nope_dim: int = NOPE_DIM,
rope_dim: int = ROPE_DIM,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Unfused path matching deepseek_v4/nvidia/ops/attention.py:295-310.
"""Unfused path matching deepseek_v4/attention.py:295-310.

Uses the production CUDA RoPE kernel + per_token_group_quant_fp8.
"""
Expand Down
6 changes: 5 additions & 1 deletion vllm/models/deepseek_v4/amd/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.v1.worker.workspace import current_workspace_manager

if TYPE_CHECKING:
from vllm.models.deepseek_v4.nvidia.ops.attention import (
from vllm.models.deepseek_v4.attention import (
DeepseekV4MLAAttention,
)

Expand Down Expand Up @@ -592,6 +592,10 @@ class DeepseekV4ROCMAiterMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):

backend_cls = DeepseekV4ROCMAiterMLASparseBackend

@classmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
return num_heads
Comment on lines +596 to +597

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

On ROCm, get_padded_num_q_heads currently returns num_heads without any padding. However, the horizontally-fused preprocessing kernel (fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert) is shared between NVIDIA and ROCm, and its dispatch logic (in csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu) only supports kNumHeadsQPadded values of 64 and 128. This will cause a runtime crash on ROCm for any configuration where num_heads is not exactly 64 or 128 (e.g., TP=4 where num_heads=32). To ensure compatibility with the fused kernel's template instantiations, ROCm should use the same padding logic as NVIDIA.

    @classmethod
    def get_padded_num_q_heads(cls, num_heads: int) -> int:
        # Match fused kernel dispatch requirements (64 or 128 heads)
        if num_heads > 128:
            raise ValueError(
                f"DeepseekV4 ROCm does not support {num_heads} heads "
                "(Fused kernel requires h_q in {64, 128}).")
        return 64 if num_heads <= 64 else 128


@classmethod
def forward_mqa( # type: ignore[override]
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,6 @@ def __init__(
self.head_dim = head_dim
self.scale = scale

# FlashMLA sparse kernel only supports 64 or 128 heads; pad up to the
# next supported size. Must match DeepseekV4MLAAttention.padded_heads.
if num_heads <= 64:
self.padded_heads = 64
elif num_heads <= 128:
self.padded_heads = 128
else:
raise ValueError(
f"DeepseekV4 attention does not support {num_heads} heads "
"(must be <= 128)."
)

self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.window_size = window_size
Expand Down Expand Up @@ -263,6 +251,9 @@ def __init__(
indexer=self.indexer,
topk_indices_buffer=self.topk_indices_buffer,
)
# Mirror the inner layer's padded head count (single source of truth).
self.padded_heads = self.mla_attn.padded_heads

# Register this layer in the compilation config's static forward context
# This allows the custom op to retrieve the layer during execution
compilation_config = mla_modules.vllm_config.compilation_config
Expand Down Expand Up @@ -450,7 +441,7 @@ def attention_impl(

def wq_b_kv_insert() -> torch.Tensor:
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
return q

# 3-way overlap (matches TRT-LLM PR #14142 Level 1): default runs
Expand Down Expand Up @@ -484,7 +475,7 @@ def wq_b_kv_insert() -> torch.Tensor:

def wq_b_kv_insert() -> torch.Tensor:
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
return q

q, _ = maybe_execute_in_parallel(
Expand All @@ -497,12 +488,7 @@ def wq_b_kv_insert() -> torch.Tensor:
else:
# SWA-only layer: no compressor, no overlap.
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)

# Pad q to FlashMLA-required head count (64 or 128)
if self.n_local_heads < self.padded_heads:
pad_size = self.padded_heads - self.n_local_heads
q = F.pad(q, (0, 0, 0, pad_size), value=0.0)
q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)

# MLA attention writes into the pre-allocated `out` buffer
# ([num_tokens, padded_heads, head_dim]).
Expand All @@ -516,9 +502,17 @@ def _fused_qnorm_rope_kv_insert(
attn_metadata: (
dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None
),
) -> None:
) -> torch.Tensor:
if not isinstance(attn_metadata, dict):
return
# Profile run: kernel doesn't fire; produce a padded tensor so
# downstream FlashMLA gets the right shape.
if self.n_local_heads < self.padded_heads:
return F.pad(
q,
(0, 0, 0, self.padded_heads - self.n_local_heads),
value=0.0,
)
return q

swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
Expand All @@ -530,16 +524,19 @@ def _fused_qnorm_rope_kv_insert(
swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1)

# Horizontally fused:
# Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE
# Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE,
# with zero-fill for the padding head slots. The kernel
# allocates and returns the padded q tensor.
# KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert
# kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
return torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q,
kv,
swa_kv_cache_2d,
swa_metadata.slot_mapping,
positions.to(torch.int64),
self.rotary_emb.cos_sin_cache,
self.padded_heads,
self.eps,
swa_metadata.block_size,
)
Expand Down Expand Up @@ -607,9 +604,6 @@ def deepseek_v4_fp8_einsum_fake(


class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
# FlashMLA FP8 sparse only supports 64 or 128 heads
SUPPORTED_HEAD_COUNTS = (64, 128)

def __init__(
self,
num_heads: int,
Expand Down Expand Up @@ -655,19 +649,8 @@ def __init__(
self.aux_stream = aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]

# Determine padded head count for FlashMLA
if num_heads not in self.SUPPORTED_HEAD_COUNTS:
if num_heads < 64:
self.padded_heads = 64
elif num_heads < 128:
self.padded_heads = 128
else:
raise ValueError(
f"DeepseekV4MLAAttention does not support {num_heads} heads. "
f"Supported: <= 128 (will be padded to 64 or 128)"
)
else:
self.padded_heads = num_heads
# Padded Q head count is dictated by the selected impl.
self.padded_heads = self.impl_cls.get_padded_num_q_heads(num_heads)

# Store attention sink
assert attn_sink is not None
Expand Down
Loading
Loading