diff --git a/csrc/topk.cu b/csrc/topk.cu index 364ecc21e532..b0f612ba6e4b 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -82,18 +82,73 @@ void launch_persistent_topk(const torch::Tensor& logits, size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t); if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium; + // Query occupancy for the instantiation that will actually launch; + // overestimating it deadlocks the cooperative barrier. int occupancy = 1; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, - smem_size); + cudaError_t occ_err = cudaSuccess; + if (vec_size == 4) { + occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, + smem_size); + } else if (vec_size == 2) { + occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, + smem_size); + } else { + occ_err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, + smem_size); + } + TORCH_CHECK(occ_err == cudaSuccess, + "persistent_topk occupancy query failed: ", + cudaGetErrorString(occ_err)); if (occupancy < 1) occupancy = 1; - uint32_t max_resident_ctas = static_cast(num_sms) * occupancy; + // The cooperative spin-wait barrier only runs when at least one row hits + // the radix path (seq_len > RADIX_THRESHOLD). Below that, non-CTA-0 CTAs + // early-exit, so oversubscription can't deadlock and headroom is wasted. + const bool needs_cooperative = + static_cast(max_seq_len) > P::RADIX_THRESHOLD; + + const uint32_t hw_resident_cap = + static_cast(num_sms) * static_cast(occupancy); + uint32_t max_resident_ctas = hw_resident_cap; + if (needs_cooperative) { + // Reserve one CTA per SM when occupancy allows; fall back to a single + // CTA when occupancy == 1 (the most deadlock-prone case — any straggler + // kernel that takes the only slot on one SM hangs the barrier). Never + // drop below one full group's worth. + uint32_t headroom = (occupancy > 1) ? static_cast(num_sms) : 1u; + if (max_resident_ctas >= headroom + ctas_per_group) { + max_resident_ctas -= headroom; + } + } uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group, static_cast(num_rows)); if (num_groups == 0) num_groups = 1; uint32_t total_ctas = num_groups * ctas_per_group; + // If the cooperative launch wouldn't fit, fall back to FilteredTopK + // instead of deadlocking. Only relevant when needs_cooperative. + if (needs_cooperative && total_ctas > hw_resident_cap) { + TORCH_CHECK(max_smem_per_block >= 128 * 1024, + "persistent_topk would oversubscribe and the FilteredTopK " + "fallback requires >=128KB smem per block (have ", + max_smem_per_block, "). total_ctas=", total_ctas, + " > num_sms*occupancy=", hw_resident_cap, " (TopK=", TopK, + ", vec_size=", vec_size, ", ctas_per_group=", ctas_per_group, + ", smem=", smem_size, ")."); + cudaError_t status = + vllm::FilteredTopKRaggedTransform( + logits.data_ptr(), output.data_ptr(), + lengths.data_ptr(), static_cast(num_rows), + static_cast(TopK), static_cast(stride), + stream); + TORCH_CHECK(status == cudaSuccess, + "FilteredTopK fallback failed: ", cudaGetErrorString(status)); + return; + } + size_t state_bytes = num_groups * sizeof(P::RadixRowState); TORCH_CHECK(workspace.size(0) >= static_cast(state_bytes), "workspace too small, need ", state_bytes, " bytes"); diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 4e3706645ef2..54b796fde3bf 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -36,7 +36,7 @@ th { | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ht.DeepEPHTPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ll.DeepEPLLPrepareAndFinalize] | | flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_two_sided.FlashInferNVLinkTwoSidedPrepareAndFinalize] | -| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] | +| flashinfer_nvlink_one_sided | standard | nvfp4,bf16,mxfp8 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] | !!! info "Table key" 1. All types: mxfp4, nvfp4, int4, int8, fp8 diff --git a/tests/compile/h100/test_startup.py b/tests/compile/h100/test_startup.py index ff4496c2ba6d..78554a3e93da 100644 --- a/tests/compile/h100/test_startup.py +++ b/tests/compile/h100/test_startup.py @@ -34,7 +34,10 @@ def _run_vllm(vllm_runner): mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=CUDAGraphMode.NONE, ), - num_gpu_blocks_override=8, + # Phi-tiny-MoE uses SWA, whose admission cap is `cdiv(L, block_size) + 1` + # at default block_size=16 — i.e. 17 blocks for max_model_len=256. Use + # 32 for headroom. + num_gpu_blocks_override=32, ): pass @@ -190,7 +193,7 @@ def _run_model(vllm_runner, spec: ModelStartupSpec): cudagraph_mode=CUDAGraphMode.NONE, pass_config=PassConfig(fuse_allreduce_rms=False), ), - num_gpu_blocks_override=8, + num_gpu_blocks_override=16, ): pass diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index bbb9cb1fcbcc..d822b68c5036 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -405,6 +405,9 @@ def test_should_split(): (None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0), # truncated to nearest multiple of 8 or 16 (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256), + # max_num_batched_tokens <= max_cudagraph_capture_size should always be + # captured even if not landing on a 16-stride step + (None, 2048, 1, False, 257, CUDAGraphMode.FULL_AND_PIECEWISE, 257), # max from list ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15), # SP forces full-graph compilation, sizes are filtered by TP diff --git a/tests/kernels/test_compressor_kv_cache.py b/tests/kernels/test_compressor_kv_cache.py index 592b58fbe430..122254bc3c41 100644 --- a/tests/kernels/test_compressor_kv_cache.py +++ b/tests/kernels/test_compressor_kv_cache.py @@ -3,12 +3,11 @@ """ Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant. -Two paths tested: +Four test functions cover five paths: A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64 B) Indexer: head_dim=128 (all FP8), quant_block=128 - -These serve as golden references for validating the future fused -compressor+quant+cache kernel. + C) DeepseekV4 Attention magnitude range: correctness across small/large values + D) Indexer fused Triton kernel: compress+norm+rope+quant+insert """ import math @@ -21,6 +20,12 @@ dequantize_and_gather_k_cache, quantize_and_insert_k_cache, ) +from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import ( + _fused_kv_compress_norm_rope_insert_indexer_attn, + _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, +) + +from .test_fused_indexer_q_rope_quant import quantize_to_mxfp4 def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float): @@ -309,3 +314,222 @@ def test_deepseek_v4_quant_magnitude_range(): f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, " f"magnitude={magnitude:.4f}" ) + + +# ── Test D: Indexer fused K-cache insert (Triton kernels) ──────────────────── +# +# Both kernels share the same Triton signature; use_fp4 selects between them. +# Full pipeline: state-cache gather → softmax-weighted compress → RMSNorm → +# GPT-J RoPE → quant (MXFP4 or FP8) → paged cache insert. + + +def _reference_kv_compress_norm_rope( + state_cache: torch.Tensor, + block_table: torch.Tensor, + positions: torch.Tensor, + rms_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + compress_ratio: int = 1, + overlap: int = 0, + use_fp4: bool = False, + rms_eps: float = 1e-6, + fp8_max: float = 448.0, +): + """Compress → RMSNorm → GPT-J RoPE → quantize. + + Gathers (1+overlap)*compress_ratio state entries per output token, applies + per-element softmax over the scores, and computes the weighted kv sum. + Returns (quantized_values, scale) matching the kernel's output layout. + """ + device = state_cache.device + head_dim = rms_weight.shape[0] + rope_dim = cos_sin_cache.shape[-1] + state_block_size = state_cache.shape[1] + state_width = state_cache.shape[-1] // 2 + nope_dim = head_dim - rope_dim + total = (1 + overlap) * compress_ratio + results = [] + for pos in positions.tolist(): + src = torch.arange(pos - total + 1, pos + 1, dtype=torch.int64, device=device) + valid = src >= 0 + idx = src.clamp(min=0) + pages = block_table[0, idx // state_block_size] + offsets = idx % state_block_size + raw = state_cache[pages, offsets].float() # [total, state_dim] + + # Group 0 (tokens 0..cr-1): kv[:H], score[SW:SW+H] + # Group 1 (tokens cr..2cr-1): kv[H:2H], score[SW+H:SW+2H] + if overlap: + sw = state_width + g0_kv = raw[:compress_ratio, :head_dim] + g1_kv = raw[compress_ratio:, head_dim : 2 * head_dim] + g0_scores = raw[:compress_ratio, sw : sw + head_dim] + g1_scores = raw[compress_ratio:, sw + head_dim : sw + 2 * head_dim] + kv = torch.cat([g0_kv, g1_kv]) + scores = torch.cat([g0_scores, g1_scores]) + else: + kv = raw[:, :head_dim] + scores = raw[:, state_width : state_width + head_dim] + + scores[~valid] = float("-inf") + kv[~valid] = 0.0 + weights = torch.softmax(scores, dim=0) + compressed = (kv * weights).sum(dim=0) # [H] + var = (compressed * compressed).mean() + normed = compressed * torch.rsqrt(var + rms_eps) * rms_weight.float() + compressed_pos = (pos // compress_ratio) * compress_ratio + cos, sin = cos_sin_cache[compressed_pos].float().chunk(2) + nope, rope = normed.split([nope_dim, rope_dim]) + rope = torch.stack( + [rope[0::2] * cos - rope[1::2] * sin, rope[1::2] * cos + rope[0::2] * sin], + dim=-1, + ).reshape(rope_dim) + results.append(torch.cat([nope, rope]).to(state_cache.dtype)) + result = torch.stack(results) + + if use_fp4: + return quantize_to_mxfp4(result) + else: + pairs = [ + _ue8m0_reference(result[t], head_dim, fp8_max) for t in range(len(result)) + ] + quants, scales = zip(*pairs) + return torch.stack(quants), torch.cat(scales) + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32]) +@pytest.mark.parametrize("kv_block_size", [16, 32]) +@pytest.mark.parametrize("use_fp4", [False, True]) +def test_fused_kv_insert_indexer(num_tokens: int, kv_block_size: int, use_fp4: bool): + """Fused K compress+norm+rope+quant+insert for the indexer KV cache.""" + HEAD_DIM = 128 + ROPE_DIM = 64 + BLOCK_SIZE = 16 + RMS_EPS = 1e-6 + FP8_MAX = 448.0 + + device = "cuda" + torch.manual_seed(42) + compress_ratio = 4 + + if use_fp4: + TOKEN_STRIDE = HEAD_DIM // 2 # packed nibbles: 64 bytes + SCALE_DIM = HEAD_DIM // 32 # ue8m0 bytes: 4 + QUANT_BLOCK = 32 + kernel = _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn + else: + TOKEN_STRIDE = HEAD_DIM # FP8 bytes: 128 + SCALE_DIM = 4 # 1 float32: 4 bytes + QUANT_BLOCK = HEAD_DIM + kernel = _fused_kv_compress_norm_rope_insert_indexer_attn + + # overlap=1 whenever compress_ratio==4, matching DeepseekCompressor logic. + overlap = 1 if compress_ratio == 4 else 0 + coff = 1 + overlap # multiplier for state_dim per entry + + num_pages = (compress_ratio * num_tokens - 1) // BLOCK_SIZE + 2 + state_cache = torch.randn( + num_pages, + BLOCK_SIZE, + 2 * coff * HEAD_DIM, # kv_state + score_state, each coff*HEAD_DIM wide + dtype=torch.bfloat16, + device=device, + ) + block_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0) + token_to_req = torch.zeros(num_tokens, dtype=torch.int32, device=device) + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + positions = torch.arange( + compress_ratio - 1, + compress_ratio * num_tokens, + compress_ratio, + dtype=torch.int64, + device=device, + ) + rms_weight = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device) + cos_sin_cache = torch.randn(compress_ratio * num_tokens, ROPE_DIM, device=device) + + kv_n_blocks = (num_tokens + kv_block_size - 1) // kv_block_size + 1 + kv_cache = torch.zeros( + kv_n_blocks, + kv_block_size * (TOKEN_STRIDE + SCALE_DIM), + dtype=torch.uint8, + device=device, + ) + + kernel[(num_tokens,)]( + state_cache, + state_cache.stride(0), + state_cache.stride(1), + token_to_req, + positions, + slot_mapping, + block_table, + block_table.stride(0), + BLOCK_SIZE, + rms_weight, + RMS_EPS, + cos_sin_cache, + cos_sin_cache.stride(0), + kv_cache, + slot_mapping, + kv_block_size, + HEAD_SIZE=HEAD_DIM, + TRITON_BLOCK_SIZE=HEAD_DIM, + STATE_WIDTH=coff * HEAD_DIM, + COMPRESS_RATIO=compress_ratio, + OVERLAP=overlap, + ROPE_HEAD_DIM=ROPE_DIM, + FP8_MAX=FP8_MAX, + QUANT_BLOCK=QUANT_BLOCK, + TOKEN_STRIDE=TOKEN_STRIDE, + SCALE_DIM=SCALE_DIM, + KV_BLOCK_STRIDE=kv_cache.stride(0), + num_warps=1, + ) + + k_quant, scale = _reference_kv_compress_norm_rope( + state_cache, + block_table, + positions, + rms_weight, + cos_sin_cache, + compress_ratio, + overlap, + use_fp4, + rms_eps=RMS_EPS, + fp8_max=FP8_MAX, + ) + + if use_fp4: + for i in range(num_tokens): + blk, pos = i // kv_block_size, i % kv_block_size + val_off = pos * TOKEN_STRIDE + fp4_actual = kv_cache[blk, val_off : val_off + TOKEN_STRIDE] + assert torch.equal(k_quant[i], fp4_actual), ( + f"token {i}: packed nibbles differ, " + f"{(k_quant[i] != fp4_actual).sum()} " + f"/ {TOKEN_STRIDE}" + ) + + scale_off = kv_block_size * TOKEN_STRIDE + pos * SCALE_DIM + scale_actual = kv_cache[blk, scale_off : scale_off + SCALE_DIM] + assert torch.equal(scale_actual, scale[i]), ( + f"token {i}: ue8m0 {scale_actual.tolist()} != {scale[i].tolist()}" + ) + + else: + k_quant = k_quant.view(torch.uint8) + for i in range(num_tokens): + blk, pos = i // kv_block_size, i % kv_block_size + val_off = pos * TOKEN_STRIDE + assert torch.equal( + k_quant[i], kv_cache[blk, val_off : val_off + TOKEN_STRIDE] + ), f"token {i}: FP8 bytes differ" + + scale_off = kv_block_size * TOKEN_STRIDE + pos * SCALE_DIM + actual_scale = kv_cache[blk, scale_off : scale_off + SCALE_DIM].view( + torch.float32 + ) + assert torch.equal(actual_scale, scale[i : i + 1]), ( + f"token {i}: scale {actual_scale.item()} != {scale[i].item()}" + ) diff --git a/tests/kernels/test_fused_indexer_q_rope_quant.py b/tests/kernels/test_fused_indexer_q_rope_quant.py index 03d5ad4c8ac7..be2039ce513e 100644 --- a/tests/kernels/test_fused_indexer_q_rope_quant.py +++ b/tests/kernels/test_fused_indexer_q_rope_quant.py @@ -30,6 +30,56 @@ MAX_POS = 4096 +def quantize_to_mxfp4( + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reference MXFP4 quantization. + + Args: + x: [..., head_dim] where head_dim is divisible by 32 + Returns: + packed: [..., head_dim//2] uint8 2 E2M1 nibbles/byte, low nibble = even index + scales: [..., head_dim//32] uint8 1 ue8m0 byte + """ + MXFP4_BLOCK_SIZE = 32 + orig_shape = x.shape + head_dim = orig_shape[-1] + n_blocks = head_dim // MXFP4_BLOCK_SIZE + + x_f32 = x.float().reshape(-1, n_blocks, MXFP4_BLOCK_SIZE) + + # Per-block ue8m0 scale: 2^ceil(log2(amax / 6.0)), stored as byte = exp + 127 + # 6 * 2^-126 is from https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/inference/kernel.py#L163 + amax = x_f32.abs().amax(dim=-1, keepdim=True).clamp(min=6 * (2**-126)) + log2_ratio = (amax * (1.0 / 6.0)).log2().ceil().clamp(-127.0, 127.0) + scale = log2_ratio.exp2() + ue8m0 = (log2_ratio + 127.0).to(torch.uint8) # [*, n_blocks] + + # E2M1 round-to-nearest-even: midpoints round to the even code. + # E2M1 values: [0.00, 0.50, 1.00, 1.50, 2.00, 3.00, 4.00, 6.00] + # boundaries: [ 0.25, 0.75, 1.25, 1.75, 2.50, 3.50, 5.00] + x_scaled = (x_f32 / scale).clamp(-6.0, 6.0) + abs_x = x_scaled.abs() + code = torch.zeros_like(abs_x, dtype=torch.int32) + code = torch.where(abs_x > 0.25, 1, code) + code = torch.where(abs_x >= 0.75, 2, code) + code = torch.where(abs_x > 1.25, 3, code) + code = torch.where(abs_x >= 1.75, 4, code) + code = torch.where(abs_x > 2.5, 5, code) + code = torch.where(abs_x >= 3.5, 6, code) + code = torch.where(abs_x > 5.0, 7, code) + sign = ((x_scaled.view(torch.int32) >> 31) & 1).to(torch.uint8) + nibble = code.to(torch.uint8) | (sign << 3) + + # Pack: even-index element → low nibble, odd-index → high nibble + nibble_flat = nibble.reshape(-1, head_dim) + packed = (nibble_flat[:, 0::2] | (nibble_flat[:, 1::2] << 4)).contiguous() + packed = packed.reshape(*orig_shape[:-1], head_dim // 2) + + scales = ue8m0.view(*orig_shape[:-1], n_blocks) + return packed, scales + + def _reference( positions: torch.Tensor, q: torch.Tensor, @@ -37,6 +87,7 @@ def _reference( weights: torch.Tensor, softmax_scale: float, head_scale: float, + use_fp4: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: q_rot = q.clone() ops.rotary_embedding( @@ -49,22 +100,33 @@ def _reference( HEAD_DIM - ROPE_DIM, # rope_dim_offset → rotate the tail False, ) - q_fp8, q_scale = per_token_group_quant_fp8( - q_rot.view(-1, HEAD_DIM).contiguous(), - HEAD_DIM, - use_ue8m0=True, - ) - q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM) - q_scale = q_scale.view(-1, N_HEAD) - weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale - return q_fp8, weights_out + if use_fp4: + q_packed, ue8m0 = quantize_to_mxfp4(q_rot.view(-1, N_HEAD, HEAD_DIM)) + # Pack 4 ue8m0 bytes into 1 int32 + q_scale = ue8m0.view(torch.int32).squeeze(-1) + # FP4 path: q_scale stays separate (cannot be folded into a per-token scalar) + weights_out = weights.to(torch.float32) * softmax_scale * head_scale + return (q_packed, q_scale), weights_out + + else: + q_fp8, q_scale = per_token_group_quant_fp8( + q_rot.view(-1, HEAD_DIM).contiguous(), + HEAD_DIM, + use_ue8m0=True, + ) + q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM) + q_scale = q_scale.view(-1, N_HEAD) + + weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale + return q_fp8, weights_out @pytest.mark.parametrize("num_tokens", [1, 7, 32, 257]) @pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("use_fp4", [False, True]) @torch.inference_mode() -def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype): +def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype, use_fp4): device = "cuda" torch.manual_seed(0) @@ -77,21 +139,32 @@ def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype): softmax_scale = HEAD_DIM**-0.5 head_scale = N_HEAD**-0.5 - q_fp8_ref, weights_ref = _reference( - positions, q, cos_sin_cache, weights, softmax_scale, head_scale + q_quant_ref, weights_ref = _reference( + positions, q, cos_sin_cache, weights, softmax_scale, head_scale, use_fp4 ) - q_fp8_fused, weights_fused = fused_indexer_q_rope_quant( - positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale + q_quant_fused, weights_fused = fused_indexer_q_rope_quant( + positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale, use_fp4 ) + if use_fp4: + q_quant_ref, q_scale_ref = q_quant_ref + q_quant_fused, q_scale_fused = q_quant_fused + + assert torch.equal(q_scale_ref, q_scale_fused), ( + f"q_scale mismatch: " + f"{(q_scale_ref != q_scale_fused).sum().item()} " + f"/ {q_scale_ref.numel()} bytes differ" + ) + # fp8 tensors aren't directly comparable via torch.equal — reinterpret as int8. - ref_bits = q_fp8_ref.view(torch.int8) - fused_bits = q_fp8_fused.view(torch.int8) + ref_bits = q_quant_ref.view(torch.int8) + fused_bits = q_quant_fused.view(torch.int8) assert torch.equal(ref_bits, fused_bits), ( - f"q_fp8 mismatch: " + f"q_quant_fused mismatch: " f"{(ref_bits != fused_bits).sum().item()} / {ref_bits.numel()} bytes differ" ) + assert weights_fused.dtype == torch.float32 assert torch.equal(weights_ref, weights_fused), ( f"weights mismatch: max abs diff " f"{(weights_ref - weights_fused).abs().max().item()}" diff --git a/tests/tool_parsers/test_deepseekv32_tool_parser.py b/tests/tool_parsers/test_deepseekv32_tool_parser.py index 6145253d9f90..c547795e7bf2 100644 --- a/tests/tool_parsers/test_deepseekv32_tool_parser.py +++ b/tests/tool_parsers/test_deepseekv32_tool_parser.py @@ -188,6 +188,30 @@ def test_multiple_tools(self, parser): "location": "NYC" } + def test_type_conversion_in_non_streaming(self): + """Non-streaming extraction must convert params using the tool schema.""" + tool = ChatCompletionToolsParam( + function=FunctionDefinition( + name="toggle", + parameters={ + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + "count": {"type": "integer"}, + }, + }, + ), + ) + parser = make_parser(tools=[tool]) + model_output = build_tool_call("toggle", {"enabled": "true", "count": "42"}) + result = parser.extract_tool_calls(model_output, None) + assert result.tools_called + assert len(result.tool_calls) == 1 + args = json.loads(result.tool_calls[0].function.arguments) + assert args == {"enabled": True, "count": 42} + assert isinstance(args["enabled"], bool) + assert isinstance(args["count"], int) + # --------------------------------------------------------------------------- # Tests: extract_tool_calls_streaming diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index cfd03c5f687e..985b97c69ca4 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -2074,6 +2074,54 @@ def test_auto_fit_max_model_len_not_triggered(): assert vllm_config.model_config.max_model_len == 16 +def test_auto_fit_max_model_len_respects_num_gpu_blocks_override(): + """Auto-fit must size max_model_len against the override-clamped pool, not + the raw `available_memory`. Without this, auto-fit could pick a + max_model_len that no longer fits once `num_gpu_blocks_override` is applied. + """ + model_config = ModelConfig(max_model_len=16384) + model_config.original_max_model_len = -1 # request auto-fit + vllm_config = VllmConfig(model_config=model_config) + # Cap the cache to 32 blocks regardless of available memory. + vllm_config.cache_config.num_gpu_blocks_override = 32 + + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 + kv_cache_specs = { + "layer_1": new_kv_cache_spec(), # block_size=16 + "layer_2": new_kv_cache_spec(), + } + # Plenty of raw memory (1024 blocks per layer would fit max_model_len=16384). + large_available_memory = mem_per_block_per_layer * 2 * 1024 + + get_kv_cache_configs(vllm_config, [kv_cache_specs], [large_available_memory]) + + # 32 blocks * block_size 16 = 512 token slots, so max_model_len must + # auto-fit at or below that. + assert 0 < vllm_config.model_config.max_model_len <= 32 * 16 + + +def test_check_enough_kv_cache_memory_respects_num_gpu_blocks_override(): + """Admission check must use the override-clamped pool size, not raw + `available_memory`. Without this, startup could accept a max_model_len + that does not actually fit in `num_gpu_blocks_override` blocks. + """ + model_config = ModelConfig(max_model_len=16384) + vllm_config = VllmConfig(model_config=model_config) + # 32 blocks is far too small for max_model_len=16384 (would need 1024). + vllm_config.cache_config.num_gpu_blocks_override = 32 + + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 + kv_cache_specs = { + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), + } + # Plenty of raw memory: a bytes-only check against this would pass. + large_available_memory = mem_per_block_per_layer * 2 * 1024 + + with pytest.raises(ValueError, match="max seq len"): + get_kv_cache_configs(vllm_config, [kv_cache_specs], [large_available_memory]) + + def test_unify_hybrid_kv_cache_specs(): # 1. has_full_attention and has_sliding_window before_spec_1 = new_kv_cache_spec() diff --git a/tests/v1/e2e/general/test_async_scheduling.py b/tests/v1/e2e/general/test_async_scheduling.py index 8e1eddb0f64e..28a1bedbe0b2 100644 --- a/tests/v1/e2e/general/test_async_scheduling.py +++ b/tests/v1/e2e/general/test_async_scheduling.py @@ -324,10 +324,13 @@ def run_test( ): spec_decoding = spec_config is not None cache_arg: dict[str, Any] = ( - # Force preemptions - dict(num_gpu_blocks_override=32) + # Force preemptions: with 32 blocks the cache holds at most a single + # max-length request, so the ~34 concurrent prompts contend and trigger + # preemption. (Prompts here are << max_model_len, so dropping + # max_model_len from 4096 to 512 doesn't change generation behavior.) + dict(num_gpu_blocks_override=32, max_model_len=512) if test_preemption - else dict(gpu_memory_utilization=0.9) + else dict(gpu_memory_utilization=0.9, max_model_len=4096) ) spec_mml = (spec_config or {}).get("max_model_len") spec_method = (spec_config or {}).get("method", "none") @@ -343,7 +346,6 @@ def run_test( with VllmRunner( model, - max_model_len=4096, enable_chunked_prefill=test_prefill_chunking, # Force prefill chunking max_num_batched_tokens=48 if test_prefill_chunking else None, diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f591605d08c7..56123542ce71 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1432,6 +1432,10 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( range(256, max_graph_size + 1, 16)) + `max_num_batched_tokens` is also appended to the list if it fits + within `max_cudagraph_capture_size`, so the max batch size is captured + even when off-stride. + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` will be the final sizes to capture cudagraph (in ascending order). @@ -1520,6 +1524,12 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes += list( range(256, max_cudagraph_capture_size + 1, 16) ) + # ensure max_num_tokens is captured if within max capture size + if ( + max_num_tokens <= max_cudagraph_capture_size + and max_num_tokens not in cudagraph_capture_sizes + ): + cudagraph_capture_sizes.append(max_num_tokens) # de-duplicate and sort the sizes cudagraph_capture_sizes = sorted(set(cudagraph_capture_sizes)) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 554a34b6a68e..b2202b7d08d4 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -128,13 +128,6 @@ def get_instance() -> "CuMemAllocator": return CuMemAllocator.instance def __init__(self): - conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") - assert "expandable_segments:True" not in conf, ( - "Expandable segments are not compatible with memory pool. " - "Please track https://github.com/pytorch/pytorch/issues/147851 " - "for the latest updates." - ) - self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag self.allocator_and_pools: dict[str, Any] = {} @@ -264,34 +257,49 @@ def use_memory_pool(self, tag: str | None = None): assert isinstance(tag, str) + # Expandable segments are incompatible with the memory pool used for + # sleep mode (see https://github.com/pytorch/pytorch/issues/147851). + # If the user has enabled expandable segments via + # PYTORCH_CUDA_ALLOC_CONF, temporarily disable them for the duration + # of the memory pool context and restore on exit. + conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + expandable_was_enabled = "expandable_segments:True" in conf + if expandable_was_enabled: + torch.cuda.memory._set_allocator_settings("expandable_segments:False") + old_tag = self.current_tag self.current_tag = tag - with use_memory_pool_with_allocator( - self.python_malloc_callback, self.python_free_callback - ) as data: - # start to hit another PyTorch bug in PyTorch 2.6, - # possibly because of gc-related issue w.r.t. the allocator and - # the memory pool. - # to avoid the issue, we keep a reference of the data. - # see https://github.com/pytorch/pytorch/issues/146431 . - self.allocator_and_pools[tag] = data - yield - # PyTorch's bug, calling torch.cuda.empty_cache() will error - # when using pluggable allocator, see - # https://github.com/pytorch/pytorch/issues/145168 . - # if we have some memory allocated and then freed, - # the memory will not be released, e.g. in online quantization, - # where the model is created in higher precision, and then - # quantized in lower precision. - # Find all unused allocations and manually release them. - # TODO: we should expose `empty_cache` method in the memory pool. - # TODO: ask for help from PyTorch team to expose this method. - allocations = data[0].snapshot() - for allocation in allocations: - if allocation["allocated_size"] == 0: - handle = self._python_free_callback(allocation["address"]) - unmap_and_release(handle) + try: + with use_memory_pool_with_allocator( + self.python_malloc_callback, self.python_free_callback + ) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator + # and the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data + yield + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://github.com/pytorch/pytorch/issues/145168 . + # if we have some memory allocated and then freed, + # the memory will not be released, e.g. in online + # quantization, where the model is created in higher + # precision, and then quantized in lower precision. + # Find all unused allocations and manually release them. + # TODO: we should expose `empty_cache` method in the memory + # pool. + # TODO: ask for help from PyTorch team to expose this method. + allocations = data[0].snapshot() + for allocation in allocations: + if allocation["allocated_size"] == 0: + handle = self._python_free_callback(allocation["address"]) + unmap_and_release(handle) + finally: self.current_tag = old_tag + if expandable_was_enabled: + torch.cuda.memory._set_allocator_settings("expandable_segments:True") def get_current_usage(self) -> int: """ diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 6a15d3f6168a..36837c5b1bea 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -584,6 +584,8 @@ def initialize( top_k: int, num_experts: int, hidden_size: int, + dispatch_dtype_bytes_per_elem: int = 0, + dispatch_scale_bytes_per_token: int = 0, ): """Initialize the MoeAlltoAll workspace.""" if self.initialized: @@ -614,9 +616,13 @@ def initialize( ep_config = MnnvlConfig( comm_backend=CustomCommunicator(self.cpu_group), ) + if dispatch_dtype_bytes_per_elem == 0: + hidden_bytes = hidden_size // 2 + else: + hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem total_dispatch_payload_size_per_token = ( - hidden_size // 2 # nvfp4 hidden states - + hidden_size // 16 # fp8 scaling factors + hidden_bytes + + dispatch_scale_bytes_per_token + top_k * 4 # int32 topks ids + top_k * 4 # float32 topk weights ) diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index af2783f604da..cae80c35316a 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, ) -from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( @@ -271,16 +270,12 @@ def __init__( def forward( self, - # [num_tokens, hidden_size] - x: torch.Tensor, + # [num_tokens, 2 * self.coff * self.head_dim] + kv_score: torch.Tensor, # [num_tokens] positions: torch.Tensor, rotary_emb, ) -> None: - num_tokens, _ = x.shape - # bf16 weights/activations but fp32 output for numerical stability of - # the downstream compressor math. - kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight) # Each of shape [num_tokens, coff * self.head_dim] # input bf16, output are fp32 kv, score = kv_score.split( diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 43242eddb5b2..a968a06bb650 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,8 +4,9 @@ DeepseekV4 MLA Attention Layer """ +from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast import torch import torch.nn as nn @@ -16,6 +17,7 @@ ReplicatedLinear, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer +from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( @@ -51,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) -from vllm.utils.multi_stream_utils import maybe_execute_in_parallel +from vllm.utils.multi_stream_utils import ( + execute_in_parallel, + maybe_execute_in_parallel, +) from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -94,7 +99,7 @@ class DeepseekV4MLAModules: indexer: torch.nn.Module | None indexer_rotary_emb: torch.nn.Module topk_indices_buffer: torch.Tensor | None - aux_stream: torch.cuda.Stream | None = None + aux_stream_list: list[torch.cuda.Stream] | None = None # --8<-- [start:multi_head_latent_attention] @@ -217,8 +222,11 @@ def __init__( + 1 # 1B pad ) - self.aux_stream = mla_modules.aux_stream - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + self.aux_stream_list = mla_modules.aux_stream_list + # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; + # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins + # before post-GEMM starts. + self.ln_events = [torch.cuda.Event() for _ in range(4)] assert cache_config is not None, "DeepseekV4 attention requires cache_config" self.swa_cache_layer = DeepseekV4SWACache( @@ -277,9 +285,6 @@ def forward( hidden_states: torch.Tensor, llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: - qr_kv, _ = self.fused_wqa_wkv(hidden_states) - qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) - # Pre-allocate attention output with FlashMLA-padded head count. # The op writes into `o_padded`; we slice to n_local_heads after. num_tokens = hidden_states.shape[0] @@ -292,8 +297,6 @@ def forward( # Attention (inside custom op for torch.compile boundary) torch.ops.vllm.deepseek_v4_attention( hidden_states, - qr, - kv, positions, o_padded, self.layer_name, @@ -332,17 +335,71 @@ def forward( return self.wo_b(z.flatten(1)) + def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: + assert self.aux_stream_list is not None + assert len(self.aux_stream_list) >= 3 + + # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs + # on aux streams 0..2 when their owning module exists. ln_events[0] + # is the fan-out start event; ln_events[1..3] are per-aux done events. + aux_fns: list[Callable[[], Any] | None] = [None, None, None] + + if self.compressor is not None: + # Local ref so the closure keeps a non-None type for mypy. + compressor = self.compressor + + def compressor_kv_score() -> torch.Tensor: + return cublas_gemm_bf16_bf16_fp32( + hidden_states, compressor.fused_wkv_wgate.weight + ) + + aux_fns[0] = compressor_kv_score + + if self.indexer is not None: + indexer = self.indexer + + def indexer_weights_proj() -> torch.Tensor: + # ReplicatedLinear returns (output, bias); bias is None. + weights, _ = indexer.weights_proj(hidden_states) + return weights + + def indexer_compressor_kv_score() -> torch.Tensor: + return cublas_gemm_bf16_bf16_fp32( + hidden_states, indexer.compressor.fused_wkv_wgate.weight + ) + + aux_fns[1] = indexer_weights_proj + aux_fns[2] = indexer_compressor_kv_score + + def fused_wqa_wkv() -> torch.Tensor: + # MergedColumnParallelLinear returns (output, bias); bias is None. + qr_kv, _ = self.fused_wqa_wkv(hidden_states) + return qr_kv + + qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( + fused_wqa_wkv, + aux_fns, + self.ln_events[0], + self.ln_events[1:4], + self.aux_stream_list[:3], + ) + + return qr_kv, kv_score, indexer_kv_score, indexer_weights + def attention_impl( self, hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + qr_kv, kv_score, indexer_kv_score, indexer_weights = ( + self.attn_gemm_parallel_execute(hidden_states) + ) + + qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) qr, kv = fused_q_kv_rmsnorm( qr, kv, @@ -350,42 +407,60 @@ def attention_impl( self.kv_norm.weight.data, self.eps, ) - q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - # Overlap kv_insert with whichever of indexer/compressor is present. - # Indexer implies compressor; when both exist, compressor rides on the - # aux stream alongside kv_insert so the heavy indexer owns default. + # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride + # on the default stream so q stays on its consumer stream (mla_attn + # downstream reads q on default). Indexer/compressor go on aux for + # overlap with default's GEMM + cache write. if self.indexer is not None: + assert self.aux_stream_list is not None + aux_stream = self.aux_stream_list[0] indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None compressor = self.compressor - def kv_insert_and_compress() -> None: + def wq_b_kv_insert_and_compress() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - compressor(hidden_states, positions, self.rotary_emb) - - maybe_execute_in_parallel( - lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb), - kv_insert_and_compress, + compressor(kv_score, positions, self.rotary_emb) + return q + + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert_and_compress, + lambda: indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + ), self.ln_events[0], self.ln_events[1], - self.aux_stream, + aux_stream, ) elif self.compressor is not None: - # Compressor on default, kv_insert on aux. + # wq_b + kv_insert on default, compressor on aux. + assert self.aux_stream_list is not None + aux_stream = self.aux_stream_list[0] compressor = self.compressor - maybe_execute_in_parallel( - lambda: compressor(hidden_states, positions, self.rotary_emb), - lambda: self._fused_qnorm_rope_kv_insert( - q, kv, positions, attn_metadata - ), + + def wq_b_kv_insert() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + return q + + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert, + lambda: compressor(kv_score, positions, self.rotary_emb), self.ln_events[0], self.ln_events[1], - self.aux_stream, + aux_stream, ) else: # SWA-only layer: no compressor, no overlap. + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) # Handle dummy run (no metadata). @@ -455,21 +530,17 @@ def _fused_qnorm_rope_kv_insert( def deepseek_v4_attention( hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.attention_impl(hidden_states, qr, kv, positions, out) + self.attention_impl(hidden_states, positions, out) def deepseek_v4_attention_fake( hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, layer_name: str, @@ -1057,18 +1128,20 @@ def forward( self, hidden_states: torch.Tensor, qr: torch.Tensor, + compressed_kv_score: torch.Tensor, + indexer_weights: torch.Tensor, positions: torch.Tensor, rotary_emb: nn.Module, ) -> torch.Tensor: + # ReplicatedLinear returns (output, bias); bias is None. q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) - k = self.compressor(hidden_states, positions, rotary_emb) - weights, _ = self.weights_proj(hidden_states) + k = self.compressor(compressed_kv_score, positions, rotary_emb) q_quant, weights = fused_indexer_q_rope_quant( positions, q, rotary_emb.cos_sin_cache, - weights, + indexer_weights, self.softmax_scale, self.n_head**-0.5, use_fp4=self.use_fp4_kv, diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index fba1d4c692af..2a6f0c71d936 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -228,23 +228,37 @@ def maybe_make_prepare_finalize( elif moe.use_fi_nvl_one_sided_kernels: assert quant_config is not None - if quant_config.quant_dtype != "nvfp4": - raise ValueError( - "The 'flashinfer_nvlink_one_sided' all2all backend only " - "supports nvfp4 activation quantization, but got " - f"quant_dtype={quant_config.quant_dtype!r}. Use a different " - "all2all backend (e.g. 'flashinfer_nvlink_two_sided' or " - "'allgather_reducescatter') for non-nvfp4 models." - ) max_num_tokens = ( get_current_vllm_config().scheduler_config.max_num_batched_tokens ) + if quant_config.quant_dtype is None: + dispatch_dtype_bytes_per_elem = 2 + dispatch_scale_bytes_per_token = 0 + elif quant_config.quant_dtype == "nvfp4": + dispatch_dtype_bytes_per_elem = 0 + dispatch_scale_bytes_per_token = moe.hidden_dim // 16 + elif quant_config.quant_dtype == "mxfp8": + dispatch_dtype_bytes_per_elem = 1 + align = quant_config.mx_alignment + if align > 0: + padded_k = ((moe.hidden_dim + align - 1) // align) * align + else: + padded_k = moe.hidden_dim + dispatch_scale_bytes_per_token = padded_k // 32 + else: + raise NotImplementedError( + "flashinfer_nvlink_one_sided dispatch supports nvfp4, mxfp8, " + "and bf16 (quant_dtype=None) today; got " + f"quant_dtype={quant_config.quant_dtype!r}" + ) prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize( max_num_tokens=max_num_tokens, top_k=moe.experts_per_token, num_experts=moe.num_experts, hidden_size=moe.hidden_dim, num_dispatchers=all2all_manager.world_size, + dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem, + dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token, ) elif moe.use_ag_rs_all2all_kernels and allow_new_interface: diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index f958a6322e38..25b8c331bc3e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -247,6 +247,8 @@ class FusedMoEQuantConfig: gemm1_beta: float | None = None gemm1_clamp_limit: float | None = None + mx_alignment: int = 0 + def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( "illegal quantization" @@ -705,6 +707,7 @@ def mxfp4_mxfp8_moe_quant_config( gemm1_alpha: float | None = None, gemm1_beta: float | None = None, gemm1_clamp_limit: float | None = None, + mx_alignment: int = 0, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and mxfp4 weights. @@ -717,6 +720,7 @@ def mxfp4_mxfp8_moe_quant_config( gemm1_alpha=gemm1_alpha, gemm1_beta=gemm1_beta, gemm1_clamp_limit=gemm1_clamp_limit, + mx_alignment=mx_alignment, ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index f7af9aea70ad..69e5b7fe4f0e 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -44,6 +44,9 @@ def __init__( moe_config.intermediate_size_per_partition ) self.hidden_dim = moe_config.hidden_dim + self.hidden_dim_unpadded = ( + moe_config.hidden_dim_unpadded or moe_config.hidden_dim + ) self.local_num_experts = moe_config.num_local_experts self.ep_rank = moe_config.moe_parallel_config.ep_rank @@ -82,9 +85,6 @@ def __init__( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) - # P1-5 fix: use public quant_dtype property instead of private _a1 - self.use_mxfp8_input = quant_config.quant_dtype == "mxfp8" - @staticmethod def _supports_current_device() -> bool: p = current_platform @@ -121,8 +121,7 @@ def supports_expert_map(self) -> bool: @property def expects_unquantized_inputs(self) -> bool: - # Expert handles MXFP8 quantization internally if needed - return True + return False class TrtLlmMxfp4ExpertsMonolithic( @@ -181,24 +180,19 @@ def apply( ) -> torch.Tensor: from flashinfer import trtllm_fp4_block_scale_moe - # Handle input quantization - if self.use_mxfp8_input: - from flashinfer import mxfp8_quantize - - x_quant, x_scale = mxfp8_quantize( - hidden_states, - is_sf_swizzled_layout=False, - alignment=256, - ) - x_scale = x_scale.view(torch.float8_e4m3fn).reshape( - *hidden_states.shape[:-1], -1 - ) + if a1q_scale is not None: + x_quant = hidden_states + x_scale = a1q_scale.view(torch.float8_e4m3fn) else: assert hidden_states.dtype == torch.bfloat16 x_quant = hidden_states x_scale = None - - output = torch.empty_like(hidden_states) + output = torch.empty( + *hidden_states.shape[:-1], + self.hidden_dim_unpadded, + dtype=torch.bfloat16, + device=hidden_states.device, + ) from vllm.utils.flashinfer import _is_fi_autotuning, autotune @@ -244,10 +238,6 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula Moved from trtllm_moe.py. """ - @property - def expects_unquantized_inputs(self) -> bool: - return True - @staticmethod def _supports_parallel_config( moe_parallel_config: FusedMoEParallelConfig, @@ -284,7 +274,7 @@ def workspace_shapes( # The workspaces for this implementation are managed by flashinfer. workspace1 = (0,) workspace2 = (0,) - output = (M, K) + output = (M, self.hidden_dim_unpadded) return (workspace1, workspace2, output) def apply( @@ -310,18 +300,9 @@ def apply( intermediate_size = self.intermediate_size_per_partition local_expert_offset = self.moe_config.ep_rank * local_num_experts - # Handle input quantization - if self.use_mxfp8_input: - from flashinfer import mxfp8_quantize - - x_quant, x_scale = mxfp8_quantize( - hidden_states, - is_sf_swizzled_layout=False, - alignment=256, - ) - x_scale = x_scale.view(torch.float8_e4m3fn).reshape( - *hidden_states.shape[:-1], -1 - ) + if a1q_scale is not None: + x_quant = hidden_states + x_scale = a1q_scale.view(torch.float8_e4m3fn) else: assert hidden_states.dtype == torch.bfloat16 x_quant = hidden_states diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index f476d980d555..c1423362d737 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -1195,10 +1195,18 @@ def make_mxfp4_moe_quant_config( gemm1_beta=gemm1_beta, gemm1_clamp_limit=swiglu_limit, ) - elif mxfp4_backend in ( - Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, - Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, - ): + elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8: + return mxfp4_mxfp8_moe_quant_config( + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, + mx_alignment=256, + ) + elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8: return mxfp4_mxfp8_moe_quant_config( w1_bias=w1_bias, w2_bias=w2_bias, @@ -1250,7 +1258,6 @@ def make_mxfp4_moe_kernel( """Create a FusedMoEKernel for the given MXFP4 backend.""" is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic) - # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, quant_config=moe_quant_config, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py index a04ff3b8b68f..6cc0d01cde6b 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_one_sided.py @@ -31,6 +31,8 @@ def __init__( num_experts: int, hidden_size: int, num_dispatchers: int = 1, + dispatch_dtype_bytes_per_elem: int = 0, + dispatch_scale_bytes_per_token: int = 0, ): super().__init__() self.max_num_tokens = max_num_tokens @@ -38,6 +40,7 @@ def __init__( self.num_experts = num_experts self.hidden_size = hidden_size self.num_dispatchers_ = num_dispatchers + self.scale_elems_per_token = dispatch_scale_bytes_per_token device_communicator = get_ep_group().device_communicator assert device_communicator is not None @@ -49,6 +52,8 @@ def __init__( top_k=self.top_k, num_experts=self.num_experts, hidden_size=self.hidden_size, + dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem, + dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token, ) @property @@ -92,19 +97,24 @@ def prepare( else a1.shape[0] ) - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=False, # delay swizzle to after comm - ) + if defer_input_quant: + a1q, a1q_scale = a1, None + else: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + mx_alignment=quant_config.mx_alignment, + ) payloads = [] payloads.append(a1q) if a1q_scale is not None: payloads.append(a1q_scale) + topk_ids_payload_index = len(payloads) payloads.append(topk_ids) payloads.append(topk_weights) @@ -113,6 +123,8 @@ def prepare( token_selected_experts=topk_ids, input_payloads=payloads, runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank, + invalid_token_expert_id=-1, # Follow TRTLLM Pattern + expert_id_payload_index=topk_ids_payload_index, ) if a1q_scale is not None: a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads @@ -124,7 +136,8 @@ def prepare( a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1]) a1q_scale_recv = a1q_scale_recv.view(torch.uint8) a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv) - a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16) + assert self.scale_elems_per_token > 0 + a1q_scale_recv = a1q_scale_recv.view(-1, self.scale_elems_per_token) else: a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads a1q_scale_recv = None diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py index 47fe293d511e..78be414759f7 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py @@ -174,6 +174,7 @@ def flashinfer_alltoall_dispatch( # the hidden states, breaking the A2A kernel. So, we # delay the swizzling until after the A2A. is_fp4_scale_swizzled=False, + mx_alignment=quant_config.mx_alignment, ) x = MnnvlMoe.mnnvl_moe_alltoallv( diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py index 2b21e2db9f68..5b3325ad0195 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -40,6 +40,7 @@ def _quantize_and_setup_dispatch( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape, is_fp4_scale_swizzled=False, + mx_alignment=quant_config.mx_alignment, ) # Skip gathering scales if we have static quantization diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py index b9d57da08326..31a35bd60218 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py @@ -31,6 +31,7 @@ def _quantize_input( per_act_token_quant=quant_config.per_act_token_quant, block_shape=quant_config.block_shape, is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled, + mx_alignment=quant_config.mx_alignment, ) return a1q, a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ffab3ca0bfa9..ed24cbe2b233 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -208,11 +208,12 @@ def _mxfp8_e4m3_quantize( per_act_token_quant: bool, block_shape: list[int] | None = None, is_sf_swizzled_layout: bool = False, + mx_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: assert A_scale is None assert not per_act_token_quant assert block_shape is None or block_shape == [1, 32] - return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout) + return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout, mx_alignment) def _mxfp6_e3m2_quantize( @@ -258,6 +259,7 @@ def moe_kernel_quantize_input( is_fp4_scale_swizzled: bool = True, ocp_mx_scheme: str | None = None, quantization_emulation: bool = False, + mx_alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor | None]: # Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation if ocp_mx_scheme is not None: @@ -319,7 +321,8 @@ def moe_kernel_quantize_input( A_scale, per_act_token_quant, block_shape, - is_sf_swizzled_layout=is_fp4_scale_swizzled, + is_sf_swizzled_layout=False, + mx_alignment=mx_alignment, ) elif quant_dtype == "mxfp6_e3m2": if not quantization_emulation: diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index b9b7bd542738..a12918225348 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -85,7 +85,9 @@ def _mxfp8_e4m3_quantize_torch( def _mxfp8_e4m3_quantize_impl( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: from vllm.platforms import current_platform @@ -93,7 +95,9 @@ def _mxfp8_e4m3_quantize_impl( from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize x_q, x_scales = flashinfer_mxfp8_quantize( - x, is_sf_swizzled_layout=is_sf_swizzled_layout + x, + is_sf_swizzled_layout=is_sf_swizzled_layout, + alignment=alignment if alignment > 0 else 32, ) if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout: x_scales = x_scales.view(x.size(0), -1) @@ -103,9 +107,11 @@ def _mxfp8_e4m3_quantize_impl( def mxfp8_e4m3_quantize( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: - return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout) + return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout, alignment) def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: @@ -125,7 +131,9 @@ def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor def mxfp8_e4m3_quantize_fake( - x: torch.Tensor, is_sf_swizzled_layout: bool = False + x: torch.Tensor, + is_sf_swizzled_layout: bool = False, + alignment: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """Fake implementation for torch.compile tracing.""" fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE) diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 6cb9101a78b1..9a06eedd0f7d 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -45,6 +45,7 @@ def __init__( beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, + init_cache: bool = True, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -65,7 +66,13 @@ def __init__( and head_size in [64, 128, 256, 512] ) super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + init_cache=init_cache, ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: @@ -211,7 +218,9 @@ class DeepseekV4ScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + # Avoid compute cache repeatedly + kwargs.pop("init_cache", None) + super().__init__(*args, **kwargs, init_cache=False) cache_fp32 = self._compute_cos_sin_cache() self.register_buffer("cos_sin_cache", cache_fp32, persistent=False) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 7733252804b7..baf28d04581a 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, get_tensor_model_parallel_rank, @@ -54,7 +54,6 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.triton_utils import tl, triton -from vllm.utils.multi_stream_utils import AuxStreamType from vllm.utils.torch_utils import direct_register_custom_op from .utils import ( @@ -65,6 +64,8 @@ maybe_prefix, ) +_DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8") + class DeepseekV4MLP(nn.Module): def __init__( @@ -118,16 +119,59 @@ def forward(self, x): class DeepseekV4FP8Config(Fp8Config): - """FP8 config that routes MoE layers to MXFP4 quantization. - - DeepSeek V4 checkpoints use FP8 for linear/attention layers but - MXFP4 for MoE expert weights. This config inherits standard FP8 - behavior and overrides only the MoE dispatch. + """FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch. + + DeepSeek V4 checkpoints always use FP8 block quantization for + linear/attention layers. The MoE expert weights vary by checkpoint: + - ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts + with ue8m0 (e8m0fnu) FP8 linear scales. + - ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block + experts with float32 FP8 linear scales. + + The dispatch and the linear scale dtype are both keyed off + ``expert_dtype`` from the model's hf_config; missing values default + to ``"fp4"`` so existing FP4 checkpoints stay unchanged. + + NOTE: ``expert_dtype`` is resolved lazily because this config is + constructed during VllmConfig setup, before ``set_current_vllm_config`` + is active. Reading hf_config eagerly in ``__init__`` would always see + the default ``"fp4"`` and silently misroute Flash-Base checkpoints. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.is_scale_e8m0: bool = True + self._resolved_expert_dtype: str | None = None + # ``is_scale_e8m0`` is a property that resolves on first read, + # by which time the current vllm_config has been set. + + @property + def expert_dtype(self) -> str: + if self._resolved_expert_dtype is None: + try: + hf_config = get_current_vllm_config().model_config.hf_config + except Exception: + # vllm_config not yet set; defer the decision until a + # later call lands inside set_current_vllm_config. + return "fp4" + expert_dtype = getattr(hf_config, "expert_dtype", "fp4") + if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES: + raise ValueError( + f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; " + f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}." + ) + self._resolved_expert_dtype = expert_dtype + from vllm.logger import init_logger + + init_logger(__name__).info_once( + "DeepSeek V4 expert_dtype resolved to %r", expert_dtype + ) + return self._resolved_expert_dtype + + @property + def is_scale_e8m0(self) -> bool: + # FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert + # checkpoints (Flash-Base) store them as float32. + return self.expert_dtype == "fp4" @classmethod def get_name(cls) -> QuantizationMethods: @@ -155,11 +199,14 @@ def get_quant_method(self, layer, prefix): fused_mapping=self.packed_modules_mapping, ): return UnquantizedFusedMoEMethod(layer.moe_config) - return Mxfp4MoEMethod(layer.moe_config) + if self.expert_dtype == "fp4": + return Mxfp4MoEMethod(layer.moe_config) + # expert_dtype == "fp8": fall through to Fp8Config which + # returns Fp8MoEMethod with block-wise float32 scales. return super().get_quant_method(layer, prefix) def is_mxfp4_quant(self, prefix, layer): - return isinstance(layer, FusedMoE) + return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4" @triton.jit @@ -689,6 +736,12 @@ def __init__( raise NotImplementedError( "DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only." ) + if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4": + raise NotImplementedError( + "DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype=" + f"{config.expert_dtype!r}. Drop --kernel-config moe_backend=" + "deep_gemm_mega_moe for this checkpoint." + ) self.gate = GateLinear( config.hidden_size, @@ -872,7 +925,7 @@ def __init__( vllm_config: VllmConfig, prefix: str, topk_indices_buffer: torch.Tensor | None = None, - aux_stream: torch.cuda.Stream | None = None, + aux_stream_list: list[torch.cuda.Stream] | None = None, ): super().__init__() config = vllm_config.model_config.hf_config @@ -974,7 +1027,6 @@ def __init__( max_position=self.max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=False, - dtype=config.torch_dtype, ) self.indexer = None @@ -1005,7 +1057,7 @@ def __init__( indexer=self.indexer, indexer_rotary_emb=self.rotary_emb, topk_indices_buffer=topk_indices_buffer, - aux_stream=aux_stream, + aux_stream_list=aux_stream_list, ) self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper( hidden_size=self.hidden_size, @@ -1041,9 +1093,14 @@ def __init__( vllm_config, prefix, topk_indices_buffer: torch.Tensor | None = None, - aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream] | None = None, + aux_stream_list: list[torch.cuda.Stream] | None = None, ): super().__init__() + + # Lazy import to avoid top-level tilelang dependency. + # Registers both torch.ops.vllm.mhc_pre and mhc_post + import vllm.model_executor.layers.mhc # noqa: F401 + config = vllm_config.model_config.hf_config self.hidden_size = config.hidden_size @@ -1052,9 +1109,7 @@ def __init__( vllm_config, prefix=f"{prefix}.attn", topk_indices_buffer=topk_indices_buffer, - aux_stream=aux_stream_dict.get(AuxStreamType.Attention) - if aux_stream_dict is not None - else None, + aux_stream_list=aux_stream_list, ) self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") @@ -1116,11 +1171,6 @@ def hc_pre( hc_scale: torch.Tensor, hc_base: torch.Tensor, ): - # Lazy import to avoid top-level tilelang dependency. - # Registers both torch.ops.vllm.mhc_pre and mhc_post, - # so hc_post() doesn't need its own import. - import vllm.model_executor.layers.mhc # noqa: F401 - post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre( residual=x, fn=hc_fn, @@ -1182,10 +1232,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - aux_stream_list = [torch.cuda.Stream() for _ in range(1)] - self.aux_stream_dict = { - AuxStreamType.Attention: aux_stream_list[0], - } + # Three aux streams: one per non-default input GEMM in + # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute + # (compressor kv_score, indexer.weights_proj, indexer.compressor + # kv_score). fused_wqa_wkv stays on the default stream. + aux_stream_list = [torch.cuda.Stream() for _ in range(3)] self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. @@ -1209,7 +1260,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config, prefix=prefix, topk_indices_buffer=self.topk_indices_buffer, - aux_stream_dict=self.aux_stream_dict, + aux_stream_list=aux_stream_list, ), prefix=f"{prefix}.layers", ) @@ -1410,10 +1461,24 @@ def hc_head( return y.to(dtype) -class DeepseekV4ForCausalLM(nn.Module): - model_cls = DeepseekV4Model - - hf_to_vllm_mapper = WeightsMapper( +def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: + if expert_dtype == "fp4": + # MXFP4 experts use Mxfp4MoEMethod, which registers scales as + # ``w{1,2,3}_weight_scale`` (no _inv suffix). FP8 linear and + # shared experts use Fp8LinearMethod's block scales, which + # register as ``weight_scale_inv``. + scale_regex = { + re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale", + re.compile(r"\.scale$"): ".weight_scale_inv", + } + else: + # FP8 experts use Fp8MoEMethod (block_quant=True), which registers + # scales as ``w{13,2}_weight_scale_inv``. Map all ``.scale`` keys + # there. + scale_regex = { + re.compile(r"\.scale$"): ".weight_scale_inv", + } + return WeightsMapper( orig_to_new_prefix={ "layers.": "model.layers.", "embed.": "model.embed.", @@ -1421,12 +1486,7 @@ class DeepseekV4ForCausalLM(nn.Module): "hc_head": "model.hc_head", "mtp.": "model.mtp.", }, - orig_to_new_regex={ - # Routed MoE expert scales: experts.N.wX.scale -> .weight_scale - re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale", - # Everything else (FP8 linear + shared experts): .scale -> .weight_scale_inv - re.compile(r"\.scale$"): ".weight_scale_inv", - }, + orig_to_new_regex=scale_regex, orig_to_new_suffix={ "head.weight": "lm_head.weight", "embed.weight": "embed_tokens.weight", @@ -1438,11 +1498,22 @@ class DeepseekV4ForCausalLM(nn.Module): }, ) + +class DeepseekV4ForCausalLM(nn.Module): + model_cls = DeepseekV4Model + + # Default mapper assumes the original FP4-expert checkpoint layout. + # Overridden per-instance in __init__ when expert_dtype != "fp4". + hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config + expert_dtype = getattr(config, "expert_dtype", "fp4") + if expert_dtype != "fp4": + self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype) self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/model_executor/models/deepseek_v4_mtp.py b/vllm/model_executor/models/deepseek_v4_mtp.py index c1f0e3fb5d3a..a3724e5ebe80 100644 --- a/vllm/model_executor/models/deepseek_v4_mtp.py +++ b/vllm/model_executor/models/deepseek_v4_mtp.py @@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils.multi_stream_utils import AuxStreamType from .deepseek_mtp import SharedHead from .deepseek_v2 import get_spec_layer_idx_from_weight_name @@ -48,9 +47,14 @@ logger = init_logger(__name__) -# MoE expert scales are fused into per-layer w13/w2 tensors; other FP8 linear -# scales use `.weight_scale_inv`. Mirrors the regex in -# DeepseekV4ForCausalLM.hf_to_vllm_mapper. +# MoE expert scales are fused into per-layer w13/w2 tensors. The exact +# parameter suffix depends on which FusedMoE method handles the experts: +# - fp4 experts (Mxfp4MoEMethod) register ``w{1,2,3}_weight_scale``; +# - fp8 experts (Fp8MoEMethod with block_quant=True) register +# ``w{1,2,3}_weight_scale_inv``. +# Other FP8 linear scales (including shared experts) always use +# ``.weight_scale_inv``. Mirrors the per-instance mapper built by +# ``_make_deepseek_v4_weights_mapper`` in deepseek_v4.py. _EXPERT_SCALE_RE = re.compile(r"\.experts\.\d+\.w[123]\.scale$") @@ -60,6 +64,7 @@ def __init__( vllm_config: VllmConfig, topk_indices_buffer: torch.Tensor, prefix: str, + aux_stream_list: list[torch.cuda.Stream] | None = None, ) -> None: super().__init__() @@ -107,14 +112,11 @@ def __init__( self.shared_head = SharedHead( config=config, prefix=prefix, quant_config=quant_config ) - self.aux_stream_dict = { - AuxStreamType.Attention: torch.cuda.Stream(), - } self.mtp_block = DeepseekV4DecoderLayer( vllm_config, prefix, topk_indices_buffer=topk_indices_buffer, - aux_stream_dict=self.aux_stream_dict, + aux_stream_list=aux_stream_list, ) def forward( @@ -164,6 +166,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): device=self.device, ) + # Three aux streams shared across all MTP layers, mirroring + # DeepseekV4Model. + aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + # to map the exact layer index from weights self.layers = torch.nn.ModuleDict( { @@ -171,6 +177,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config, self.topk_indices_buffer, f"{prefix}.layers.{idx}", + aux_stream_list=aux_stream_list, ) for idx in range( self.mtp_start_layer_idx, @@ -326,6 +333,15 @@ def _find_mtp_layer_idx(name: str) -> int: num_experts=self.config.n_routed_experts, ) + # FP8 experts register ``..._weight_scale_inv`` (block_quant) while + # FP4/MXFP4 experts register ``..._weight_scale``. Choose the suffix + # for the rename below based on the model's expert dtype. + expert_scale_suffix = ( + ".weight_scale" + if getattr(self.config, "expert_dtype", "fp4") == "fp4" + else ".weight_scale_inv" + ) + for name, loaded_weight in weights: mtp_layer_idx = _find_mtp_layer_idx(name) # V4 checkpoints store MTP weights as `mtp.{i}.*`; remap to @@ -347,7 +363,7 @@ def _find_mtp_layer_idx(name: str) -> int: continue if name.endswith(".scale"): suffix = ( - ".weight_scale" + expert_scale_suffix if _EXPERT_SCALE_RE.search(name) else ".weight_scale_inv" ) diff --git a/vllm/tool_parsers/deepseekv32_tool_parser.py b/vllm/tool_parsers/deepseekv32_tool_parser.py index b8623592365c..02182e22935a 100644 --- a/vllm/tool_parsers/deepseekv32_tool_parser.py +++ b/vllm/tool_parsers/deepseekv32_tool_parser.py @@ -191,12 +191,13 @@ def extract_tool_calls( tool_call_match ): param_dict = self._parse_invoke_params(invoke_content) + params = self._convert_params_with_schema(invoke_name, param_dict) tool_calls.append( ToolCall( type="function", function=FunctionCall( name=invoke_name, - arguments=json.dumps(param_dict, ensure_ascii=False), + arguments=json.dumps(params, ensure_ascii=False), ), ) ) diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py index cc6bc6462449..c00f08f93329 100644 --- a/vllm/utils/multi_stream_utils.py +++ b/vllm/utils/multi_stream_utils.py @@ -56,3 +56,67 @@ def maybe_execute_in_parallel( result0 = fn0() result1 = fn1() return (result0, result1) + + +def execute_in_parallel( + default_fn: Callable[[], Any], + aux_fns: list[Callable[[], Any] | None], + start_event: torch.cuda.Event, + done_events: list[torch.cuda.Event], + aux_streams: list[torch.cuda.Stream] | None = None, +) -> tuple[Any, list[Any]]: + """Run default_fn on the current stream and aux_fns concurrently on + aux_streams. + + Generalizes maybe_execute_in_parallel to N aux callables. Slots where + aux_fns[i] is None are skipped (no stream switch, no event record); their + corresponding entry in the returned aux_results list is None. + + start_event fans out from the current stream to every launched aux stream; + done_events[i] is recorded after aux_fns[i] so the current stream joins + before returning. When aux_streams is None, all aux_fns run sequentially + on the current stream. + + Args: + default_fn: Callable for the default (current) stream. + aux_fns: Per-aux callables; entries may be None to skip. + start_event: CUDA event recorded on the current stream before + default_fn so each launched aux stream can wait on it. + done_events: One CUDA event per aux slot, recorded after the + corresponding aux_fn. Length must match aux_fns. + aux_streams: Per-aux CUDA streams. Length must match aux_fns. + Multi-stream is disabled when None. + + Returns: + Tuple of (default_result, aux_results) where aux_results[i] is the + result of aux_fns[i] (or None when skipped). + """ + aux_results: list[Any] + if aux_streams is None: + default_result = default_fn() + aux_results = [fn() if fn is not None else None for fn in aux_fns] + return default_result, aux_results + + assert len(aux_fns) == len(aux_streams) == len(done_events), ( + "aux_fns, aux_streams, and done_events must be the same length" + ) + + aux_results = [None] * len(aux_fns) + pending: list[torch.cuda.Event] = [] + + start_event.record() + for i, fn in enumerate(aux_fns): + if fn is None: + continue + with torch.cuda.stream(aux_streams[i]): + start_event.wait() + aux_results[i] = fn() + done_events[i].record() + pending.append(done_events[i]) + + default_result = default_fn() + + for ev in pending: + ev.wait() + + return default_result, aux_results diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py index 26b076f34238..2f97d8733c95 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py @@ -21,7 +21,7 @@ from vllm.triton_utils import tl, triton -from .fused_indexer_q import _e2m1_nibble +from .fused_indexer_q import _fp32x2_to_fp4x2 # ============================================================================= @@ -566,18 +566,18 @@ def _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn( tl.max(tl.abs(even_2d), axis=1), tl.max(tl.abs(odd_2d), axis=1), ) - amax = tl.maximum(amax, 1e-4) + amax = tl.maximum(amax, 6.0 * (2**-126)) # ue8m0 block scale: 2^ceil(log2(amax / 6.0)), stored as (exp + 127) byte. - log2_ratio = tl.ceil(tl.log2(amax / 6.0)) + log2_ratio = tl.ceil(tl.log2(amax * (1.0 / 6.0))) log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0) inv_scale = tl.exp2(-log2_ratio) ue8m0 = (log2_ratio + 127.0).to(tl.uint8) # [N_QUANT_BLOCKS] inv_scale_col = tl.reshape(inv_scale, (N_QUANT_BLOCKS, 1)) - lo_nib = _e2m1_nibble(even_2d * inv_scale_col) # (N_BLOCKS, HALF_BLOCK) uint8 - hi_nib = _e2m1_nibble(odd_2d * inv_scale_col) - packed = lo_nib | (hi_nib << 4) + packed = _fp32x2_to_fp4x2( + even_2d * inv_scale_col, odd_2d * inv_scale_col + ) # (N_BLOCKS, HALF_BLOCK) uint8 packed_flat = tl.reshape(packed, (TOKEN_STRIDE,)) tl.store(val_ptr + tl.arange(0, TOKEN_STRIDE), packed_flat) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index 0254a46752c6..f94fc013f5c6 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -24,36 +24,22 @@ def _get_cos_sin( @triton.jit -def _e2m1_nibble(x): - """Quantize fp32 x (already scale-divided) to E2M1 4-bit nibble in uint8. - Matches torch.bucketize with boundaries - [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] and right=False (each boundary - belongs to the lower bucket), plus sign bit.""" - abs_x = tl.minimum(tl.abs(x), 6.0) - code = tl.where( - abs_x <= 0.25, - 0.0, - tl.where( - abs_x <= 0.75, - 1.0, - tl.where( - abs_x <= 1.25, - 2.0, - tl.where( - abs_x <= 1.75, - 3.0, - tl.where( - abs_x <= 2.5, - 4.0, - tl.where(abs_x <= 3.5, 5.0, tl.where(abs_x <= 5.0, 6.0, 7.0)), - ), - ), - ), - ), - ) - code_u8 = code.to(tl.uint8) - sign = ((x < 0) & (code_u8 != 0)).to(tl.uint8) - return code_u8 | (sign << 3) +def _fp32x2_to_fp4x2(x_lo, x_hi): + # NOTE: $1 is high nibble, $2 is low nibble + return tl.inline_asm_elementwise( + """ + { + .reg .b8 tmp; + cvt.rn.satfinite.e2m1x2.f32 tmp, $1, $2; + cvt.u32.u8 $0, tmp; + } + """, + constraints="=r,f,f", + args=[x_hi, x_lo], + dtype=tl.uint32, + is_pure=True, + pack=1, + ).to(tl.uint8) @triton.jit @@ -65,17 +51,16 @@ def _quantize_mxfp4_pair(x_lo, x_hi): - ue8m0 : scalar uint8 (block scale = 2^(ue8m0 - 127)) """ amax = tl.maximum(tl.max(tl.abs(x_lo)), tl.max(tl.abs(x_hi))) - amax = tl.maximum(amax, 1e-4) + # 6 * 2^-126 is from https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/inference/kernel.py#L163 + amax = tl.maximum(amax, 6.0 * (2**-126)) # ue8m0 block scale: 2^ceil(log2(amax/6.0)). - log2_ratio = tl.math.ceil(tl.math.log2(amax / 6.0)) + log2_ratio = tl.math.ceil(tl.math.log2(amax * (1.0 / 6.0))) log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0) scale = tl.math.exp2(log2_ratio) ue8m0 = (log2_ratio + 127.0).to(tl.uint8) inv_scale = 1.0 / scale - lo_nib = _e2m1_nibble(x_lo * inv_scale) - hi_nib = _e2m1_nibble(x_hi * inv_scale) - packed = lo_nib | (hi_nib << 4) + packed = _fp32x2_to_fp4x2(x_lo * inv_scale, x_hi * inv_scale) return packed, ue8m0 diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py index 97c9538889a1..84647d6120d8 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py @@ -10,6 +10,7 @@ import torch from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit @@ -180,34 +181,74 @@ def fused_inv_rope_fp8_quant( fp8_dtype = torch.float8_e4m3fn fp8_max = torch.finfo(fp8_dtype).max - fp8_buf = torch.empty( - (n_groups, num_tokens, d), - dtype=fp8_dtype, - device=o.device, - ) - tma_aligned_T = get_tma_aligned_size(num_tokens, 4) if tma_aligned_scales: packed_sf_k = (num_scale_blocks + 3) // 4 - scale_buf = torch.empty( - n_groups * packed_sf_k * tma_aligned_T, - dtype=torch.int32, - device=o.device, - ).as_strided( - (n_groups, num_tokens, packed_sf_k), - (packed_sf_k * tma_aligned_T, 1, tma_aligned_T), - ) + scale_inner = packed_sf_k else: - scale_buf = torch.empty( - n_groups * num_scale_blocks * tma_aligned_T, - dtype=torch.float32, - device=o.device, - ).as_strided( - (n_groups, num_tokens, num_scale_blocks), - (num_scale_blocks * tma_aligned_T, 1, tma_aligned_T), - ) + scale_inner = num_scale_blocks + + # Run kernel through a custom op so inductor sees an opaque boundary. + # It's a pytorch bug, see https://github.com/vllm-project/vllm/issues/41106 + fp8_buf, scale_buf = torch.ops.vllm.fused_inv_rope_fp8_quant_kernel( + o, + positions, + cos_sin_cache, + heads_per_group, + quant_group_size, + chunks_per_head, + nope_dim % quant_group_size, + rope_dim // 2, + tma_aligned_scales, + fp8_max, + tma_aligned_T, + num_tokens, + n_groups, + d, + scale_inner, + ) + return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1) + - common_args = dict( +def _fused_inv_rope_fp8_quant_kernel_impl( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + heads_per_group: int, + quant_group_size: int, + chunks_per_head: int, + rope_start: int, + half_rope: int, + tma_aligned_scales: bool, + fp8_max: float, + tma_aligned_T: int, + num_tokens: int, + n_groups: int, + d: int, + scale_inner: int, +) -> tuple[torch.Tensor, torch.Tensor]: + fp8_buf = torch.empty( + (n_groups, num_tokens, d), + dtype=torch.float8_e4m3fn, + device=o.device, + ) + scale_dtype = torch.int32 if tma_aligned_scales else torch.float32 + scale_buf = torch.empty( + n_groups * scale_inner * tma_aligned_T, + dtype=scale_dtype, + device=o.device, + ).as_strided( + (n_groups, num_tokens, scale_inner), + (scale_inner * tma_aligned_T, 1, tma_aligned_T), + ) + grid = (tma_aligned_T, n_groups * heads_per_group) + _fused_inv_rope_fp8_quant_per_head[grid]( + o, + positions, + cos_sin_cache, + fp8_buf, + scale_buf, + num_tokens, heads_per_group=heads_per_group, o_stride_token=o.stride(0), o_stride_head=o.stride(1), @@ -220,23 +261,52 @@ def fused_inv_rope_fp8_quant( eps=1e-10, QUANT_GROUP_SIZE=quant_group_size, CHUNKS_PER_HEAD=chunks_per_head, - ROPE_START=nope_dim % quant_group_size, - HALF_ROPE=rope_dim // 2, + ROPE_START=rope_start, + HALF_ROPE=half_rope, TMA_ALIGNED_SCALES=tma_aligned_scales, num_stages=1, launch_pdl=False, + num_warps=1, ) + return fp8_buf, scale_buf - grid = (tma_aligned_T, n_groups * heads_per_group) - _fused_inv_rope_fp8_quant_per_head[grid]( - o, - positions, - cos_sin_cache, - fp8_buf, - scale_buf, - num_tokens, - **common_args, - num_warps=1, + +def _fused_inv_rope_fp8_quant_kernel_fake( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + heads_per_group: int, + quant_group_size: int, + chunks_per_head: int, + rope_start: int, + half_rope: int, + tma_aligned_scales: bool, + fp8_max: float, + tma_aligned_T: int, + num_tokens: int, + n_groups: int, + d: int, + scale_inner: int, +) -> tuple[torch.Tensor, torch.Tensor]: + fp8_buf = torch.empty( + (n_groups, num_tokens, d), + dtype=torch.float8_e4m3fn, + device=o.device, ) + scale_dtype = torch.int32 if tma_aligned_scales else torch.float32 + scale_buf = torch.empty( + n_groups * scale_inner * tma_aligned_T, + dtype=scale_dtype, + device=o.device, + ).as_strided( + (n_groups, num_tokens, scale_inner), + (scale_inner * tma_aligned_T, 1, tma_aligned_T), + ) + return fp8_buf, scale_buf - return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1) + +direct_register_custom_op( + op_name="fused_inv_rope_fp8_quant_kernel", + op_func=_fused_inv_rope_fp8_quant_kernel_impl, + fake_impl=_fused_inv_rope_fp8_quant_kernel_fake, +) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 879cd0928c1e..fb4b9ca3b7cb 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -890,31 +890,48 @@ def get_max_concurrency_for_kv_cache_config( return max_concurrency -def may_override_num_blocks( - vllm_config: VllmConfig, num_blocks: int, suppress_log: bool = False -) -> int: +def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: """ Override the number of kv cache blocks if `num_gpu_blocks_override` is set. + The override is logged once, at the call site in `get_kv_cache_configs`. """ if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = vllm_config.cache_config.num_gpu_blocks_override - if not suppress_log: - logger.info( - "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", - num_blocks, - num_gpu_blocks_override, - ) - num_blocks = num_gpu_blocks_override - + num_blocks = vllm_config.cache_config.num_gpu_blocks_override return num_blocks +def _pool_bytes_per_block(kv_cache_groups: list[KVCacheGroupSpec]) -> int: + """ + Bytes consumed by one block in the worker's shared KV cache pool, mirroring + the divisor used by `get_kv_cache_config_from_groups` to convert + `available_memory` into `num_blocks`. Used to compute the effective KV cache + capacity once `num_gpu_blocks_override` is applied. + """ + if len(kv_cache_groups) == 1 and isinstance( + kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs + ): + return kv_cache_groups[0].kv_cache_spec.page_size_bytes + if all( + isinstance(g.kv_cache_spec, UniformTypeKVCacheSpecs) for g in kv_cache_groups + ): + # DeepseekV4: shared layout sized by the largest per-page-size bucket. + full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec) + layer_tuple_page_bytes = sum(full_mla_spec.get_page_sizes()) + num_layer_tuples = max( + cast(UniformTypeKVCacheSpecs, g.kv_cache_spec).get_num_layer_tuples() + for g in kv_cache_groups + ) + return layer_tuple_page_bytes * num_layer_tuples + group_size = max(len(g.layer_names) for g in kv_cache_groups) + page_size = get_uniform_page_size([g.kv_cache_spec for g in kv_cache_groups]) + return page_size * group_size + + def get_num_blocks( vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int, - suppress_log: bool = False, ) -> int: """ Get the number of kv cache blocks. @@ -924,15 +941,10 @@ def get_num_blocks( num_layers: The number of layers available_memory: Memory available for KV cache in bytes. page_size: The page size of the KV cache. - suppress_log: Whether to suppress override log messages. Used when creating a - temporary/dummy KV cache config, e.g. during CG memory profiling """ num_blocks = int(available_memory // page_size // num_layers) num_blocks = max(num_blocks, 0) - num_blocks = may_override_num_blocks( - vllm_config, num_blocks, suppress_log=suppress_log - ) - return num_blocks + return may_override_num_blocks(vllm_config, num_blocks) def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int: @@ -1220,7 +1232,6 @@ def get_kv_cache_config_from_groups( vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec], available_memory: int, - suppress_log: bool = False, ) -> KVCacheConfig: """ Generate the KV cache configuration from the KV cache groups and spec @@ -1252,9 +1263,7 @@ def get_kv_cache_config_from_groups( num_blocks = ( available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes ) - num_blocks = may_override_num_blocks( - vllm_config, num_blocks, suppress_log=suppress_log - ) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs kv_cache_tensors = [ KVCacheTensor( @@ -1288,11 +1297,7 @@ def get_kv_cache_config_from_groups( ) assert group_size > 0, "group_size must be greater than 0" num_blocks = get_num_blocks( - vllm_config, - group_size, - available_memory, - page_size, - suppress_log=suppress_log, + vllm_config, group_size, available_memory, page_size ) kv_cache_tensors = [] for i in range(group_size): @@ -1686,36 +1691,24 @@ def _report_kv_cache_config( vllm_config: The global VllmConfig kv_cache_config: The resolved KV cache configuration """ - min_block_size = min( - [group.kv_cache_spec.block_size for group in kv_cache_config.kv_cache_groups] - ) - - # Log the KV cache size and maximum concurrency. - num_tokens = ( - kv_cache_config.num_blocks - // len(kv_cache_config.kv_cache_groups) - * min_block_size - ) - dcp_size = vllm_config.parallel_config.decode_context_parallel_size - pcp_size = vllm_config.parallel_config.prefill_context_parallel_size - if pcp_size * dcp_size > 1: - num_tokens *= pcp_size * dcp_size - logger.info( - "Multiplying the GPU KV cache size by the cp_world_size %d " - "(pcp_world_size %d * dcp_world_size %d).", - pcp_size * dcp_size, - pcp_size, - dcp_size, - ) - num_tokens_str = f"{num_tokens:,}" - logger.info_once("GPU KV cache size: %s tokens", num_tokens_str) - max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_model_len = vllm_config.model_config.max_model_len max_concurrency = get_max_concurrency_for_kv_cache_config( vllm_config, kv_cache_config ) + + # GPU KV cache size in tokens = max_concurrency * max_model_len: the total + # tokens of context the pool can hold at peak utilization. Sourcing this + # from the concurrency calculation handles hybrid layouts correctly: SWA / + # chunked-local groups have a per-request block count that's capped by + # their window, so a naive `num_blocks // num_groups * block_size` formula + # underestimates capacity for these models. DCP/PCP sharding is already + # accounted for in each spec's `max_memory_usage_bytes`. + num_tokens = int(max_concurrency * max_model_len) + + logger.info_once("GPU KV cache size: %s tokens", f"{num_tokens:,}") logger.info_once( "Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, + f"{max_model_len:,}", max_concurrency, ) @@ -1986,6 +1979,28 @@ def get_kv_cache_configs( for worker_spec in kv_cache_specs ] + # If `num_gpu_blocks_override` is set, the cache size that will actually + # be allocated is decoupled from the profiled `available_memory`: + # `may_override_num_blocks` in `get_kv_cache_config_from_groups` clamps + # `num_blocks` to the override. Reflect that in `available_memory` here so + # auto-fit, the admission check, and the per-worker config builder all + # plan against the same effective capacity. + override = vllm_config.cache_config.num_gpu_blocks_override + if override is not None: + adjusted_memory: list[int] = [] + for groups, avail_mem in zip(projected_groups_per_worker, available_memory): + if not groups: + adjusted_memory.append(avail_mem) + continue + bytes_per_block = _pool_bytes_per_block(groups) + logger.info( + "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", + avail_mem // bytes_per_block, + override, + ) + adjusted_memory.append(override * bytes_per_block) + available_memory = adjusted_memory + if vllm_config.model_config.original_max_model_len == -1: _auto_fit_max_model_len( vllm_config, projected_groups_per_worker, available_memory diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a0ba47f945a7..caf3bfdfc3a8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5874,7 +5874,7 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: saved_override = self.cache_config.num_gpu_blocks_override self.cache_config.num_gpu_blocks_override = min_blocks minimal_config = get_kv_cache_config_from_groups( - self.vllm_config, kv_cache_groups, available_memory=0, suppress_log=True + self.vllm_config, kv_cache_groups, available_memory=0 ) self.cache_config.num_gpu_blocks_override = saved_override