Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions csrc/topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,73 @@ void launch_persistent_topk(const torch::Tensor& logits,
size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t);
if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium;

// Query occupancy for the instantiation that will actually launch;
// overestimating it deadlocks the cooperative barrier.
int occupancy = 1;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<TopK, 4>, P::kThreadsPerBlock,
smem_size);
cudaError_t occ_err = cudaSuccess;
if (vec_size == 4) {
occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<TopK, 4>, P::kThreadsPerBlock,
smem_size);
} else if (vec_size == 2) {
occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<TopK, 2>, P::kThreadsPerBlock,
smem_size);
} else {
occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, P::persistent_topk_kernel<TopK, 1>, P::kThreadsPerBlock,
smem_size);
}
TORCH_CHECK(occ_err == cudaSuccess,
"persistent_topk occupancy query failed: ",
cudaGetErrorString(occ_err));
if (occupancy < 1) occupancy = 1;

uint32_t max_resident_ctas = static_cast<uint32_t>(num_sms) * occupancy;
// The cooperative spin-wait barrier only runs when at least one row hits
// the radix path (seq_len > RADIX_THRESHOLD). Below that, non-CTA-0 CTAs
// early-exit, so oversubscription can't deadlock and headroom is wasted.
const bool needs_cooperative =
static_cast<uint32_t>(max_seq_len) > P::RADIX_THRESHOLD;

const uint32_t hw_resident_cap =
static_cast<uint32_t>(num_sms) * static_cast<uint32_t>(occupancy);
uint32_t max_resident_ctas = hw_resident_cap;
if (needs_cooperative) {
// Reserve one CTA per SM when occupancy allows; fall back to a single
// CTA when occupancy == 1 (the most deadlock-prone case — any straggler
// kernel that takes the only slot on one SM hangs the barrier). Never
// drop below one full group's worth.
uint32_t headroom = (occupancy > 1) ? static_cast<uint32_t>(num_sms) : 1u;
if (max_resident_ctas >= headroom + ctas_per_group) {
max_resident_ctas -= headroom;
}
}
uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group,
static_cast<uint32_t>(num_rows));
if (num_groups == 0) num_groups = 1;
uint32_t total_ctas = num_groups * ctas_per_group;

// If the cooperative launch wouldn't fit, fall back to FilteredTopK
// instead of deadlocking. Only relevant when needs_cooperative.
if (needs_cooperative && total_ctas > hw_resident_cap) {
TORCH_CHECK(max_smem_per_block >= 128 * 1024,
"persistent_topk would oversubscribe and the FilteredTopK "
"fallback requires >=128KB smem per block (have ",
max_smem_per_block, "). total_ctas=", total_ctas,
" > num_sms*occupancy=", hw_resident_cap, " (TopK=", TopK,
", vec_size=", vec_size, ", ctas_per_group=", ctas_per_group,
", smem=", smem_size, ").");
cudaError_t status =
vllm::FilteredTopKRaggedTransform<float, int32_t, TopK>(
logits.data_ptr<float>(), output.data_ptr<int32_t>(),
lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
static_cast<uint32_t>(TopK), static_cast<uint32_t>(stride),
stream);
TORCH_CHECK(status == cudaSuccess,
"FilteredTopK fallback failed: ", cudaGetErrorString(status));
return;
}

size_t state_bytes = num_groups * sizeof(P::RadixRowState);
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
"workspace too small, need ", state_bytes, " bytes");
Expand Down
2 changes: 1 addition & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ th {
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ht.DeepEPHTPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ll.DeepEPLLPrepareAndFinalize] |
| flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_two_sided.FlashInferNVLinkTwoSidedPrepareAndFinalize] |
| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] |
| flashinfer_nvlink_one_sided | standard | nvfp4,bf16,mxfp8 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] |

!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
Expand Down
7 changes: 5 additions & 2 deletions tests/compile/h100/test_startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def _run_vllm(vllm_runner):
mode=CompilationMode.VLLM_COMPILE,
cudagraph_mode=CUDAGraphMode.NONE,
),
num_gpu_blocks_override=8,
# Phi-tiny-MoE uses SWA, whose admission cap is `cdiv(L, block_size) + 1`
# at default block_size=16 — i.e. 17 blocks for max_model_len=256. Use
# 32 for headroom.
num_gpu_blocks_override=32,
):
pass

Expand Down Expand Up @@ -190,7 +193,7 @@ def _run_model(vllm_runner, spec: ModelStartupSpec):
cudagraph_mode=CUDAGraphMode.NONE,
pass_config=PassConfig(fuse_allreduce_rms=False),
),
num_gpu_blocks_override=8,
num_gpu_blocks_override=16,
):
pass

Expand Down
3 changes: 3 additions & 0 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ def test_should_split():
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
# truncated to nearest multiple of 8 or 16
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
# max_num_batched_tokens <= max_cudagraph_capture_size should always be
# captured even if not landing on a 16-stride step
(None, 2048, 1, False, 257, CUDAGraphMode.FULL_AND_PIECEWISE, 257),
# max from list
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
# SP forces full-graph compilation, sizes are filtered by TP
Expand Down
232 changes: 228 additions & 4 deletions tests/kernels/test_compressor_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
"""
Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant.

Two paths tested:
Four test functions cover five paths:
A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64
B) Indexer: head_dim=128 (all FP8), quant_block=128

These serve as golden references for validating the future fused
compressor+quant+cache kernel.
C) DeepseekV4 Attention magnitude range: correctness across small/large values
D) Indexer fused Triton kernel: compress+norm+rope+quant+insert
"""

import math
Expand All @@ -21,6 +20,12 @@
dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import (
_fused_kv_compress_norm_rope_insert_indexer_attn,
_fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn,
)

from .test_fused_indexer_q_rope_quant import quantize_to_mxfp4


def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float):
Expand Down Expand Up @@ -309,3 +314,222 @@ def test_deepseek_v4_quant_magnitude_range():
f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, "
f"magnitude={magnitude:.4f}"
)


# ── Test D: Indexer fused K-cache insert (Triton kernels) ────────────────────
#
# Both kernels share the same Triton signature; use_fp4 selects between them.
# Full pipeline: state-cache gather → softmax-weighted compress → RMSNorm →
# GPT-J RoPE → quant (MXFP4 or FP8) → paged cache insert.


def _reference_kv_compress_norm_rope(
state_cache: torch.Tensor,
block_table: torch.Tensor,
positions: torch.Tensor,
rms_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
compress_ratio: int = 1,
overlap: int = 0,
use_fp4: bool = False,
rms_eps: float = 1e-6,
fp8_max: float = 448.0,
):
"""Compress → RMSNorm → GPT-J RoPE → quantize.

Gathers (1+overlap)*compress_ratio state entries per output token, applies
per-element softmax over the scores, and computes the weighted kv sum.
Returns (quantized_values, scale) matching the kernel's output layout.
"""
device = state_cache.device
head_dim = rms_weight.shape[0]
rope_dim = cos_sin_cache.shape[-1]
state_block_size = state_cache.shape[1]
state_width = state_cache.shape[-1] // 2
nope_dim = head_dim - rope_dim
total = (1 + overlap) * compress_ratio
results = []
for pos in positions.tolist():
src = torch.arange(pos - total + 1, pos + 1, dtype=torch.int64, device=device)
valid = src >= 0
idx = src.clamp(min=0)
pages = block_table[0, idx // state_block_size]
offsets = idx % state_block_size
raw = state_cache[pages, offsets].float() # [total, state_dim]

# Group 0 (tokens 0..cr-1): kv[:H], score[SW:SW+H]
# Group 1 (tokens cr..2cr-1): kv[H:2H], score[SW+H:SW+2H]
if overlap:
sw = state_width
g0_kv = raw[:compress_ratio, :head_dim]
g1_kv = raw[compress_ratio:, head_dim : 2 * head_dim]
g0_scores = raw[:compress_ratio, sw : sw + head_dim]
g1_scores = raw[compress_ratio:, sw + head_dim : sw + 2 * head_dim]
kv = torch.cat([g0_kv, g1_kv])
scores = torch.cat([g0_scores, g1_scores])
else:
kv = raw[:, :head_dim]
scores = raw[:, state_width : state_width + head_dim]

scores[~valid] = float("-inf")
kv[~valid] = 0.0
weights = torch.softmax(scores, dim=0)
compressed = (kv * weights).sum(dim=0) # [H]
var = (compressed * compressed).mean()
normed = compressed * torch.rsqrt(var + rms_eps) * rms_weight.float()
compressed_pos = (pos // compress_ratio) * compress_ratio
cos, sin = cos_sin_cache[compressed_pos].float().chunk(2)
nope, rope = normed.split([nope_dim, rope_dim])
rope = torch.stack(
[rope[0::2] * cos - rope[1::2] * sin, rope[1::2] * cos + rope[0::2] * sin],
dim=-1,
).reshape(rope_dim)
results.append(torch.cat([nope, rope]).to(state_cache.dtype))
result = torch.stack(results)

if use_fp4:
return quantize_to_mxfp4(result)
else:
pairs = [
_ue8m0_reference(result[t], head_dim, fp8_max) for t in range(len(result))
]
quants, scales = zip(*pairs)
return torch.stack(quants), torch.cat(scales)


@pytest.mark.parametrize("num_tokens", [1, 7, 32])
@pytest.mark.parametrize("kv_block_size", [16, 32])
@pytest.mark.parametrize("use_fp4", [False, True])
def test_fused_kv_insert_indexer(num_tokens: int, kv_block_size: int, use_fp4: bool):
"""Fused K compress+norm+rope+quant+insert for the indexer KV cache."""
HEAD_DIM = 128
ROPE_DIM = 64
BLOCK_SIZE = 16
RMS_EPS = 1e-6
FP8_MAX = 448.0

device = "cuda"
torch.manual_seed(42)
compress_ratio = 4

if use_fp4:
TOKEN_STRIDE = HEAD_DIM // 2 # packed nibbles: 64 bytes
SCALE_DIM = HEAD_DIM // 32 # ue8m0 bytes: 4
QUANT_BLOCK = 32
kernel = _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn
else:
TOKEN_STRIDE = HEAD_DIM # FP8 bytes: 128
SCALE_DIM = 4 # 1 float32: 4 bytes
QUANT_BLOCK = HEAD_DIM
kernel = _fused_kv_compress_norm_rope_insert_indexer_attn

# overlap=1 whenever compress_ratio==4, matching DeepseekCompressor logic.
overlap = 1 if compress_ratio == 4 else 0
coff = 1 + overlap # multiplier for state_dim per entry

num_pages = (compress_ratio * num_tokens - 1) // BLOCK_SIZE + 2
state_cache = torch.randn(
num_pages,
BLOCK_SIZE,
2 * coff * HEAD_DIM, # kv_state + score_state, each coff*HEAD_DIM wide
dtype=torch.bfloat16,
device=device,
)
block_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0)
token_to_req = torch.zeros(num_tokens, dtype=torch.int32, device=device)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
positions = torch.arange(
compress_ratio - 1,
compress_ratio * num_tokens,
compress_ratio,
dtype=torch.int64,
device=device,
)
rms_weight = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
cos_sin_cache = torch.randn(compress_ratio * num_tokens, ROPE_DIM, device=device)

kv_n_blocks = (num_tokens + kv_block_size - 1) // kv_block_size + 1
kv_cache = torch.zeros(
kv_n_blocks,
kv_block_size * (TOKEN_STRIDE + SCALE_DIM),
dtype=torch.uint8,
device=device,
)

kernel[(num_tokens,)](
state_cache,
state_cache.stride(0),
state_cache.stride(1),
token_to_req,
positions,
slot_mapping,
block_table,
block_table.stride(0),
BLOCK_SIZE,
rms_weight,
RMS_EPS,
cos_sin_cache,
cos_sin_cache.stride(0),
kv_cache,
slot_mapping,
kv_block_size,
HEAD_SIZE=HEAD_DIM,
TRITON_BLOCK_SIZE=HEAD_DIM,
STATE_WIDTH=coff * HEAD_DIM,
COMPRESS_RATIO=compress_ratio,
OVERLAP=overlap,
ROPE_HEAD_DIM=ROPE_DIM,
FP8_MAX=FP8_MAX,
QUANT_BLOCK=QUANT_BLOCK,
TOKEN_STRIDE=TOKEN_STRIDE,
SCALE_DIM=SCALE_DIM,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=1,
)

k_quant, scale = _reference_kv_compress_norm_rope(
state_cache,
block_table,
positions,
rms_weight,
cos_sin_cache,
compress_ratio,
overlap,
use_fp4,
rms_eps=RMS_EPS,
fp8_max=FP8_MAX,
)

if use_fp4:
for i in range(num_tokens):
blk, pos = i // kv_block_size, i % kv_block_size
val_off = pos * TOKEN_STRIDE
fp4_actual = kv_cache[blk, val_off : val_off + TOKEN_STRIDE]
assert torch.equal(k_quant[i], fp4_actual), (
f"token {i}: packed nibbles differ, "
f"{(k_quant[i] != fp4_actual).sum()} "
f"/ {TOKEN_STRIDE}"
)

scale_off = kv_block_size * TOKEN_STRIDE + pos * SCALE_DIM
scale_actual = kv_cache[blk, scale_off : scale_off + SCALE_DIM]
assert torch.equal(scale_actual, scale[i]), (
f"token {i}: ue8m0 {scale_actual.tolist()} != {scale[i].tolist()}"
)

else:
k_quant = k_quant.view(torch.uint8)
for i in range(num_tokens):
blk, pos = i // kv_block_size, i % kv_block_size
val_off = pos * TOKEN_STRIDE
assert torch.equal(
k_quant[i], kv_cache[blk, val_off : val_off + TOKEN_STRIDE]
), f"token {i}: FP8 bytes differ"

scale_off = kv_block_size * TOKEN_STRIDE + pos * SCALE_DIM
actual_scale = kv_cache[blk, scale_off : scale_off + SCALE_DIM].view(
torch.float32
)
assert torch.equal(actual_scale, scale[i : i + 1]), (
f"token {i}: scale {actual_scale.item()} != {scale[i].item()}"
)
Loading
Loading