diff --git a/CMakeLists.txt b/CMakeLists.txt index 272c48f7ff9a..387386066c21 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -294,6 +294,7 @@ set(VLLM_EXT_SRC "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" "csrc/fused_qknorm_rope_kernel.cu" + "csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu" "csrc/layernorm_quant_kernels.cu" "csrc/sampler.cu" "csrc/topk.cu" @@ -357,11 +358,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # are not supported by Machete yet. # marlin arches for fp16 output - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0+PTX;12.0;12.1" "${CUDA_ARCHS}") # marlin has limited support for turing cuda_archs_loose_intersection(MARLIN_SM75_ARCHS "7.5" "${CUDA_ARCHS}") # marlin arches for bf16 output (we need 9.0 for bf16 atomicAdd PTX) - cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_BF16_ARCHS "8.0+PTX;9.0+PTX;12.0;12.1" "${CUDA_ARCHS}") # marlin arches for fp8 input # - sm80 doesn't support fp8 computation # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction @@ -1045,7 +1046,8 @@ endif() set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" - "csrc/moe/topk_softmax_kernels.cu") + "csrc/moe/topk_softmax_kernels.cu" + "csrc/moe/topk_softplus_sqrt_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC @@ -1078,7 +1080,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # moe marlin arches # note that we always set `use_atomic_add=False` for moe marlin now, # so we don't need 9.0 for bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0+PTX;12.0;12.1" "${CUDA_ARCHS}") # moe marlin has limited support for turing cuda_archs_loose_intersection(MARLIN_MOE_SM75_ARCHS "7.5" "${CUDA_ARCHS}") # moe marlin arches for fp8 input diff --git a/cmake/external_projects/deepgemm.cmake b/cmake/external_projects/deepgemm.cmake index c3a48a64fc77..f38f1189cf5e 100644 --- a/cmake/external_projects/deepgemm.cmake +++ b/cmake/external_projects/deepgemm.cmake @@ -19,8 +19,8 @@ else() # This ref should be kept in sync with tools/install_deepgemm.sh FetchContent_Declare( deepgemm - GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM.git - GIT_TAG 477618cd51baffca09c4b0b87e97c03fe827ef03 + GIT_REPOSITORY https://github.com/jasl/DeepGEMM.git + GIT_TAG 7a7a41a1bac7dacabe74057e7600e59f98f85bce GIT_SUBMODULES "third-party/cutlass" "third-party/fmt" GIT_PROGRESS TRUE CONFIGURE_COMMAND "" @@ -46,6 +46,9 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9) elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8) list(APPEND DEEPGEMM_SUPPORT_ARCHS "10.0a") endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + list(APPEND DEEPGEMM_SUPPORT_ARCHS "12.0f") +endif() cuda_archs_loose_intersection(DEEPGEMM_ARCHS "${DEEPGEMM_SUPPORT_ARCHS}" "${CUDA_ARCHS}") @@ -120,6 +123,11 @@ if(DEEPGEMM_ARCHS) COMPONENT _deep_gemm_C FILES_MATCHING PATTERN "*.py") + install(DIRECTORY "${deepgemm_SOURCE_DIR}/deep_gemm/mega/" + DESTINATION vllm/third_party/deep_gemm/mega + COMPONENT _deep_gemm_C + FILES_MATCHING PATTERN "*.py") + # Generate envs.py (normally generated by DeepGEMM's setup.py build step) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/deep_gemm_envs.py" "# Pre-installed environment variables\npersistent_envs = dict()\n") diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 0f16b9161fa3..65986df55012 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA - GIT_TAG 692917b1cda61b93ac9ee2d846ec54e75afe87b1 + GIT_TAG a6ec2ba7bd0a7dff98b3f4d3e6b52b159c48d78b GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu new file mode 100644 index 000000000000..56b5d71270ca --- /dev/null +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -0,0 +1,477 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright contributors to the vLLM project + * + * Horizontally-fused DeepseekV4-MLA kernel: + * - Q side: per-head RMSNorm (no weight) + GPT-J RoPE on last ROPE_DIM + * - KV side: GPT-J RoPE on last ROPE_DIM + UE8M0 FP8 quant on NoPE + paged + * cache insert + * + * Structured after `applyMLARopeAndAssignQKVKernelGeneration` in + * TensorRT-LLM's mlaKernels.cu: one kernel, one grid, with head-slot + * dispatch choosing Q vs KV work per warp. The per-warp RMSNorm/RoPE + * skeleton is adapted from vllm-deepseek_v4's existing + * `fusedQKNormRopeKernel` (csrc/fused_qknorm_rope_kernel.cu). + * + * Assumptions (hard-coded for DeepseekV4 attention): + * HEAD_DIM = 512 + * ROPE_DIM = 64 (RoPE applied to dims [NOPE_DIM, HEAD_DIM)) + * NOPE_DIM = 448 + * QUANT_BLOCK = 64 (UE8M0 FP8 quant block) + * FP8_MAX = 448.0f + * is_neox=false (GPT-J interleaved pairs) + * cos_sin_cache layout [max_pos, rope_dim] = cos || sin (cos first, sin + * second along last dim; each half is rope_dim/2 = 32 values) + * + * Cache layout per paged-cache block (block_size tokens): + * [0, bs*576): token data, 448 fp8 + 128 bf16 each + * [bs*576, bs*576 + bs*8): UE8M0 scales, 7 real + 1 pad per token + */ + +#include +#include +#include +#include + +#include +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" +#include "type_convert.cuh" + +#ifndef FINAL_MASK + #define FINAL_MASK 0xffffffffu +#endif + +namespace vllm { +namespace deepseek_v4_fused_ops { + +namespace { +inline int getSMVersion() { + auto* props = at::cuda::getCurrentDeviceProperties(); + return props->major * 10 + props->minor; +} +} // namespace + +// ──────────────────────────────────────────────────────────────────────────── +// Constants +// ──────────────────────────────────────────────────────────────────────────── +constexpr int kHeadDim = 512; +constexpr int kRopeDim = 64; +constexpr int kNopeDim = kHeadDim - kRopeDim; // 448 +constexpr int kQuantBlock = 64; +constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 +constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) +constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 +constexpr float kFp8Max = 448.0f; + +// Per-warp layout: 32 lanes × 16 elems/lane = 512 elems = HEAD_DIM. +constexpr int kNumLanes = 32; +constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16 + +// ──────────────────────────────────────────────────────────────────────────── +// Small inline helpers +// ──────────────────────────────────────────────────────────────────────────── +__device__ __forceinline__ float warp4MaxAbs(float val) { + // Reduce absolute max across 4 consecutive lanes (lane id & 3 group). + float peer = __shfl_xor_sync(FINAL_MASK, val, 1); + val = fmaxf(val, peer); + peer = __shfl_xor_sync(FINAL_MASK, val, 2); + val = fmaxf(val, peer); + return val; +} + +template +__device__ __forceinline__ float warpSum(float val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + } + return val; +} + +// ──────────────────────────────────────────────────────────────────────────── +// Kernel +// ──────────────────────────────────────────────────────────────────────────── +// +// Grid: 1D, gridDim.x = ceil(num_tokens_full * (num_heads_q + 1) / +// warps_per_block) Block: blockDim.x = 256 threads (8 warps per block) Each +// warp handles one (token, head_slot) pair. head_slot < num_heads_q → +// Q branch (RMSNorm + RoPE, in place) head_slot == num_heads_q → KV +// branch (RoPE + UE8M0 quant + insert) +// +// With DP padding, q/kv/position_ids can have more rows than slot_mapping. +// The Q branch covers all `num_tokens_full` rows (downstream attention uses +// them). The KV branch only inserts the first `num_tokens_insert` tokens +// (= slot_mapping length) into the paged cache. +// +template +__global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( + scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place + scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16 + uint8_t* __restrict__ k_cache, // [num_blocks, block_stride] + int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64 + int64_t const* __restrict__ position_ids, // [N] i64 + float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 + float const eps, + int const num_tokens_full, // = q.size(0) = kv.size(0) + int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full + int const num_heads_q, // H + int const cache_block_size, // tokens per paged-cache block + int const kv_block_stride) { // bytes per paged-cache block +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) + // BF16 _typeConvert specialization is unavailable on pre-Ampere. The + // DeepseekV4 kernel only runs with bf16 inputs in practice, so compile a + // no-op stub for sm_70/sm_75 to keep multi-arch builds happy. + if constexpr (std::is_same_v) { + return; + } else { +#endif + using Converter = vllm::_typeConvert; + + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; + + int const total_slots_per_token = num_heads_q + 1; + int const tokenIdx = globalWarpIdx / total_slots_per_token; + int const slotIdx = globalWarpIdx % total_slots_per_token; + if (tokenIdx >= num_tokens_full) return; + + bool const isKV = (slotIdx == num_heads_q); + // KV branch: skip DP-padded tokens (no slot reserved for them). + if (isKV && tokenIdx >= num_tokens_insert) return; + + // PDL: wait for predecessor kernel (upstream q/kv producer) to signal + // before touching any global memory. No-op when PDL is not enabled on + // the launch. The CUDA runtime wrapper emits the griddepcontrol.wait + // PTX with the required memory clobber internally. +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaGridDependencySynchronize(); +#endif + + // Dim range this lane owns within the 512-wide head. + int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16 + + // ── Load 16 bf16 → 16 fp32 registers (one 16-byte + one 16-byte LDG) ──── + float elements[kElemsPerLane]; + float sumOfSquares = 0.0f; + + scalar_t_in const* src_ptr; + if (isKV) { + src_ptr = kv_in + static_cast(tokenIdx) * kHeadDim + dim_base; + } else { + int64_t const q_row_offset = + (static_cast(tokenIdx) * num_heads_q + slotIdx) * kHeadDim + + dim_base; + src_ptr = q_inout + q_row_offset; + } + + // Two 16-byte loads per thread (8 bf16 each). Use uint4 as the vector + // type and bitcast to scalar_t_in packed pairs for conversion. + uint4 v0 = *reinterpret_cast(src_ptr); + uint4 v1 = *reinterpret_cast(src_ptr + 8); + + { + typename Converter::packed_hip_type const* p0 = + reinterpret_cast(&v0); + typename Converter::packed_hip_type const* p1 = + reinterpret_cast(&v1); +// Each packed_hip_type holds 2 bf16 → 4 packed = 8 elems per uint4. +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 f2 = Converter::convert(p0[i]); + elements[2 * i] = f2.x; + elements[2 * i + 1] = f2.y; + } +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 f2 = Converter::convert(p1[i]); + elements[8 + 2 * i] = f2.x; + elements[8 + 2 * i + 1] = f2.y; + } + } + + // ── Q branch: RMSNorm with no weight (has_weight=False) ───────────────── + // Variance + rsqrt + multiply all in fp32, no intermediate bf16 round. + // The downstream bf16 round only happens at the final store. + if (!isKV) { +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + sumOfSquares += elements[i] * elements[i]; + } + sumOfSquares = warpSum(sumOfSquares); + float const rms_rcp = + rsqrtf(sumOfSquares / static_cast(kHeadDim) + eps); +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + elements[i] = elements[i] * rms_rcp; + } + } + + // ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ───────────────────────────── + // All math in fp32. cos_sin_cache is loaded as fp32 (its native storage). + bool const is_rope_lane = dim_base >= kNopeDim; + if (is_rope_lane) { + int64_t const pos = position_ids[tokenIdx]; + constexpr int kHalfRope = kRopeDim / 2; // 32 + float const* cos_ptr = cos_sin_cache + pos * kRopeDim; + float const* sin_ptr = cos_ptr + kHalfRope; + + int const rope_local_base = dim_base - kNopeDim; // in [0, 64) step 16 +#pragma unroll + for (int p = 0; p < kElemsPerLane / 2; p++) { + int const pair_dim = rope_local_base + 2 * p; + int const half_idx = pair_dim / 2; + float const cos_v = VLLM_LDG(cos_ptr + half_idx); + float const sin_v = VLLM_LDG(sin_ptr + half_idx); + float const x_even = elements[2 * p]; + float const x_odd = elements[2 * p + 1]; + elements[2 * p] = x_even * cos_v - x_odd * sin_v; + elements[2 * p + 1] = x_even * sin_v + x_odd * cos_v; + } + } + + // ═══════════════════════════════════════════════════════════════════════ + // Q branch: cast to bf16 and store back in place. + // ═══════════════════════════════════════════════════════════════════════ + if (!isKV) { + uint4 out0, out1; + typename Converter::packed_hip_type* po0 = + reinterpret_cast(&out0); + typename Converter::packed_hip_type* po1 = + reinterpret_cast(&out1); +#pragma unroll + for (int i = 0; i < 4; i++) { + po0[i] = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + po1[i] = Converter::convert( + make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + } + scalar_t_in* dst = + q_inout + + (static_cast(tokenIdx) * num_heads_q + slotIdx) * kHeadDim + + dim_base; + *reinterpret_cast(dst) = out0; + *reinterpret_cast(dst + 8) = out1; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + return; + } + + // ═══════════════════════════════════════════════════════════════════════ + // KV branch. + // ═══════════════════════════════════════════════════════════════════════ + int64_t const slot_id = slot_mapping[tokenIdx]; + if (slot_id < 0) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + return; + } + + int64_t const block_idx = slot_id / cache_block_size; + int64_t const pos_in_block = slot_id % cache_block_size; + uint8_t* block_base = + k_cache + block_idx * static_cast(kv_block_stride); + uint8_t* token_fp8_ptr = block_base + pos_in_block * kTokenDataBytes; + uint8_t* token_bf16_ptr = token_fp8_ptr + kNopeDim; + uint8_t* token_scale_ptr = + block_base + static_cast(cache_block_size) * kTokenDataBytes + + pos_in_block * kScaleBytesPerToken; + + // Round K to bf16 first, matching the unfused reference path where K is + // materialized as bf16 before K quantization. absmax, clamp, and FP8 + // quant below all run on these bf16-rounded values. +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + elements[i] = Converter::convert(Converter::convert(elements[i])); + } + + // Per-quant-block absmax must be computed by ALL 32 lanes (warp-collective + // shuffle requires full participation). RoPE lanes contribute garbage, + // but their values are gated out below via `!is_rope_lane`. + float local_absmax = 0.0f; +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + local_absmax = fmaxf(local_absmax, fabsf(elements[i])); + } + float const absmax = fmaxf(warp4MaxAbs(local_absmax), 1e-4f); + float const exponent = ceilf(log2f(absmax / kFp8Max)); + float const inv_scale = exp2f(-exponent); + + if (!is_rope_lane) { + // ── NoPE lane: UE8M0 FP8 quant ─────────────────────────────────────── + alignas(16) uint8_t out_bytes[kElemsPerLane]; +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + float scaled = elements[i] * inv_scale; + scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max); + __nv_fp8_storage_t s = + __nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3); + out_bytes[i] = static_cast(s); + } + // One 16-byte STG per lane. + *reinterpret_cast(token_fp8_ptr + dim_base) = + *reinterpret_cast(out_bytes); + + // Lane (4k) of each 4-lane group writes the scale byte for block k<7. + if ((laneId & 3) == 0) { + int const q_block_idx = laneId >> 2; // 0..6 for NoPE lanes + float encoded = fmaxf(fminf(exponent + 127.0f, 255.0f), 0.0f); + token_scale_ptr[q_block_idx] = static_cast(encoded); + } + // Lane 0 also writes the padding byte at index 7. + if (laneId == 0) { + token_scale_ptr[kNumQuantBlocks] = 0; // pad + } + } else { + // ── RoPE lane: cast back to bf16 and store to cache bf16 tail ──────── + uint4 out0, out1; + typename Converter::packed_hip_type* po0 = + reinterpret_cast(&out0); + typename Converter::packed_hip_type* po1 = + reinterpret_cast(&out1); +#pragma unroll + for (int i = 0; i < 4; i++) { + po0[i] = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + po1[i] = Converter::convert( + make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + } + int const rope_local_base = dim_base - kNopeDim; // in [0, 64) + scalar_t_in* bf16_dst = + reinterpret_cast(token_bf16_ptr) + rope_local_base; + *reinterpret_cast(bf16_dst) = out0; + *reinterpret_cast(bf16_dst + 8) = out1; + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) + } +#endif +} + +// ──────────────────────────────────────────────────────────────────────────── +// Launch wrapper +// ──────────────────────────────────────────────────────────────────────────── +template +void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( + scalar_t_in* q_inout, scalar_t_in const* kv_in, uint8_t* k_cache, + int64_t const* slot_mapping, int64_t const* position_ids, + float const* cos_sin_cache, float const eps, int const num_tokens_full, + int const num_tokens_insert, int const num_heads_q, + int const cache_block_size, int const kv_block_stride, + cudaStream_t stream) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; + int64_t const total_warps = + static_cast(num_tokens_full) * (num_heads_q + 1); + int const grid = + static_cast((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock); + + // PDL: enable programmatic stream serialization whenever the hardware + // supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable, + // so leave numAttrs = 0 and launch as a regular kernel. + static int const sm_version = getSMVersion(); + // Host-side guard: the device kernel body is compiled as a no-op for + // bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert is + // unavailable there. Refuse the launch loudly instead of silently + // skipping the work. + TORCH_CHECK( + sm_version >= 80, + "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert requires sm_80+ " + "(Ampere or newer); got sm_", + sm_version); + cudaLaunchConfig_t config; + config.gridDim = dim3(grid); + config.blockDim = dim3(kBlockSize); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attrs; + config.numAttrs = (sm_version >= 90) ? 1 : 0; + + cudaLaunchKernelEx( + &config, fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, + q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, + num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, + kv_block_stride); +} + +} // namespace deepseek_v4_fused_ops +} // namespace vllm + +// ──────────────────────────────────────────────────────────────────────────── +// Torch op wrapper +// ──────────────────────────────────────────────────────────────────────────── +void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + torch::Tensor& q, // [N, H, 512] bf16, in place + torch::Tensor const& kv, // [N, 512] bf16 (read-only) + torch::Tensor& k_cache, // [num_blocks, block_bytes] uint8 + torch::Tensor const& slot_mapping, // [N] int64 + torch::Tensor const& position_ids, // [N] int64 + torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] bf16 + double eps, int64_t cache_block_size) { + TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA"); + TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); + TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); + TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, + "slot_mapping must be int64 CUDA"); + TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, + "position_ids must be int64 CUDA"); + TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); + TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]"); + TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); + TORCH_CHECK(q.dtype() == kv.dtype(), "q and kv dtype must match"); + TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8"); + TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, + "cos_sin_cache must be float32"); + + // With DP padding, slot_mapping can be shorter than q/kv/positions. + // Q-norm+RoPE runs on all q.size(0) rows (downstream attention uses them); + // KV quant+insert runs only on the first slot_mapping.size(0) rows. + int const num_tokens_full = static_cast(q.size(0)); + int const num_tokens_insert = static_cast(slot_mapping.size(0)); + TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && + static_cast(position_ids.size(0)) == num_tokens_full, + "q/kv/position_ids row counts must match"); + TORCH_CHECK(num_tokens_insert <= num_tokens_full, + "slot_mapping must not exceed q row count"); + int const num_heads_q = static_cast(q.size(1)); + int const cache_block_size_i = static_cast(cache_block_size); + int const kv_block_stride = static_cast(k_cache.stride(0)); + + at::cuda::OptionalCUDAGuard device_guard(device_of(q)); + auto stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_HALF_TYPES( + q.scalar_type(), "fused_deepseek_v4_qnorm_rope_kv_insert", [&] { + using qkv_scalar_t = scalar_t; + vllm::deepseek_v4_fused_ops:: + launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( + reinterpret_cast(q.data_ptr()), + reinterpret_cast(kv.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(position_ids.data_ptr()), + cos_sin_cache.data_ptr(), static_cast(eps), + num_tokens_full, num_tokens_insert, num_heads_q, + cache_block_size_i, kv_block_stride, stream); + }); +} diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 9766103f7646..e617e45dc58b 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -77,7 +77,8 @@ __global__ void rms_norm_kernel( #pragma unroll for (int j = 0; j < VEC_SIZE; j++) { float x = static_cast(src1.val[j]); - dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j]; + float w = static_cast(src2.val[j]); + dst.val[j] = static_cast(x * s_variance * w); } v_out[i] = dst; } @@ -134,10 +135,17 @@ fused_add_rms_norm_kernel( for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; int64_t strided_id = blockIdx.x * vec_input_stride + idx; - _f16Vec temp = residual_v[id]; - temp *= s_variance; - temp *= weight_v[idx]; - input_v[strided_id] = temp; + _f16Vec res = residual_v[id]; + _f16Vec w = weight_v[idx]; + _f16Vec out; + using Converter = _typeConvert; +#pragma unroll + for (int j = 0; j < width; ++j) { + float x = Converter::convert(res.data[j]); + float wf = Converter::convert(w.data[j]); + out.data[j] = Converter::convert(x * s_variance * wf); + } + input_v[strided_id] = out; } } @@ -174,8 +182,8 @@ fused_add_rms_norm_kernel( for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * input_stride + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + float w = (float)weight[idx]; + input[blockIdx.x * input_stride + idx] = (scalar_t)(x * s_variance * w); } } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index d8d962887dab..973190935dfb 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -12,6 +12,15 @@ void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize, std::optional bias); +void topk_softplus_sqrt(torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output, bool renormalize, + double routed_scaling_factor, + const c10::optional& correction_bias, + const c10::optional& input_ids, + const c10::optional& tid2eid); + void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, diff --git a/csrc/moe/topk_softplus_sqrt_kernels.cu b/csrc/moe/topk_softplus_sqrt_kernels.cu new file mode 100644 index 000000000000..13db2f1ce968 --- /dev/null +++ b/csrc/moe/topk_softplus_sqrt_kernels.cu @@ -0,0 +1,710 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu + * Copyright (c) 2024, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "../cuda_compat.h" +#include "../cub_helpers.h" +#ifndef USE_ROCM + #include + #include +#else + #include + #include +typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat162 __nv_bfloat162; +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm { +namespace moe { + +/// Aligned array type +template +struct alignas(Alignment) AlignedArray { + T data[N]; +}; + +template +__device__ __forceinline__ float toFloat(T value) { + if constexpr (std::is_same_v) { + return value; + } else if constexpr (std::is_same_v) { + return __bfloat162float(value); + } else if constexpr (std::is_same_v) { + return __half2float(value); + } +} + +// ====================== TopK softplus_sqrt things +// =============================== + +/* + A Top-K gating softplus_sqrt written to exploit when the number of experts in + the MoE layers are a small power of 2. This allows us to cleanly share the + rows among the threads in a single warp and eliminate communication between + warps (so no need to use shared mem). + + It fuses the sigmoid, max and argmax into a single kernel. + + Limitations: + 1) This implementation is optimized for when the number of experts is a small + power of 2. Additionally it also supports when number of experts is multiple + of 64 which is still faster than the computing sigmoid and topK separately + (only tested on CUDA yet). 2) This implementation assumes k is small, but will + work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ + void topkGatingSoftplusSqrt( + const InputType* input, const bool* finished, float* output, + const int num_rows, IndType* indices, int* source_rows, const int k, + const int start_expert, const int end_expert, const bool renormalize, + double routed_scaling_factor, const float* correction_bias, + const IndType* input_ids, const IndType* tid2eid) { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "InputType must be float, __nv_bfloat16, or __half"); + + // We begin by enforcing compile time assertions and setting up compile time + // constants. + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), + "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + if constexpr (std::is_same_v || + std::is_same_v) { + static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0, + "ELTS_PER_LDG must be 1 or even for 16-bit conversion"); + } + + // Restrictions based on previous section. + static_assert( + VPT % ELTS_PER_LDG == 0, + "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, + "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), + "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, + "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, + "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time + // variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a + // block contains WARPS_PER_CTA warps. This, each block processes a chunk of + // rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each + // thread jumps to the start of the row it will read. + const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the + // first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Finally, we pull in the data from global mem + float row_chunk[VPT]; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert + // to float + if constexpr (std::is_same_v) { + using VecType = AlignedArray; + VecType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const VecType* vec_thread_read_ptr = + reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + } else if constexpr (std::is_same_v) { + if constexpr (ELTS_PER_LDG >= 2) { + using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>; + float2* row_chunk_f2 = reinterpret_cast(row_chunk); + const VecType* vec_thread_read_ptr = + reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + int base_idx_f2 = ii * ELTS_PER_LDG / 2; +#pragma unroll + for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) { + row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2( + *reinterpret_cast(vec.data + jj * 2)); + } + } + } else { // ELTS_PER_LDG == 1 +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + const __nv_bfloat16* scalar_ptr = + thread_read_ptr + ii * THREADS_PER_ROW; + row_chunk[ii] = __bfloat162float(*scalar_ptr); + } + } + } else if constexpr (std::is_same_v) { + if constexpr (ELTS_PER_LDG >= 2) { + using VecType = AlignedArray<__half, ELTS_PER_LDG>; + float2* row_chunk_f2 = reinterpret_cast(row_chunk); + const VecType* vec_thread_read_ptr = + reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + int base_idx_f2 = ii * ELTS_PER_LDG / 2; +#pragma unroll + for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) { + row_chunk_f2[base_idx_f2 + jj] = __half22float2( + *reinterpret_cast(vec.data + jj * 2)); + } + } + } else { // ELTS_PER_LDG == 1 +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW; + row_chunk[ii] = __half2float(*scalar_ptr); + } + } + } + constexpr float threshold = 20.0f; + constexpr float beta = 1.0f; + + // Hash MoE path: indices are predetermined from lookup table + if constexpr (USE_HASH) { + const IndType token_id = input_ids[thread_row]; + const IndType* expert_indices_for_token = tid2eid + token_id * k; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + float val = row_chunk[ii]; + float val_b = val * beta; + val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta; + row_chunk[ii] = sqrtf(val); + } + float selected_sum = 0.f; +#pragma unroll + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expert = expert_indices_for_token[k_idx]; + const int idx = k * thread_row + k_idx; + for (int ii = 0; ii < VPT; ++ii) { + const int group_id = ii / ELTS_PER_LDG; + const int local_id = ii % ELTS_PER_LDG; + const int expert_idx = first_elt_read_by_thread + + group_id * THREADS_PER_ROW * ELTS_PER_LDG + + local_id; + if (expert == expert_idx) { + indices[idx] = expert; + selected_sum += row_chunk[ii]; + break; + } + } + } + // Compute per-thread scale (using warp reduction when renormalizing). + if (renormalize) { +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + selected_sum += + VLLM_SHFL_XOR_SYNC_WIDTH(selected_sum, mask, THREADS_PER_ROW); + } + } + float scale = static_cast(routed_scaling_factor); + if (renormalize) { + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + scale /= denom; + } + +#pragma unroll + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expert = expert_indices_for_token[k_idx]; + const int idx = k * thread_row + k_idx; + for (int ii = 0; ii < VPT; ++ii) { + const int group_id = ii / ELTS_PER_LDG; + const int local_id = ii % ELTS_PER_LDG; + const int expert_idx = first_elt_read_by_thread + + group_id * THREADS_PER_ROW * ELTS_PER_LDG + + local_id; + if (expert == expert_idx) { + output[idx] = row_chunk[ii] * scale; + break; + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif + return; + } + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + float val = row_chunk[ii]; + float val_b = val * beta; + // Compute softplus: log(1 + exp(val)) with numerical stability + // When val > threshold, softplus(x) ≈ x to avoid exp overflow + val = (val_b > threshold) ? val : (__logf(1.0f + __expf(val_b))) / beta; + val = sqrtf(val); + if (correction_bias) { + const int group_id = ii / ELTS_PER_LDG; + const int local_id = ii % ELTS_PER_LDG; + const int expert_idx = first_elt_read_by_thread + + group_id * THREADS_PER_ROW * ELTS_PER_LDG + + local_id; + val = val + correction_bias[expert_idx]; + } + row_chunk[ii] = val; + } + + // Original TopK path: find top-k experts by score + // Now, sigmoid_res contains the sigmoid of the row chunk. Now, I want to find + // the topk elements in each row, along with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + float selected_sum = 0.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; + ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index + // are processed first and only updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads +// reach consensus about the max. This will be useful for K > 1 so that the +// threads can agree on "who" had the max value. That thread can then blank out +// their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = + VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = + VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this + // way + if (other_max > max_val || + (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = + expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to + // global memory. (This will be a single) thread per row of the + // input/output matrices. + const int idx = k * thread_row + k_idx; + if (correction_bias != nullptr) { + max_val -= correction_bias[expert]; + } + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + if (renormalize) { + selected_sum += max_val; + } + } + + // Finally, we clear the value in the thread with the current max if there + // is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = + (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the + // "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be + // between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = + -10000.f; + } + } + } + + // Apply renormalization and routed scaling factor to final weights. + if (thread_group_idx == 0) { + float scale = static_cast(routed_scaling_factor); + if (renormalize) { + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + scale /= denom; + } + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] * scale; + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at +// compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || + EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, + ""); + static constexpr int VECs_PER_THREAD = + MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW; +}; +} // namespace detail + +#define DISPATCH_HASH(use_hash, USE_HASH, ...) \ + if (use_hash) { \ + const bool USE_HASH = true; \ + static_assert(USE_HASH == true, "USE_HASH must be compile-time constant"); \ + __VA_ARGS__ \ + } else { \ + const bool USE_HASH = false; \ + static_assert(USE_HASH == false, \ + "USE_HASH must be compile-time constant"); \ + __VA_ARGS__ \ + } + +template +void topkGatingSoftplusSqrtLauncherHelper( + const InputType* input, const bool* finished, float* output, + IndType* indices, int* source_row, const int num_rows, const int k, + const int start_expert, const int end_expert, const bool renormalize, + double routed_scaling_factor, const float* correction_bias, + const bool use_hash, const IndType* input_ids, const IndType* tid2eid, + cudaStream_t stream) { + static constexpr int BYTES_PER_LDG = + MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS); + using Constants = + detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); + DISPATCH_HASH(use_hash, USE_HASH, { + auto* kernel = + &topkGatingSoftplusSqrt; +#ifndef USE_ROCM + cudaLaunchConfig_t config = {}; + config.gridDim = num_blocks; + config.blockDim = block_dim; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel, input, finished, output, num_rows, + indices, source_row, k, start_expert, end_expert, + renormalize, routed_scaling_factor, correction_bias, + input_ids, tid2eid); +#else + kernel<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, + end_expert, renormalize, routed_scaling_factor, correction_bias, + input_ids, tid2eid); +#endif + }) +} + +#ifndef USE_ROCM + #define LAUNCH_SOFTPLUS_SQRT(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + static_assert(WARP_SIZE == 32, \ + "Unsupported warp size. Only 32 is supported for CUDA"); \ + topkGatingSoftplusSqrtLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \ + routed_scaling_factor, correction_bias, use_hash, input_ids, tid2eid, \ + stream); +#else + #define LAUNCH_SOFTPLUS_SQRT(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + if (WARP_SIZE == 64) { \ + topkGatingSoftplusSqrtLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \ + routed_scaling_factor, correction_bias, use_hash, input_ids, \ + tid2eid, stream); \ + } else if (WARP_SIZE == 32) { \ + topkGatingSoftplusSqrtLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \ + routed_scaling_factor, correction_bias, use_hash, input_ids, \ + tid2eid, stream); \ + } else { \ + assert(false && \ + "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ + } +#endif + +template +void topkGatingSoftplusSqrtKernelLauncher( + const InputType* gating_output, float* topk_weights, IndType* topk_indices, + int* token_expert_indices, const int num_tokens, const int num_experts, + const int topk, const bool renormalize, double routed_scaling_factor, + const float* correction_bias, const bool use_hash, const IndType* input_ids, + const IndType* tid2eid, cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; +#ifndef USE_ROCM + // for bfloat16 dtype, we need 4 bytes loading to make sure num_experts + // elements can be loaded by a warp + static constexpr int BYTES_PER_LDG_MULTIPLE_64 = + (std::is_same_v || + std::is_same_v) + ? 4 + : 8; +#endif + switch (num_experts) { + case 1: + LAUNCH_SOFTPLUS_SQRT(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 2: + LAUNCH_SOFTPLUS_SQRT(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 4: + LAUNCH_SOFTPLUS_SQRT(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 8: + LAUNCH_SOFTPLUS_SQRT(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 16: + LAUNCH_SOFTPLUS_SQRT(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 32: + LAUNCH_SOFTPLUS_SQRT(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 64: + LAUNCH_SOFTPLUS_SQRT(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 128: + LAUNCH_SOFTPLUS_SQRT(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 256: + LAUNCH_SOFTPLUS_SQRT(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + case 512: + LAUNCH_SOFTPLUS_SQRT(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + // (CUDA only) support multiples of 64 when num_experts is not power of 2. + // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of + // num_experts, alternatively we can test 4 bytes loading and enable it in + // future. +#ifndef USE_ROCM + case 192: + LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 320: + LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 384: + LAUNCH_SOFTPLUS_SQRT(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 448: + LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 576: + LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; +#endif + default: { + TORCH_CHECK(false, "Unsupported expert number: ", num_experts); + } + } +} + +} // namespace moe +} // namespace vllm + +template +void dispatch_topk_softplus_sqrt_launch( + const ComputeType* gating_output, torch::Tensor& topk_weights, + torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, + int num_tokens, int num_experts, int topk, bool renormalize, + double routed_scaling_factor, + const c10::optional& correction_bias, + const c10::optional& input_ids, + const c10::optional& tid2eid, cudaStream_t stream) { + const float* bias_ptr = nullptr; + if (correction_bias.has_value()) { + bias_ptr = correction_bias.value().data_ptr(); + } + bool use_hash = false; + if (tid2eid.has_value()) { + TORCH_CHECK(input_ids.has_value(), "input_ids is required for hash MoE"); + use_hash = true; + } + if (topk_indices.scalar_type() == at::ScalarType::Int) { + const int* input_ids_ptr = nullptr; + const int* tid2eid_ptr = nullptr; + if (tid2eid.has_value()) { + input_ids_ptr = input_ids.value().data_ptr(); + tid2eid_ptr = tid2eid.value().data_ptr(); + } + + vllm::moe::topkGatingSoftplusSqrtKernelLauncher( + gating_output, topk_weights.data_ptr(), + topk_indices.data_ptr(), token_expert_indices.data_ptr(), + num_tokens, num_experts, topk, renormalize, routed_scaling_factor, + bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream); + } else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { + const uint32_t* input_ids_ptr = nullptr; + const uint32_t* tid2eid_ptr = nullptr; + if (tid2eid.has_value()) { + input_ids_ptr = input_ids.value().data_ptr(); + tid2eid_ptr = tid2eid.value().data_ptr(); + } + vllm::moe::topkGatingSoftplusSqrtKernelLauncher( + gating_output, topk_weights.data_ptr(), + topk_indices.data_ptr(), token_expert_indices.data_ptr(), + num_tokens, num_experts, topk, renormalize, routed_scaling_factor, + bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream); + } else { + TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); + + const int64_t* input_ids_ptr = nullptr; + const int64_t* tid2eid_ptr = nullptr; + if (tid2eid.has_value()) { + input_ids_ptr = input_ids.value().data_ptr(); + tid2eid_ptr = tid2eid.value().data_ptr(); + } + + vllm::moe::topkGatingSoftplusSqrtKernelLauncher( + gating_output, topk_weights.data_ptr(), + topk_indices.data_ptr(), token_expert_indices.data_ptr(), + num_tokens, num_experts, topk, renormalize, routed_scaling_factor, + bias_ptr, use_hash, input_ids_ptr, tid2eid_ptr, stream); + } +} + +void topk_softplus_sqrt( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices, // [num_tokens, topk] + torch::Tensor& gating_output, // [num_tokens, num_experts] + bool renormalize, double routed_scaling_factor, + const c10::optional& correction_bias, + const c10::optional& input_ids, + const c10::optional& tid2eid) { + const int num_experts = gating_output.size(-1); + const auto num_tokens = gating_output.numel() / num_experts; + const int topk = topk_weights.size(-1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (gating_output.scalar_type() == at::ScalarType::Float) { + dispatch_topk_softplus_sqrt_launch( + gating_output.data_ptr(), topk_weights, topk_indices, + token_expert_indices, num_tokens, num_experts, topk, renormalize, + routed_scaling_factor, correction_bias, input_ids, tid2eid, stream); + } else if (gating_output.scalar_type() == at::ScalarType::Half) { + dispatch_topk_softplus_sqrt_launch<__half>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights, topk_indices, token_expert_indices, num_tokens, + num_experts, topk, renormalize, routed_scaling_factor, correction_bias, + input_ids, tid2eid, stream); + } else if (gating_output.scalar_type() == at::ScalarType::BFloat16) { + dispatch_topk_softplus_sqrt_launch<__nv_bfloat16>( + reinterpret_cast( + gating_output.data_ptr()), + topk_weights, topk_indices, token_expert_indices, num_tokens, + num_experts, topk, renormalize, routed_scaling_factor, correction_bias, + input_ids, tid2eid, stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output data type: ", + gating_output.scalar_type()); + } +} \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7b627a6f8760..ab691e70cd79 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -16,6 +16,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "bias) -> ()"); m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid); + m.def( + "topk_softplus_sqrt(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output, bool renormalize, float " + "routed_scaling_factor, Tensor? " + "bias, Tensor? input_ids, Tensor? tid2eid) -> ()"); + m.impl("topk_softplus_sqrt", torch::kCUDA, &topk_softplus_sqrt); // Calculate the result of moe by summing up the partial results // from all selected experts. m.def("moe_sum(Tensor input, Tensor! output) -> ()"); diff --git a/csrc/ops.h b/csrc/ops.h index f101ab6fd924..821c505b3a02 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -100,6 +100,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, bool is_neox, torch::Tensor& position_ids, int64_t forced_token_heads_per_warp); +void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache, + torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, + torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, @@ -153,7 +158,8 @@ void silu_and_mul_per_block_quant(torch::Tensor& out, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox); + torch::Tensor& cos_sin_cache, bool is_neox, + int64_t rope_dim_offset, bool inverse); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/persistent_topk.cuh b/csrc/persistent_topk.cuh index 694fedad39f1..d6162d52998b 100644 --- a/csrc/persistent_topk.cuh +++ b/csrc/persistent_topk.cuh @@ -18,7 +18,6 @@ namespace persistent { // Constants // ============================================================================ -constexpr int TopK = 2048; constexpr int kThreadsPerBlock = 1024; constexpr int RADIX = 256; @@ -128,11 +127,12 @@ struct RadixRowState { struct PersistentTopKParams { const float* __restrict__ input; // [num_rows, stride] - int32_t* __restrict__ output; // [num_rows, TopK] + int32_t* __restrict__ output; // [num_rows, top_k] int32_t* __restrict__ lengths; // [num_rows] RadixRowState* row_states; // large path: per-group state uint32_t num_rows; uint32_t stride; + uint32_t top_k; // actual k value for output stride uint32_t chunk_size; // large path: elements per CTA uint32_t ctas_per_group; // 1=medium, >1=large uint32_t max_seq_len; // max seq_len across all rows (for early CTA exit) @@ -154,6 +154,7 @@ __device__ __forceinline__ uint32_t decode_bin(float x) { return key >> 5; } +template __device__ __noinline__ void histogram_2048_topk( const float* __restrict__ logits, int32_t* __restrict__ output_indices, int32_t seq_len) { @@ -418,6 +419,7 @@ __device__ __noinline__ void histogram_2048_topk( // by: DarkSharpness // which at the same time is an optimized topk kernel copied from tilelang // kernel +template __device__ __noinline__ void histogram_256_topk( const float* __restrict__ logits, int* __restrict__ output_indices, int logits_offset, int seq_len) { @@ -649,7 +651,7 @@ __device__ __forceinline__ void wait_ge(int* ptr, int target_val, // Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215 // ============================================================================ -template +template __device__ void radix_topk(const float* __restrict__ row_input, int32_t* __restrict__ row_output, uint32_t seq_len, uint32_t my_chunk_start, uint32_t chunk_size, @@ -857,7 +859,7 @@ __device__ void radix_topk(const float* __restrict__ row_input, // see filtered_topk.cuh) // ============================================================================ -template +template __global__ void __launch_bounds__(kThreadsPerBlock, 2) persistent_topk_kernel(PersistentTopKParams params) { const uint32_t tx = threadIdx.x; @@ -915,7 +917,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2) if (row_idx >= params.num_rows) break; const uint32_t seq_len = params.lengths[row_idx]; - int32_t* row_output = params.output + row_idx * TopK; + int32_t* row_output = params.output + row_idx * params.top_k; const float* row_input = params.input + row_idx * params.stride; if (seq_len <= RADIX_THRESHOLD) { @@ -927,19 +929,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2) row_output[i] = (i < seq_len) ? static_cast(i) : -1; } } else if (seq_len <= static_cast(HIST2048_THRESHOLD)) { - histogram_2048_topk(row_input, row_output, seq_len); + histogram_2048_topk(row_input, row_output, seq_len); } else { - histogram_256_topk(row_input, row_output, 0, seq_len); + histogram_256_topk(row_input, row_output, 0, seq_len); } } continue; } const uint32_t my_chunk_start = cta_in_group * chunk_size; - radix_topk(row_input, row_output, seq_len, my_chunk_start, - chunk_size, local_histogram, suffix_sum, - shared_scalars, shared_ordered, state, cta_in_group, - ctas_per_group, barrier_phase, iter, tx); + radix_topk( + row_input, row_output, seq_len, my_chunk_start, chunk_size, + local_histogram, suffix_sum, shared_scalars, shared_ordered, state, + cta_in_group, ctas_per_group, barrier_phase, iter, tx); } } @@ -1011,7 +1013,6 @@ struct FilteredTopKTraits { } }; -constexpr uint32_t FILTERED_TOPK_MAX_K = 2048; constexpr uint32_t FILTERED_TOPK_BLOCK_THREADS = 1024; constexpr uint32_t FILTERED_TOPK_SMEM_INPUT_SIZE = 16 * 1024; // 16K indices per buffer @@ -1025,7 +1026,7 @@ constexpr size_t FILTERED_TOPK_SMEM_DYNAMIC = * \tparam IdType Index type (int32_t) * \tparam VEC_SIZE Vector size for input loads (1, 2, 4, or 8) */ -template +template __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) FilteredTopKUnifiedKernel(const DType* __restrict__ input, IdType* __restrict__ output, @@ -1059,7 +1060,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) alignas(128) __shared__ int s_counter; alignas(128) __shared__ int s_threshold_bin_id; alignas(128) __shared__ int s_num_input[2]; - alignas(128) __shared__ int s_indices[FILTERED_TOPK_MAX_K]; + alignas(128) __shared__ int s_indices[MAX_K]; auto& s_histogram = s_histogram_buf[0]; @@ -1280,7 +1281,7 @@ constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) { return static_cast(g); } -template +template cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, @@ -1297,7 +1298,7 @@ cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, #define DISPATCH_VEC_SIZE(VS) \ if (vec_size == VS) { \ - auto kernel = FilteredTopKUnifiedKernel; \ + auto kernel = FilteredTopKUnifiedKernel; \ FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( \ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, \ diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index b5645b33b907..c45ebd34729b 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -9,28 +9,29 @@ namespace vllm { template inline __device__ void apply_token_rotary_embedding( - scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { + scalar_t* __restrict__ arr, const float* __restrict__ cos_ptr, + const float* __restrict__ sin_ptr, int rot_offset, int embed_dim, + const bool inverse) { int x_index, y_index; - scalar_t cos, sin; + float cos_f, sin_f; if (IS_NEOX) { - // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = VLLM_LDG(cos_ptr + x_index); - sin = VLLM_LDG(sin_ptr + x_index); + cos_f = VLLM_LDG(cos_ptr + x_index); + sin_f = VLLM_LDG(sin_ptr + x_index); } else { - // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = VLLM_LDG(cos_ptr + x_index / 2); - sin = VLLM_LDG(sin_ptr + x_index / 2); + cos_f = VLLM_LDG(cos_ptr + x_index / 2); + sin_f = VLLM_LDG(sin_ptr + x_index / 2); } - - const scalar_t x = arr[x_index]; - const scalar_t y = arr[y_index]; - arr[x_index] = x * cos - y * sin; - arr[y_index] = y * cos + x * sin; + if (inverse) { + sin_f = -sin_f; + } + const float x_f = static_cast(arr[x_index]); + const float y_f = static_cast(arr[y_index]); + arr[x_index] = static_cast(x_f * cos_f - y_f * sin_f); + arr[y_index] = static_cast(y_f * cos_f + x_f * sin_f); } template @@ -42,22 +43,23 @@ inline __device__ void apply_rotary_embedding( // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] - const scalar_t* cache_ptr, const int head_size, const int num_heads, + const float* cache_ptr, const int head_size, const int num_heads, const int num_kv_heads, const int rot_dim, const int token_idx, const int64_t query_stride, const int64_t key_stride, - const int64_t head_stride) { + const int64_t head_stride, const int64_t rope_dim_offset, + const bool inverse) { const int embed_dim = rot_dim / 2; - const scalar_t* cos_ptr = cache_ptr; - const scalar_t* sin_ptr = cache_ptr + embed_dim; + const float* cos_ptr = cache_ptr; + const float* sin_ptr = cache_ptr + embed_dim; const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; const int64_t token_head = - token_idx * query_stride + head_idx * head_stride; + token_idx * query_stride + head_idx * head_stride + rope_dim_offset; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( - query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse); } if (key != nullptr) { @@ -65,10 +67,10 @@ inline __device__ void apply_rotary_embedding( for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; const int64_t token_head = - token_idx * key_stride + head_idx * head_stride; + token_idx * key_stride + head_idx * head_stride + rope_dim_offset; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( - key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim, inverse); } } } @@ -84,19 +86,18 @@ __global__ void rotary_embedding_kernel( // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] + const float* __restrict__ cos_sin_cache, // [max_position, rot_dim] fp32 const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int64_t head_stride, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. + const int head_size, const int64_t rope_dim_offset, const bool inverse) { const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const float* cache_ptr = cos_sin_cache + pos * rot_dim; apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride, head_stride); + token_idx, query_stride, key_stride, head_stride, rope_dim_offset, + inverse); } } // namespace vllm @@ -115,7 +116,7 @@ void rotary_embedding( // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox) { + bool is_neox, int64_t rope_dim_offset, bool inverse) { // num_tokens = batch_size * seq_len int64_t num_tokens = positions.numel(); int positions_ndim = positions.dim(); @@ -154,6 +155,8 @@ void rotary_embedding( int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + + TORCH_CHECK((rot_dim + rope_dim_offset) <= head_size); // Determine head stride: for [*, heads, head_size] use stride of last dim; // for flat [*, heads*head_size], heads blocks are contiguous of size // head_size @@ -165,20 +168,23 @@ void rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto cache_f32 = cos_sin_cache.to(torch::kFloat32); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { if (is_neox) { vllm::rotary_embedding_kernel<<>>( positions.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, - head_stride, num_heads, num_kv_heads, head_size); + cache_f32.data_ptr(), rot_dim, query_stride, key_stride, + head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset, + inverse); } else { vllm::rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, - cos_sin_cache.data_ptr(), rot_dim, query_stride, - key_stride, head_stride, num_heads, num_kv_heads, head_size); + cache_f32.data_ptr(), rot_dim, query_stride, key_stride, + head_stride, num_heads, num_kv_heads, head_size, rope_dim_offset, + inverse); } }); } diff --git a/csrc/sampler.cu b/csrc/sampler.cu index c0cc03a08ad7..14d84013c08d 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -258,7 +258,13 @@ __device__ bool processHistogramStep( auto processBins = [&](float logit, int idx) { if (isPartialMatch(logit, logitPattern)) { uint32_t binIdx = extractBinIdx(logit); - if (binIdx < thresholdBinIdx) { + // Only write elements with binIdx < thresholdBinIdx when: + // 1. This is step 0 and the threshold bin is small enough (no step 1) + // 2. This is step >= 1 (where pattern matching filters correctly) + // This prevents duplicates when step 0 and step 1 both run. + bool shouldWriteDirectly = + (step == 0 && smemFinalBinSize[0] <= kNumFinalItems) || (step >= 1); + if (binIdx < thresholdBinIdx && shouldWriteDirectly) { // The element is part of the top-k selection int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1); diff --git a/csrc/topk.cu b/csrc/topk.cu index f48e7cbc4fc8..364ecc21e532 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -10,33 +10,17 @@ #include "persistent_topk.cuh" #endif -void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, - torch::Tensor& output, torch::Tensor& workspace, int64_t k, - int64_t max_seq_len) { +namespace { + #ifndef USE_ROCM - TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor"); - TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor"); - TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor"); - TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported"); - TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32"); - TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32"); - TORCH_CHECK(logits.dim() == 2, "logits must be 2D"); - TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2, - "lengths must be 1D or 2D"); - TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous"); - TORCH_CHECK(output.dim() == 2, "output must be 2D"); +template +void launch_persistent_topk(const torch::Tensor& logits, + const torch::Tensor& lengths, torch::Tensor& output, + torch::Tensor& workspace, int64_t max_seq_len) { + namespace P = vllm::persistent; const int64_t num_rows = logits.size(0); const int64_t stride = logits.size(1); - - TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch"); - TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k, - "output size mismatch"); - namespace P = vllm::persistent; - - TORCH_CHECK(k == P::TopK, "k must be 2048"); - TORCH_CHECK(k <= stride, "k out of range"); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); static int num_sms = 0; @@ -50,18 +34,17 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, } if (num_rows > 32 && max_smem_per_block >= 128 * 1024) { - cudaError_t status = vllm::FilteredTopKRaggedTransform( - logits.data_ptr(), output.data_ptr(), - lengths.data_ptr(), static_cast(num_rows), - static_cast(k), static_cast(stride), stream); + cudaError_t status = + vllm::FilteredTopKRaggedTransform( + logits.data_ptr(), output.data_ptr(), + lengths.data_ptr(), static_cast(num_rows), + static_cast(TopK), static_cast(stride), stream); TORCH_CHECK(status == cudaSuccess, "FilteredTopK failed: ", cudaGetErrorString(status)); } else { TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor"); TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8"); - // Smem cap: smaller smem → more CTAs/group → more per-row parallelism for - // large path. Empirically tuned. int effective_max_smem; if (num_rows <= 4) { effective_max_smem = @@ -101,7 +84,7 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, int occupancy = 1; cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock, + &occupancy, P::persistent_topk_kernel, P::kThreadsPerBlock, smem_size); if (occupancy < 1) occupancy = 1; @@ -121,15 +104,16 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, params.lengths = lengths.data_ptr(); params.num_rows = static_cast(num_rows); params.stride = static_cast(stride); + params.top_k = static_cast(TopK); params.chunk_size = chunk_size; params.row_states = reinterpret_cast(workspace.data_ptr()); params.ctas_per_group = ctas_per_group; params.max_seq_len = static_cast(max_seq_len); - #define LAUNCH_PERSISTENT(VS) \ + #define LAUNCH_PERSISTENT(TOPK_VAL, VS) \ do { \ - auto kernel = &P::persistent_topk_kernel; \ + auto kernel = &P::persistent_topk_kernel; \ cudaError_t err = cudaFuncSetAttribute( \ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \ TORCH_CHECK(err == cudaSuccess, \ @@ -138,11 +122,11 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, } while (0) if (vec_size == 4) { - LAUNCH_PERSISTENT(4); + LAUNCH_PERSISTENT(TopK, 4); } else if (vec_size == 2) { - LAUNCH_PERSISTENT(2); + LAUNCH_PERSISTENT(TopK, 2); } else { - LAUNCH_PERSISTENT(1); + LAUNCH_PERSISTENT(TopK, 1); } #undef LAUNCH_PERSISTENT } @@ -150,6 +134,46 @@ void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, cudaError_t err = cudaGetLastError(); TORCH_CHECK(err == cudaSuccess, "persistent_topk failed: ", cudaGetErrorString(err)); +} +#endif + +} // anonymous namespace + +void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, + torch::Tensor& output, torch::Tensor& workspace, int64_t k, + int64_t max_seq_len) { +#ifndef USE_ROCM + TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor"); + TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor"); + TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor"); + TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported"); + TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32"); + TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32"); + TORCH_CHECK(logits.dim() == 2, "logits must be 2D"); + TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2, + "lengths must be 1D or 2D"); + TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous"); + TORCH_CHECK(output.dim() == 2, "output must be 2D"); + + const int64_t num_rows = logits.size(0); + const int64_t stride = logits.size(1); + + TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch"); + TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k, + "output size mismatch"); + TORCH_CHECK(k == 512 || k == 1024 || k == 2048, + "persistent_topk supports k=512, k=1024, or k=2048, got k=", k); + + if (k == 512) { + launch_persistent_topk<512>(logits, lengths, output, workspace, + max_seq_len); + } else if (k == 1024) { + launch_persistent_topk<1024>(logits, lengths, output, workspace, + max_seq_len); + } else { + launch_persistent_topk<2048>(logits, lengths, output, workspace, + max_seq_len); + } #else TORCH_CHECK(false, "persistent_topk is not supported on ROCm"); #endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 48062c3f47b2..b969a7580711 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -177,6 +177,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int forced_token_heads_per_warp=-1) -> ()"); ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); + // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and + // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one + // kernel launch. + ops.def( + "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(" + "Tensor! q, Tensor kv, Tensor! k_cache, " + "Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, " + "float eps, int cache_block_size) -> ()"); + ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, + &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " @@ -240,7 +251,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "rotary_embedding(Tensor positions, Tensor! query," " Tensor!? key, int head_size," - " Tensor cos_sin_cache, bool is_neox) -> ()"); + " Tensor cos_sin_cache, bool is_neox, int " + "rope_dim_offset=0, bool inverse=False) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); // Quantization ops diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 3f37b9a4024e..c7d171c165aa 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -213,7 +213,7 @@ configuration. | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | -| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | +| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 2145b0237690..7301a69513a6 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -9,8 +9,10 @@ torchaudio==2.11.0 # These must be updated alongside torch torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # FlashInfer should be updated together with the Dockerfile -flashinfer-python==0.6.8.post1 -flashinfer-cubin==0.6.8.post1 +flashinfer-python==0.6.9 +flashinfer-cubin==0.6.9 +apache-tvm-ffi==0.1.9 +tilelang # Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to # breaking changes in 1.19.0 nvidia-cudnn-frontend>=1.13.0,<1.19.0 diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 2dc522598e4e..92ea3c32dd4e 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -9,6 +9,7 @@ from vllm.utils.deep_gemm import ( _ceil_to_ue8m0, calc_diff, + fp8_fp4_mqa_logits, fp8_mqa_logits, fp8_paged_mqa_logits, get_num_sms, @@ -90,6 +91,53 @@ def _ref_fp8_mqa_logits( return logits +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_reference_fallback(): + torch.manual_seed(0) + + seq_len, seq_len_kv, num_heads, head_dim = 9, 17, 32, 32 + q = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = (torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3) + cu_seqlen_ke = torch.minimum( + torch.arange(seq_len, device="cuda", dtype=torch.int32) + 4, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + logits = fp8_fp4_mqa_logits( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + + kv_dequant = kv_fp8.float() * kv_scale[:, None] + score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant) + ref_logits = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0) + offsets = torch.arange(seq_len_kv, device="cuda") + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + ref_logits = ref_logits.masked_fill(~valid, float("-inf")) + + assert torch.equal(torch.isneginf(logits), torch.isneginf(ref_logits)) + finite = torch.isfinite(ref_logits) + assert (logits[finite] - ref_logits[finite]).abs().max() < 1e-4 + + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") @pytest.mark.skipif( @@ -127,7 +175,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool): q_fp8 = q.to(torch.float8_e4m3fn) kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) - logits = fp8_mqa_logits( + logits = fp8_fp4_mqa_logits( q_fp8, kv_fp8, weights, ks, ke, clean_logits=clean_logits ) diff --git a/tests/kernels/core/test_fused_q_kv_rmsnorm.py b/tests/kernels/core/test_fused_q_kv_rmsnorm.py new file mode 100644 index 000000000000..1017dc52ff98 --- /dev/null +++ b/tests/kernels/core/test_fused_q_kv_rmsnorm.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Correctness + large-token-count launch tests for fused_q_kv_rmsnorm. + +Before the grid-dim fix the kernel used grid ``(2, num_tokens)``, which hit +CUDA's 65535 grid-y cap for ``num_tokens >= 65536`` and failed with +``Triton Error [CUDA]: invalid argument`` at every large chunked-prefill +profile run. These tests pin the new grid layout. +""" + +from __future__ import annotations + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.v1.attention.ops.deepseek_v4_ops import fused_q_kv_rmsnorm + +pytestmark = pytest.mark.skipif( + not current_platform.is_cuda_alike(), + reason="fused_q_kv_rmsnorm requires a CUDA/ROCm device", +) + + +def _ref_rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor: + x_f32 = x.to(torch.float32) + variance = x_f32.pow(2).mean(dim=-1, keepdim=True) + y = x_f32 * torch.rsqrt(variance + eps) * w.to(torch.float32) + return y.to(x.dtype) + + +@pytest.mark.parametrize("num_tokens", [1, 17, 1024, 8192]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +def test_fused_q_kv_rmsnorm_correctness(num_tokens: int, dtype: torch.dtype): + torch.manual_seed(0) + device = "cuda" + q_size, kv_size = 192, 576 + qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device) + kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device) + qw = torch.randn(q_size, dtype=dtype, device=device) + kvw = torch.randn(kv_size, dtype=dtype, device=device) + eps = 1e-6 + + qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, eps) + + qr_ref = _ref_rmsnorm(qr, qw, eps) + kv_ref = _ref_rmsnorm(kv, kvw, eps) + + tol = dict(rtol=1e-2, atol=1e-2) + torch.testing.assert_close(qr_out, qr_ref, **tol) + torch.testing.assert_close(kv_out, kv_ref, **tol) + + +@pytest.mark.parametrize("num_tokens", [65535, 65536, 131072]) +def test_fused_q_kv_rmsnorm_launches_past_grid_y_cap(num_tokens: int): + """Regression guard: grid used to be (2, num_tokens), hitting CUDA's + 65535 grid-y cap at num_tokens >= 65536. The new grid (num_tokens, 2) + lifts that bound to 2**31-1.""" + device = "cuda" + dtype = torch.bfloat16 + q_size, kv_size = 192, 576 + qr = torch.randn(num_tokens, q_size, dtype=dtype, device=device) + kv = torch.randn(num_tokens, kv_size, dtype=dtype, device=device) + qw = torch.randn(q_size, dtype=dtype, device=device) + kvw = torch.randn(kv_size, dtype=dtype, device=device) + + qr_out, kv_out = fused_q_kv_rmsnorm(qr, kv, qw, kvw, 1e-6) + # spot-check a couple of rows against the torch reference + for row in (0, num_tokens // 2, num_tokens - 1): + torch.testing.assert_close( + qr_out[row], + _ref_rmsnorm(qr[row : row + 1], qw, 1e-6)[0], + rtol=1e-2, + atol=1e-2, + ) + torch.testing.assert_close( + kv_out[row], + _ref_rmsnorm(kv[row : row + 1], kvw, 1e-6)[0], + rtol=1e-2, + atol=1e-2, + ) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 47700f82a7b3..fd05759ac3df 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Unit-test DeepGEMM FP8 kernels (no DeepEP). +Unit-test DeepGEMM FP8 and FP4 kernels (no DeepEP). Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts. """ @@ -21,6 +21,8 @@ maybe_make_prepare_finalize, ) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + FusedMoEQuantDesc, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts @@ -204,3 +206,195 @@ def _spy_apply(*args, **kwargs): f"DeepGEMM path was not executed during the test. " f"Call counter: {call_counter['cnt']}" ) + + +# --------------------------------------------------------------------------- +# FP4 weight tests (DeepGEMM m_grouped_fp8_fp4_gemm_nt_contiguous) +# --------------------------------------------------------------------------- + + +def make_mxfp4_weights( + e: int, + n: int, + k: int, +): + """ + Generate (w1, w2) expert weights in MXFP4 packed format with float32 scales, + plus BF16 reference weights for validation. + + w1 shape: (E, 2N, K//2) uint8 — packed FP4 + w2 shape: (E, K, N//2) uint8 — packed FP4 + w1_s shape: (E, 2N, K//32) float32 — per-row block-32 scales + w2_s shape: (E, K, N//32) float32 — per-row block-32 scales + w1_bf16: (E, 2N, K) — original BF16 for reference + w2_bf16: (E, K, N) — original BF16 for reference + """ + from deep_gemm.utils.math import per_token_cast_to_fp4 + + dtype = torch.bfloat16 + gran_k = 32 # MXFP4 block size + + # bf16 reference weights — scale by 1/sqrt(dim) for numerical stability + w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) * (k**-0.5) + w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) * (n**-0.5) + + # Quantize per-expert to FP4 + w1 = torch.empty(e, 2 * n, k // 2, device="cuda", dtype=torch.uint8) + w2 = torch.empty(e, k, n // 2, device="cuda", dtype=torch.uint8) + w1_s = torch.empty( + e, 2 * n, math.ceil(k / gran_k), device="cuda", dtype=torch.float32 + ) + w2_s = torch.empty(e, k, math.ceil(n / gran_k), device="cuda", dtype=torch.float32) + + for i in range(e): + w1[i], w1_s[i] = per_token_cast_to_fp4( + w1_bf16[i].float(), use_ue8m0=True, gran_k=gran_k + ) + w2[i], w2_s[i] = per_token_cast_to_fp4( + w2_bf16[i].float(), use_ue8m0=True, gran_k=gran_k + ) + + return w1, w2, w1_s, w2_s, w1_bf16, w2_bf16 + + +def _bf16_moe_reference(x, w1, w2, topk_weights, topk_ids): + """BF16 token-loop MoE reference for correctness testing.""" + import torch.nn.functional as F + + num_tokens, hidden_size = x.shape + intermediate = w1.shape[1] // 2 + top_k = topk_ids.shape[1] + + output = torch.zeros(num_tokens, hidden_size, dtype=torch.float32, device=x.device) + for t in range(num_tokens): + for kk in range(top_k): + e = topk_ids[t, kk].item() + w = topk_weights[t, kk].item() + fc1 = x[t : t + 1].float() @ w1[e].float().T + linear = fc1[:, :intermediate] + gate = fc1[:, intermediate:] + act = F.silu(gate) * linear + fc2 = act @ w2[e].float().T + output[t] += w * fc2[0] + return output.to(torch.bfloat16) + + +def run_single_fp4_case(m, n, k, topk, num_experts): + """ + Run one (M,N,K) configuration with FP4 weights on DeepGEMM and assert + DeepGEMM FP4 == BF16 reference within tolerance. + """ + tokens_bf16 = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) * (k**-0.5) + + # FP4 expert weight tensors + BF16 originals for reference + w1, w2, w1_s, w2_s, w1_bf16, w2_bf16 = make_mxfp4_weights(num_experts, n, k) + + router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32) + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) + + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + ) + from vllm.platforms import current_platform + + _fp8_dtype = current_platform.fp8_dtype() + _block_shape = GroupShape(128, 128) + quant_config = FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None), + _a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_s, None, None, None), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_s, None, None, None), + ) + moe_config = make_dummy_moe_config() + + from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import ( + DeepGemmFP4Experts, + ) + + deep_gemm_fp4_experts = mk.FusedMoEKernel( + prepare_finalize=maybe_make_prepare_finalize( + moe=moe_config, + quant_config=quant_config, + allow_new_interface=True, + use_monolithic=False, + ), + fused_experts=DeepGemmFP4Experts( + moe_config=moe_config, + quant_config=quant_config, + ), + inplace=False, + ) + + # DeepGEMM FP4 path + out_deepgemm_fp4 = deep_gemm_fp4_experts.apply( + hidden_states=tokens_bf16, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + global_num_experts=num_experts, + activation=MoEActivation.SILU, + apply_router_weight_on_input=False, + expert_map=None, + ) + + # BF16 reference using the same original weights + out_ref = _bf16_moe_reference(tokens_bf16, w1_bf16, w2_bf16, topk_weights, topk_ids) + + # FP4 vs BF16 reference: quantization error from FP4 weights + FP8 activations + diff = calc_diff(out_deepgemm_fp4, out_ref) + assert diff < 0.05, f"FP4 diff exceeded 5%: {diff}" + + +# DeepSeek V4 dims: H=4096, I=2048, so N=2*I=4096, K=H=4096. +# FP4 quantization with block_k=32 needs large K for good accuracy. +FP4_MNKs = [ + (128, 4096, 4096), # DeepSeek V4 shape + (256, 2048, 2048), # Half-size variant +] + +FP4_TOPKS = [2] +FP4_NUM_EXPERTS = [8] + + +@pytest.mark.parametrize(("m", "n", "k"), FP4_MNKs) +@pytest.mark.parametrize("topk", FP4_TOPKS) +@pytest.mark.parametrize("num_experts", FP4_NUM_EXPERTS) +@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") +def test_deepgemm_fp4_vs_triton( + m, n, k, topk, num_experts, monkeypatch, workspace_init +): + pytest.importorskip("deep_gemm.utils.math") + with monkeypatch.context() as mp: + mp.setenv("VLLM_USE_DEEP_GEMM", "1") + + _DeepGemmFP4Experts = importlib.import_module( + "vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe" + ).DeepGemmFP4Experts + + call_counter = {"cnt": 0} + + orig_fn = _DeepGemmFP4Experts.apply + + def _spy_apply(*args, **kwargs): + call_counter["cnt"] += 1 + return orig_fn(*args, **kwargs) + + monkeypatch.setattr(_DeepGemmFP4Experts, "apply", _spy_apply) + if topk > num_experts: + pytest.skip(f"topk={topk} > num_experts={num_experts}") + + run_single_fp4_case( + m=m, + n=n, + k=k, + topk=topk, + num_experts=num_experts, + ) + + # ensure that the DeepGEMM FP4 path was indeed taken. + assert call_counter["cnt"] == 1, ( + f"DeepGEMM FP4 path was not executed during the test. " + f"Call counter: {call_counter['cnt']}" + ) diff --git a/tests/kernels/moe/test_topk_softplus_sqrt.py b/tests/kernels/moe/test_topk_softplus_sqrt.py new file mode 100644 index 000000000000..db6924a14349 --- /dev/null +++ b/tests/kernels/moe/test_topk_softplus_sqrt.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.fused_moe.config import ( + RoutingMethodType, + get_routing_method_type, +) +from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( + fused_topk_bias, +) +from vllm.platforms import current_platform + + +def _torch_topk_softplus_sqrt( + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + routed_scaling_factor: float, + e_score_correction_bias: torch.Tensor | None = None, + input_ids: torch.Tensor | None = None, + hash_indices_table: torch.Tensor | None = None, +): + scores = F.softplus(gating_output.float()).sqrt() + original_scores = scores + if e_score_correction_bias is not None: + scores_for_choice = scores + e_score_correction_bias.unsqueeze(0) + else: + scores_for_choice = scores + + if hash_indices_table is not None: + assert input_ids is not None + topk_ids = hash_indices_table[input_ids.long()] + else: + topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=True)[1] + + topk_weights = original_scores.gather(1, topk_ids.long()) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def test_sqrtsoftplus_bias_uses_deepseek_v4_routing_method(): + assert ( + get_routing_method_type( + scoring_func="sqrtsoftplus", + top_k=8, + renormalize=True, + num_expert_group=None, + has_e_score_bias=True, + ) + == RoutingMethodType.DeepseekV4 + ) + assert ( + get_routing_method_type( + scoring_func="sqrtsoftplus", + top_k=8, + renormalize=False, + num_expert_group=None, + has_e_score_bias=True, + ) + == RoutingMethodType.Unspecified + ) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize("num_tokens", [1, 33, 128]) +@pytest.mark.parametrize("hidden_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [128, 256, 384, 512]) +@pytest.mark.parametrize("topk", [6, 8, 16]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 1.5]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) +def test_fused_topk_softplus_sqrt( + num_tokens: int, + hidden_size: int, + num_experts: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, + dtype: torch.dtype, +): + torch.manual_seed(0) + hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda") + gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda") + e_score_correction_bias = torch.randn( + (num_experts,), dtype=torch.float32, device="cuda" + ) + + topk_weights_ref, topk_ids_ref = _torch_topk_softplus_sqrt( + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=gating_output, + scoring_func="sqrtsoftplus", + e_score_correction_bias=e_score_correction_bias, + topk=topk, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + ) + + # Different kernels may return the topk experts in different orders when + # scores tie; sort by expert id before comparing. + sorted_ref_ids, idx_ref = topk_ids_ref.sort(dim=-1) + sorted_ids, idx_ops = topk_ids.sort(dim=-1) + torch.testing.assert_close(sorted_ref_ids, sorted_ids, atol=0, rtol=0) + + sorted_w_ref = topk_weights_ref.gather(1, idx_ref) + sorted_w = topk_weights.gather(1, idx_ops) + torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize("num_tokens", [1, 33, 128]) +@pytest.mark.parametrize("hidden_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [256, 384, 512]) +@pytest.mark.parametrize("topk", [6, 8, 16]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) +@pytest.mark.parametrize("indices_type", [None, torch.int64]) +def test_fused_topk_softplus_sqrt_hash( + num_tokens: int, + hidden_size: int, + num_experts: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, + dtype: torch.dtype, + indices_type: torch.dtype | None, +): + torch.manual_seed(0) + vocab_size = 1024 + hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda") + gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda") + # Per-token fixed expert selection: for each vocab id pick `topk` distinct + # experts. + hash_indices_table = torch.stack( + [torch.randperm(num_experts)[:topk] for _ in range(vocab_size)] + ).to(device="cuda", dtype=torch.int32) + input_ids = torch.randint( + 0, vocab_size, (num_tokens,), dtype=torch.int32, device="cuda" + ) + + topk_weights_ref, topk_ids_ref = _torch_topk_softplus_sqrt( + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + input_ids=input_ids, + hash_indices_table=hash_indices_table, + ) + + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=gating_output, + scoring_func="sqrtsoftplus", + e_score_correction_bias=None, + topk=topk, + renormalize=renormalize, + input_tokens=input_ids, + hash_indices_table=hash_indices_table, + routed_scaling_factor=routed_scaling_factor, + indices_type=indices_type, + ) + + sorted_ref_ids, idx_ref = topk_ids_ref.sort(dim=-1) + sorted_ids, idx_ops = topk_ids.sort(dim=-1) + torch.testing.assert_close( + sorted_ref_ids.to(sorted_ids.dtype), sorted_ids, atol=0, rtol=0 + ) + + sorted_w_ref = topk_weights_ref.gather(1, idx_ref) + sorted_w = topk_weights.gather(1, idx_ops) + torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 4cb638e47af0..229a75f1deb6 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -14,6 +14,7 @@ from vllm.config import VllmConfig from vllm.model_executor.kernels.linear.scaled_mm.cutlass import cutlass_scaled_mm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _get_default_w8a8_block_fp8_config, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm, ) @@ -62,6 +63,19 @@ def setup_cuda(): torch.set_default_device("cuda") +def test_w8a8_block_fp8_default_config_extends_low_m_tiles_on_sm12x(): + cfg = _get_default_w8a8_block_fp8_config(32, 128, 128) + capability = current_platform.get_device_capability() + capability_major = getattr(capability, "major", capability[0]) + + if capability_major == 12: + assert cfg["BLOCK_SIZE_M"] == 16 + assert cfg["num_stages"] == 3 + else: + assert cfg["BLOCK_SIZE_M"] == 64 + assert cfg["num_stages"] == 2 + + @pytest.mark.skipif( current_platform.is_fp8_fnuz(), reason="This platform supports e4m3fnuz, not e4m3fn.", @@ -135,6 +149,48 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 +@pytest.mark.skipif( + not hasattr(torch, "float8_e8m0fnu"), + reason="torch does not expose float8_e8m0fnu", +) +@torch.inference_mode() +def test_w8a8_block_fp8_matmul_accepts_e8m0_scales(): + torch.manual_seed(0) + M, N, K = 7, 256, 256 + block_size = [128, 128] + out_dtype = torch.bfloat16 + fp8_info = torch.finfo(current_platform.fp8_dtype()) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype()) + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype()) + + scale_choices = torch.tensor([0.00390625, 0.0078125, 0.015625, 0.03125]) + As = scale_choices[torch.randint(0, len(scale_choices), (M, K // 128))] + Bs = scale_choices[torch.randint(0, len(scale_choices), (N // 128, K // 128))] + As_e8m0 = As.to(torch.float8_e8m0fnu) + Bs_e8m0 = Bs.to(torch.float8_e8m0fnu) + + ref_out = native_w8a8_block_matmul( + A_fp8, + B_fp8, + As_e8m0.to(torch.float32), + Bs_e8m0.to(torch.float32), + block_size, + out_dtype, + ) + out = w8a8_triton_block_scaled_mm( + A_fp8, B_fp8, As_e8m0, Bs_e8m0, block_size, out_dtype + ) + + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) + assert rel_diff < 0.001 + + @pytest.mark.skipif( not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform." ) diff --git a/tests/kernels/quantization/test_scaled_mm_kernel_selection.py b/tests/kernels/quantization/test_scaled_mm_kernel_selection.py index bedebdb59b85..952ddcdef48b 100644 --- a/tests/kernels/quantization/test_scaled_mm_kernel_selection.py +++ b/tests/kernels/quantization/test_scaled_mm_kernel_selection.py @@ -15,12 +15,14 @@ from vllm.model_executor.kernels.linear import ( AiterInt8ScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel, + CutlassFp8BlockScaledMMKernel, Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ScaledMMLinearKernel, init_int8_linear_kernel, register_linear_kernel, ) +from vllm.model_executor.kernels.linear.scaled_mm import cutlass as cutlass_scaled_mm from vllm.platforms import PlatformEnum pytestmark = pytest.mark.cpu_test @@ -71,6 +73,35 @@ def test_aiter_kernel_implements_is_supported(): # This validates the method works correctly even on non-ROCm platforms +@pytest.mark.parametrize("compute_capability", [(12, 0), (12, 1), 120, 121]) +def test_cutlass_fp8_block_scaled_mm_rejects_sm12x( + compute_capability, monkeypatch +): + monkeypatch.setattr(cutlass_scaled_mm, "CUTLASS_BLOCK_FP8_SUPPORTED", True) + + supported, reason = CutlassFp8BlockScaledMMKernel.is_supported( + compute_capability + ) + + assert not supported + assert reason is not None + assert "SM12x" in reason + + +@pytest.mark.parametrize("compute_capability", [(9, 0), 90, (10, 0), 100]) +def test_cutlass_fp8_block_scaled_mm_allows_non_sm12x_when_available( + compute_capability, monkeypatch +): + monkeypatch.setattr(cutlass_scaled_mm, "CUTLASS_BLOCK_FP8_SUPPORTED", True) + + supported, reason = CutlassFp8BlockScaledMMKernel.is_supported( + compute_capability + ) + + assert supported + assert reason is None + + def test_cpu_kernel_accepts_all_configs(): """Test that CPUInt8ScaledMMLinearKernel accepts all config combinations.""" configs = [ diff --git a/tests/kernels/test_compressor_kv_cache.py b/tests/kernels/test_compressor_kv_cache.py new file mode 100644 index 000000000000..592b58fbe430 --- /dev/null +++ b/tests/kernels/test_compressor_kv_cache.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Round-trip tests for compressor → FP8 quant + KV cache insert → gather + dequant. + +Two paths tested: + A) DeepseekV4 Attention: head_dim=512 (448 FP8 nope + 64 bf16 rope), quant_block=64 + B) Indexer: head_dim=128 (all FP8), quant_block=128 + +These serve as golden references for validating the future fused +compressor+quant+cache kernel. +""" + +import math + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.v1.attention.ops.deepseek_v4_ops import ( + dequantize_and_gather_k_cache, + quantize_and_insert_k_cache, +) + + +def _ue8m0_reference(x: torch.Tensor, block_size: int, fp8_max: float): + """PyTorch reference for UE8M0 FP8 quantization (per-block, power-of-2 scale). + + Returns (x_fp8, scales) where x_fp8 is float8_e4m3fn and scales are float32. + """ + assert x.dim() == 1 + n = x.numel() + n_blocks = math.ceil(n / block_size) + x_fp8 = torch.zeros(n, dtype=torch.float8_e4m3fn, device=x.device) + scales = torch.zeros(n_blocks, dtype=torch.float32, device=x.device) + + for i in range(n_blocks): + start = i * block_size + end = min(start + block_size, n) + block = x[start:end].float() + amax = block.abs().max().clamp(min=1e-4) + raw_scale = amax / fp8_max + exponent = math.ceil(math.log2(raw_scale.item())) + scale = 2.0**exponent + scales[i] = scale + quantized = (block / scale).clamp(-fp8_max, fp8_max) + x_fp8[start:end] = quantized.to(torch.float8_e4m3fn) + + return x_fp8, scales + + +# ── Test A: DeepseekV4 Attention path ────────────────────────────────────────────── + + +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17]) +@pytest.mark.parametrize("block_size", [16, 64]) +def test_deepseek_v4_attention_quant_cache_roundtrip(num_tokens: int, block_size: int): + """compressed_kv → quantize_and_insert_k_cache → dequantize_and_gather_k_cache + → compare against original.""" + + HEAD_DIM = 512 + NOPE_DIM = 448 + HEAD_BYTES = 584 # 448 fp8 + 128 bf16 + 8 uint8 scale + FP8_MAX = 448.0 + QUANT_BLOCK = 64 + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + device = "cuda" + + # Random compressed_kv (simulates compressor output) + compressed_kv = torch.randn( + num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + + # ── Quant + insert ────────────────────────────────────────────────── + k_cache = torch.zeros( + num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device + ) + k_cache_2d = k_cache.view(num_blocks, -1) + + # Sequential slot mapping: token i → slot i + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + quantize_and_insert_k_cache( + compressed_kv, k_cache_2d, slot_mapping, block_size=block_size + ) + + # ── Gather + dequant ──────────────────────────────────────────────── + num_reqs = 1 + max_blocks_per_seq = num_blocks + out = torch.zeros( + num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device) + # block_table: request 0 uses physical blocks 0, 1, ... + block_table = torch.arange( + max_blocks_per_seq, dtype=torch.int32, device=device + ).unsqueeze(0) + + dequantize_and_gather_k_cache( + out, k_cache, seq_lens, None, block_table, block_size, offset=0 + ) + + recovered = out[0, :num_tokens] + + # ── NoPE portion (first 448): FP8 quantized, expect UE8M0 error ── + nope_orig = compressed_kv[:, :NOPE_DIM].float() + nope_recv = recovered[:, :NOPE_DIM].float() + nope_diff = (nope_recv - nope_orig).abs() + + # Per-token check: FP8 e4m3 (3-bit mantissa) worst-case error is + # half-ULP at the largest representable value. At y ≈ 448 (max), + # ULP = 2^(8-3) = 32, so error ≤ 16 * scale. + for t in range(num_tokens): + _, scales = _ue8m0_reference( + compressed_kv[t, :NOPE_DIM].float(), QUANT_BLOCK, FP8_MAX + ) + max_allowed = 16.0 * scales.max().item() + token_diff = nope_diff[t].max().item() + assert token_diff <= max_allowed, ( + f"Token {t} nope diff {token_diff} exceeds max_allowed " + f"{max_allowed} (scale={scales.max().item()})" + ) + + # ── RoPE portion (last 64): stored as bf16, should be exact ───── + rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs() + assert rope_diff.max().item() == 0.0, ( + f"RoPE portion should be exact but got max diff {rope_diff.max().item()}" + ) + + +# ── Test B: Indexer path ──────────────────────────────────────────────────── + + +@pytest.mark.parametrize("num_tokens", [1, 4, 8, 17]) +@pytest.mark.parametrize("block_size", [16, 64]) +def test_indexer_quant_cache_roundtrip(num_tokens: int, block_size: int): + """k → indexer_k_quant_and_cache → cp_gather_indexer_k_quant_cache + → manual dequant → compare against original.""" + + HEAD_DIM = 128 + QUANT_BLOCK_SIZE = 128 + # cache_stride = head_dim + (head_dim * 4 / quant_block_size) = 128 + 4 = 132 + CACHE_STRIDE = HEAD_DIM + HEAD_DIM * 4 // QUANT_BLOCK_SIZE + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + device = "cuda" + + # Random K (simulates compressor output for indexer) + k = torch.randn(num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device) + + # ── Quant + insert ────────────────────────────────────────────────── + kv_cache = torch.zeros( + num_blocks, block_size, CACHE_STRIDE, dtype=torch.uint8, device=device + ) + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, QUANT_BLOCK_SIZE, "ue8m0") + + # ── Gather ────────────────────────────────────────────────────────── + max_blocks_per_seq = num_blocks + block_table = torch.arange( + max_blocks_per_seq, dtype=torch.int32, device=device + ).unsqueeze(0) + cu_seq_lens = torch.tensor([0, num_tokens], dtype=torch.int32, device=device) + + # dst_k: [total_seq_len, head_dim] as uint8 (raw FP8 bytes) + dst_k = torch.zeros(num_tokens, HEAD_DIM, dtype=torch.uint8, device=device) + # dst_scale: [total_seq_len, head_dim/quant_block*4] as uint8 (raw float32 bytes) + num_scale_bytes = HEAD_DIM * 4 // QUANT_BLOCK_SIZE # 4 + dst_scale = torch.zeros( + num_tokens, num_scale_bytes, dtype=torch.uint8, device=device + ) + + ops.cp_gather_indexer_k_quant_cache( + kv_cache, dst_k, dst_scale, block_table, cu_seq_lens + ) + + # ── Manual dequant ────────────────────────────────────────────────── + k_fp8 = dst_k.view(torch.float8_e4m3fn).float() # [num_tokens, 128] + scale = dst_scale.view(torch.float32) # [num_tokens, 1] + k_recovered = k_fp8 * scale # [num_tokens, 128] + + # ── Compare ───────────────────────────────────────────────────────── + diff = (k_recovered - k.float()).abs() + k_abs = k.float().abs() + + for t in range(num_tokens): + amax = k_abs[t].max().clamp(min=1e-4).item() + # UE8M0: scale = 2^ceil(log2(amax / 448)) + exponent = math.ceil(math.log2(amax / 448.0)) + ue8m0_scale = 2.0**exponent + # FP8 e4m3 (3-bit mantissa): worst-case error = 16 * scale + max_allowed = 16.0 * ue8m0_scale + token_diff = diff[t].max().item() + assert token_diff <= max_allowed, ( + f"Token {t} diff {token_diff} exceeds max_allowed " + f"{max_allowed} (scale={ue8m0_scale})" + ) + + +def test_indexer_gather_accepts_upper_bound_output(): + """Gather only exact cu_seq_lens even when dst is over-allocated.""" + + head_dim = 128 + quant_block_size = 128 + cache_stride = head_dim + head_dim * 4 // quant_block_size + valid_tokens = 9 + upper_bound_tokens = 13 + block_size = 16 + num_blocks = 2 + sentinel = 123 + device = "cuda" + + k = torch.randn(valid_tokens, head_dim, dtype=torch.bfloat16, device=device) + kv_cache = torch.zeros( + num_blocks, block_size, cache_stride, dtype=torch.uint8, device=device + ) + slot_mapping = torch.arange(valid_tokens, dtype=torch.int64, device=device) + ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, quant_block_size, "ue8m0") + + block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze( + 0 + ) + cu_seq_lens = torch.tensor([0, valid_tokens], dtype=torch.int32, device=device) + dst_k = torch.full( + (upper_bound_tokens, head_dim), sentinel, dtype=torch.uint8, device=device + ) + num_scale_bytes = head_dim * 4 // quant_block_size + dst_scale = torch.full( + (upper_bound_tokens, num_scale_bytes), + sentinel, + dtype=torch.uint8, + device=device, + ) + + ops.cp_gather_indexer_k_quant_cache( + kv_cache, dst_k, dst_scale, block_table, cu_seq_lens + ) + torch.accelerator.synchronize() + + k_recovered = dst_k[:valid_tokens].view(torch.float8_e4m3fn).float() * dst_scale[ + :valid_tokens + ].view(torch.float32) + diff = (k_recovered - k.float()).abs() + max_allowed = (16.0 * dst_scale[:valid_tokens].view(torch.float32).max()).item() + assert diff.max().item() <= max_allowed + assert torch.all(dst_k[valid_tokens:] == sentinel) + assert torch.all(dst_scale[valid_tokens:] == sentinel) + + +# ── Test C: DeepseekV4 attention with values at different magnitudes ─────────── + + +def test_deepseek_v4_quant_magnitude_range(): + """Test that quantization handles a range of magnitudes correctly.""" + + HEAD_DIM = 512 + NOPE_DIM = 448 + HEAD_BYTES = 584 + block_size = 16 + num_tokens = 4 + num_blocks = 2 + device = "cuda" + + # Create inputs with varying magnitudes: small, medium, large + compressed_kv = torch.zeros( + num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + compressed_kv[0] = 0.001 # very small + compressed_kv[1] = 1.0 # unit scale + compressed_kv[2] = 100.0 # large + compressed_kv[3] = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device) + + k_cache = torch.zeros( + num_blocks, block_size, HEAD_BYTES, dtype=torch.uint8, device=device + ) + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + quantize_and_insert_k_cache( + compressed_kv, k_cache.view(num_blocks, -1), slot_mapping, block_size + ) + + out = torch.zeros(1, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device) + seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device) + block_table = torch.arange(num_blocks, dtype=torch.int32, device=device).unsqueeze( + 0 + ) + + dequantize_and_gather_k_cache( + out, k_cache, seq_lens, None, block_table, block_size, offset=0 + ) + + recovered = out[0, :num_tokens] + + # RoPE portion must be exact + rope_diff = (recovered[:, NOPE_DIM:] - compressed_kv[:, NOPE_DIM:]).abs().max() + assert rope_diff.item() == 0.0, f"RoPE diff {rope_diff.item()}" + + # NoPE: relative error should be reasonable + for t in range(num_tokens): + orig = compressed_kv[t, :NOPE_DIM].float() + recv = recovered[t, :NOPE_DIM].float() + abs_diff = (recv - orig).abs().max().item() + magnitude = orig.abs().max().item() + if magnitude > 0.01: + rel_err = abs_diff / magnitude + assert rel_err < 0.15, ( + f"Token {t}: rel_err={rel_err:.4f}, abs_diff={abs_diff:.6f}, " + f"magnitude={magnitude:.4f}" + ) diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py new file mode 100644 index 000000000000..46d226e0f74e --- /dev/null +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Standalone unit test for the horizontally-fused DeepseekV4-MLA kernel: + + fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert + - Q side: per-head RMSNorm (no weight) + GPT-J RoPE on last 64 dims + - KV side: GPT-J RoPE on last 64 + UE8M0 FP8 quant + paged cache insert + +We compare against: + - PyTorch reference for RMSNorm + GPT-J RoPE on Q + - Existing Triton `quantize_and_insert_k_cache` + round-trip via + `dequantize_and_gather_k_cache` for KV + +The kernel is imported via +`torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert`. +""" + +import pytest +import torch + +from vllm.v1.attention.ops.deepseek_v4_ops import ( + dequantize_and_gather_k_cache, + quantize_and_insert_k_cache, +) + +# ── Constants matching the kernel ──────────────────────────────────────────── +HEAD_DIM = 512 +ROPE_DIM = 64 +NOPE_DIM = HEAD_DIM - ROPE_DIM # 448 +QUANT_BLOCK = 64 +FP8_MAX = 448.0 +HEAD_BYTES = NOPE_DIM + ROPE_DIM * 2 + 8 # 448 + 128 + 8 = 584 + + +# ── PyTorch reference implementations ──────────────────────────────────────── + + +def make_cos_sin_cache(max_pos: int, rope_dim: int, dtype, device): + """Build a cos||sin cache matching DeepseekV4ScalingRotaryEmbedding layout. + cos_sin_cache[pos, :rope_dim/2] = cos(theta), [rope_dim/2:] = sin(theta). + """ + base = 10000.0 + inv_freq = 1.0 / ( + base + ** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=device) / rope_dim) + ) + t = torch.arange(max_pos, dtype=torch.float32, device=device) + freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rope_dim/2] + cache = torch.cat((freqs.cos(), freqs.sin()), dim=-1) # [max_pos, rope_dim] + return cache.to(dtype) + + +def apply_rope_gptj_last_k( + x: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor +) -> torch.Tensor: + """GPT-J-style (interleaved-pair) RoPE on the LAST rope_dim elements. + + x: [..., head_dim] float32 + positions: [num_tokens] int64 (positions[i] corresponds to x[i, ...]) + cos_sin_cache: [max_pos, rope_dim] float (cos|sin layout) + + Returns rotated x (same shape/dtype). + """ + rope_dim = cos_sin_cache.shape[-1] + half = rope_dim // 2 + head_dim = x.shape[-1] + nope_dim = head_dim - rope_dim + + # Gather cos/sin for each token position: [num_tokens, rope_dim] + cs = cos_sin_cache[positions].to(torch.float32) # [N, rope_dim] + cos = cs[..., :half] # [N, half] + sin = cs[..., half:] # [N, half] + + # Reshape leading dims so we can broadcast: x shape [..., head_dim]. + # Bring token dim to front; assume x is [num_tokens, ..., head_dim]. + # We rely on positions being per-token and all other dims sharing the same pos. + rope = x[..., nope_dim:].float() # [..., rope_dim] + # Make rope pairs: reshape last dim to [half, 2] + shape = rope.shape + rope = rope.reshape(*shape[:-1], half, 2) + even = rope[..., 0] # [..., half] + odd = rope[..., 1] + + # Broadcast cos/sin over any heads dim in between. cos/sin are [N, half]. + # Add singleton dims for intermediate axes. + for _ in range(rope.ndim - 3): + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + new_even = even * cos - odd * sin + new_odd = even * sin + odd * cos + rope_rotated = torch.stack((new_even, new_odd), dim=-1).reshape(shape) + + out = x.clone().float() + out[..., nope_dim:] = rope_rotated + return out.to(x.dtype) + + +def rmsnorm_no_weight(x: torch.Tensor, eps: float) -> torch.Tensor: + """RMSNorm with no learnable weight, matching + `RMSNorm(head_dim, has_weight=False)`.""" + orig_dtype = x.dtype + xf = x.float() + variance = xf.pow(2).mean(dim=-1, keepdim=True) + return (xf * torch.rsqrt(variance + eps)).to(orig_dtype) + + +# ── Dispatch to the CUDA op (skip test cleanly if it isn't built in) ───────── + + +def _op_available() -> bool: + return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert") + + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or not _op_available(), + reason="CUDA not available or fused DeepseekV4 op not built in", +) + + +def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs): + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs + ) + + +# ── Test 1: Q path numerical parity ────────────────────────────────────────── + + +@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64]) +@pytest.mark.parametrize("n_heads", [8, 64]) +def test_q_path_matches_reference(num_tokens: int, n_heads: int): + torch.manual_seed(0) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + max_pos = 4096 + + q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=torch.int64, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + # Reference: RMSNorm (no weight) per head, then GPT-J RoPE on last 64. + q_ref = rmsnorm_no_weight(q, eps) + q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache) + + # Fused call with dummy KV tensors (KV branch will write slot_mapping=-1 → noop). + num_blocks = 2 + bs = 16 + kv = torch.zeros(num_tokens, HEAD_DIM, dtype=dtype, device=device) + k_cache = torch.zeros( + num_blocks, bs, HEAD_BYTES, dtype=torch.uint8, device=device + ).view(num_blocks, -1) + slot_mapping = torch.full((num_tokens,), -1, dtype=torch.int64, device=device) + q_fused = q.clone() + _call_fused(q_fused, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs) + + torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2) + + +# ── Test 2: KV path round-trip byte/value parity ───────────────────────────── + + +def _ue8m0_per_block_scales(kv_roped_nope_f32: torch.Tensor, qblock: int): + """Return per-token per-block max scale (used to bound FP8 error).""" + n_tok, nope = kv_roped_nope_f32.shape + n_blocks = nope // qblock + blocks = kv_roped_nope_f32.view(n_tok, n_blocks, qblock) + absmax = blocks.abs().amax(dim=-1).clamp(min=1e-4) + raw = absmax / FP8_MAX + exponent = torch.ceil(torch.log2(raw)) + return torch.pow(2.0, exponent) # [n_tok, n_blocks] + + +@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64]) +@pytest.mark.parametrize("block_size", [16, 64]) +def test_kv_path_matches_reference(num_tokens: int, block_size: int): + torch.manual_seed(1) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + max_pos = 4096 + + kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=torch.int64, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + # ── Reference path: RoPE on kv, then existing Triton quant+insert ────── + kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache) + k_cache_ref = torch.zeros( + num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device + ) + quantize_and_insert_k_cache( + kv_ref, k_cache_ref, slot_mapping, block_size=block_size + ) + + # ── Fused path (dummy q, single head) ────────────────────────────────── + k_cache_fused = torch.zeros_like(k_cache_ref) + q_dummy = torch.zeros(num_tokens, 1, HEAD_DIM, dtype=dtype, device=device) + _call_fused( + q_dummy, + kv, + k_cache_fused, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + ) + + # ── Round-trip compare via dequant+gather ────────────────────────────── + def _dequant(k_cache_2d): + num_reqs = 1 + max_blocks = num_blocks + out = torch.zeros( + num_reqs, num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + seq_lens = torch.tensor([num_tokens], dtype=torch.int32, device=device) + block_table = torch.arange( + max_blocks, dtype=torch.int32, device=device + ).unsqueeze(0) + # gather_lens arg is None (use seq_lens) + k_cache_3d = k_cache_2d.view(num_blocks, block_size, HEAD_BYTES) + dequantize_and_gather_k_cache( + out, k_cache_3d, seq_lens, None, block_table, block_size, offset=0 + ) + return out[0, :num_tokens] + + recovered_ref = _dequant(k_cache_ref) + recovered_fused = _dequant(k_cache_fused) + + # NoPE: per-block UE8M0 FP8 error bound (half-ULP at max = 16 * scale). + scales = _ue8m0_per_block_scales(kv_ref[:, :NOPE_DIM].float(), QUANT_BLOCK) + for t in range(num_tokens): + max_allowed = 16.0 * scales[t].max().item() + diff_ref = ( + (recovered_ref[t, :NOPE_DIM] - kv_ref[t, :NOPE_DIM]).abs().max().item() + ) + diff_fused = ( + (recovered_fused[t, :NOPE_DIM] - kv_ref[t, :NOPE_DIM]).abs().max().item() + ) + assert diff_ref <= max_allowed, ( + f"ref NoPE token {t} diff {diff_ref} > {max_allowed}" + ) + assert diff_fused <= max_allowed, ( + f"fused NoPE token {t} diff {diff_fused} > {max_allowed}" + ) + + # RoPE region: bf16 stored exactly → zero diff. + rope_diff = (recovered_fused[:, NOPE_DIM:] - kv_ref[:, NOPE_DIM:]).abs().max() + assert rope_diff.item() == 0.0, f"RoPE portion not exact: {rope_diff.item()}" + + # Exact byte equality of the two cache buffers — strong parity. + torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) + + +# ── Test 2b: DP padding (slot_mapping shorter than q/kv) ───────────────────── + + +@pytest.mark.parametrize("num_tokens", [4, 17]) +@pytest.mark.parametrize("pad", [1, 5]) +@pytest.mark.parametrize("block_size", [16, 64]) +def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int): + """slot_mapping.size(0) < q.size(0): the kernel must skip padded + tokens in the KV branch while still running Q-norm+RoPE on all rows.""" + torch.manual_seed(3) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + max_pos = 4096 + total = num_tokens + pad + + kv = torch.randn(total, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(total, dtype=torch.int64, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + # Reference: only the first num_tokens kv rows get inserted. + kv_ref = apply_rope_gptj_last_k( + kv[:num_tokens], positions[:num_tokens], cos_sin_cache + ) + k_cache_ref = torch.zeros( + num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device + ) + quantize_and_insert_k_cache( + kv_ref, k_cache_ref, slot_mapping, block_size=block_size + ) + + # Fused: pass full-sized q/kv/positions, shorter slot_mapping. + q_dummy = torch.zeros(total, 1, HEAD_DIM, dtype=dtype, device=device) + k_cache_fused = torch.zeros_like(k_cache_ref) + _call_fused( + q_dummy, + kv, + k_cache_fused, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + ) + + torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) + + +# ── Test 3: combined single-call Q + KV parity ─────────────────────────────── + + +@pytest.mark.parametrize("num_tokens", [1, 4, 17]) +@pytest.mark.parametrize("n_heads", [8, 64]) +@pytest.mark.parametrize("block_size", [16, 64]) +def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int): + torch.manual_seed(2) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + max_pos = 4096 + + q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device) + kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=torch.int64, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + # Reference. + q_ref = rmsnorm_no_weight(q, eps) + q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache) + kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache) + k_cache_ref = torch.zeros( + num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device + ) + quantize_and_insert_k_cache( + kv_ref, k_cache_ref, slot_mapping, block_size=block_size + ) + + # Fused single call. + q_fused = q.clone() + k_cache_fused = torch.zeros_like(k_cache_ref) + _call_fused( + q_fused, + kv, + k_cache_fused, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + ) + + torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) diff --git a/tests/kernels/test_fused_indexer_q_rope_quant.py b/tests/kernels/test_fused_indexer_q_rope_quant.py new file mode 100644 index 000000000000..03d5ad4c8ac7 --- /dev/null +++ b/tests/kernels/test_fused_indexer_q_rope_quant.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit test for fused_indexer_q_rope_quant. + +Compares the fused Triton kernel against the unfused reference flow used by +the DeepseekV4 indexer in model_tracking: + q_rot = ops.rotary_embedding(positions, q, None, head_dim, cos_sin_cache, + is_neox_style=False, + rope_dim_offset=head_dim - rope_dim) + q_fp8, q_scale = per_token_group_quant_fp8(q_rot, head_dim, use_ue8m0=True) + weights_out = weights * q_scale * softmax_scale * head_scale + +Expects bit-exact equality on both q_fp8 and weights_out. +""" + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( + fused_indexer_q_rope_quant, +) + +HEAD_DIM = 128 +ROPE_DIM = 64 +N_HEAD = 64 +MAX_POS = 4096 + + +def _reference( + positions: torch.Tensor, + q: torch.Tensor, + cos_sin_cache: torch.Tensor, + weights: torch.Tensor, + softmax_scale: float, + head_scale: float, +) -> tuple[torch.Tensor, torch.Tensor]: + q_rot = q.clone() + ops.rotary_embedding( + positions, + q_rot, + None, + HEAD_DIM, + cos_sin_cache, + False, # is_neox_style=False → GPT-J interleaved + HEAD_DIM - ROPE_DIM, # rope_dim_offset → rotate the tail + False, + ) + q_fp8, q_scale = per_token_group_quant_fp8( + q_rot.view(-1, HEAD_DIM).contiguous(), + HEAD_DIM, + use_ue8m0=True, + ) + q_fp8 = q_fp8.view(-1, N_HEAD, HEAD_DIM) + q_scale = q_scale.view(-1, N_HEAD) + + weights_out = weights.to(torch.float32) * q_scale * softmax_scale * head_scale + return q_fp8, weights_out + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32, 257]) +@pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16]) +@torch.inference_mode() +def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype): + device = "cuda" + torch.manual_seed(0) + + q = torch.randn(num_tokens, N_HEAD, HEAD_DIM, dtype=torch.bfloat16, device=device) + positions = torch.randint( + 0, MAX_POS, (num_tokens,), dtype=torch.int64, device=device + ) + cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=cache_dtype, device=device) + weights = torch.randn(num_tokens, N_HEAD, dtype=torch.bfloat16, device=device) + softmax_scale = HEAD_DIM**-0.5 + head_scale = N_HEAD**-0.5 + + q_fp8_ref, weights_ref = _reference( + positions, q, cos_sin_cache, weights, softmax_scale, head_scale + ) + q_fp8_fused, weights_fused = fused_indexer_q_rope_quant( + positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale + ) + + # fp8 tensors aren't directly comparable via torch.equal — reinterpret as int8. + ref_bits = q_fp8_ref.view(torch.int8) + fused_bits = q_fp8_fused.view(torch.int8) + assert torch.equal(ref_bits, fused_bits), ( + f"q_fp8 mismatch: " + f"{(ref_bits != fused_bits).sum().item()} / {ref_bits.numel()} bytes differ" + ) + + assert torch.equal(weights_ref, weights_fused), ( + f"weights mismatch: max abs diff " + f"{(weights_ref - weights_fused).abs().max().item()}" + ) diff --git a/tests/kernels/test_fused_inv_rope_fp8_quant.py b/tests/kernels/test_fused_inv_rope_fp8_quant.py new file mode 100644 index 000000000000..10561a8a0304 --- /dev/null +++ b/tests/kernels/test_fused_inv_rope_fp8_quant.py @@ -0,0 +1,908 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for the fused inverse RoPE + block-scaled FP8 quantization kernel. + +Tests compare the fused kernel against a reference implementation built from +the existing separate operations (inverse RoPE via rotate_neox + FP8 quant +via per_token_group_quant_fp8). + +The reference faithfully reproduces the exact flow in deepseek_v4_attention.py:295-310: + 1. Apply inverse RoPE (NeoX style, last rope_dim=64 dims of each head) + 2. Reshape [T, H, head_dim] -> [T, G, D] + 3. Transpose+flatten to [G*T, D], quantize, reshape back + 4. Return o_fp8 and o_scale with strides (D, T*D, 1) and (S, T*S, 1) + (non-contiguous [T, G, ...] view backed by contiguous [G, T, ...] memory) + +Usage: + pytest tests/kernels/test_fused_inv_rope_fp8_quant.py -v +""" + +import pytest +import torch + +from vllm.v1.attention.ops.deepseek_v4_ops import fused_inv_rope_fp8_quant + +# -- Default dimensions matching DeepSeek V3/V4 -------------------------- +HEAD_DIM = 512 +NOPE_DIM = 448 +ROPE_DIM = 64 +QUANT_GROUP_SIZE = 128 +FP8_MAX = 448.0 # torch.finfo(torch.float8_e4m3fn).max +FP8_DTYPE = torch.float8_e4m3fn +EPS = 1e-10 + + +# ========================================================================= +# Helpers +# ========================================================================= + + +def assert_dequant_close( + fp8_a: torch.Tensor, + scale_a: torch.Tensor, + fp8_b: torch.Tensor, + scale_b: torch.Tensor, + msg: str = "", +): + """Compare two FP8-quantized tensors via their dequantized values. + + Uses cosine-similarity-based diff (same as deep_gemm calc_diff). + Both fused and reference paths rotate in fp32 using an fp32 + cos_sin_cache, so differences are only fp32 ordering ULPs that can + occasionally shift FP8 values at quantization boundaries. + """ + S = scale_a.shape[-1] + shape = fp8_a.shape + + dq_a = fp8_a.float() * scale_a.unsqueeze(-1).expand( + *shape[:-1], S, QUANT_GROUP_SIZE + ).reshape(shape) + dq_b = fp8_b.float() * scale_b.unsqueeze(-1).expand( + *shape[:-1], S, QUANT_GROUP_SIZE + ).reshape(shape) + + # Cosine diff: 1 - cos_sim (0 = identical, higher = worse) + dq_a_flat = dq_a.flatten().float() + dq_b_flat = dq_b.flatten().float() + cos_sim = torch.nn.functional.cosine_similarity( + dq_a_flat.unsqueeze(0), dq_b_flat.unsqueeze(0) + ).item() + diff = 1.0 - cos_sim + + assert diff < 1e-4, f"Dequant diff too large: {diff:.8f} (expected < 1e-4). {msg}" + + +def rotate_gptj(x: torch.Tensor) -> torch.Tensor: + """GPT-J style rotation: interleaved pairs, negate-swap. + + Matches vllm/model_executor/layers/rotary_embedding/common.py:23-27. + DeepseekV4 uses is_neox_style=False, so this is the correct rotation. + """ + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def make_cos_sin_cache( + max_pos: int, + rope_dim: int = ROPE_DIM, + dtype: torch.dtype = torch.float32, + device: str = "cuda", +) -> torch.Tensor: + """Create a synthetic cos_sin_cache matching the layout used by + DeepseekV4ScalingRotaryEmbedding._compute_cos_sin_cache. + + Shape: [max_pos, rope_dim] where first half is cos, second half is sin. + The fused kernel requires fp32; callers can override dtype if passing + the cache into the bf16-only paths. + """ + half = rope_dim // 2 + # Use random but bounded frequencies so cos/sin are well-behaved + inv_freq = 1.0 / ( + 10000.0 ** (torch.arange(0, half, device=device, dtype=torch.float32) / half) + ) + t = torch.arange(max_pos, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) # [max_pos, half] + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) # [max_pos, rope_dim] + return cache.to(dtype) + + +def reference_inv_rope( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + nope_dim: int = NOPE_DIM, + rope_dim: int = ROPE_DIM, +) -> torch.Tensor: + """Apply inverse RoPE to the last rope_dim dimensions of each head. + + Matches the GPT-J inverse rotation in pos_encoding_kernels.cu, which + promotes the cache to fp32 and performs the rotation in fp32. The + result is cast back to the input dtype. + + Args: + o: [T, H, head_dim] bf16 + positions: [T] int64 + cos_sin_cache: [max_pos, rope_dim] fp32 + + Returns: + o with inverse RoPE applied on the rope portion (bf16). + """ + assert cos_sin_cache.dtype == torch.float32 + cos_sin = cos_sin_cache[positions] # [T, rope_dim] fp32 + half = rope_dim // 2 + cos = cos_sin[:, :half] + sin = cos_sin[:, half:] + + # GPT-J style: repeat_interleave (not repeat) to match interleaved pairs + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(1) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(1) + sin = -sin # inverse + + o_pass = o[..., :nope_dim] + o_rot_f32 = o[..., nope_dim:].float() + o_rot_f32 = o_rot_f32 * cos + rotate_gptj(o_rot_f32) * sin + o_rot = o_rot_f32.to(o.dtype) + + return torch.cat([o_pass, o_rot], dim=-1) + + +def _ref_ue8m0_quant_block(x_f32: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Per-block UE8M0 FP8 quantization in pure float32. + + Matches the Triton kernel logic exactly: + absmax -> 2^ceil(log2(absmax / fp8_max)) -> clamp(x / scale) -> fp8 + + Args: + x_f32: [..., quant_group_size] float32 — one or more 128-element blocks. + + Returns: + x_fp8: same shape, float8_e4m3fn + scales: [...] float32, one scale per block + """ + assert x_f32.shape[-1] == QUANT_GROUP_SIZE + absmax = x_f32.abs().amax(dim=-1, keepdim=True).clamp(min=EPS) + scale_raw = absmax * (1.0 / FP8_MAX) + scale = torch.exp2(torch.ceil(torch.log2(scale_raw))) + x_scaled = (x_f32 / scale).clamp(-FP8_MAX, FP8_MAX) + x_fp8 = x_scaled.to(FP8_DTYPE) + return x_fp8, scale.squeeze(-1) + + +def reference_inv_rope_fp8_quant( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + n_groups: int, + heads_per_group: int, + nope_dim: int = NOPE_DIM, + rope_dim: int = ROPE_DIM, + quant_group_size: int = QUANT_GROUP_SIZE, +) -> tuple[torch.Tensor, torch.Tensor]: + """Full reference: inverse RoPE in fp32 + UE8M0 FP8 quant in fp32. + + Mimics the Triton kernel's precision path exactly: + Load bf16 -> cast to fp32 -> apply inverse RoPE with fp32 cos/sin -> + UE8M0 quant in fp32 -> write fp8 + scale + + Returns: + o_fp8: [T, G, D] FP8 with strides (D, T*D, 1) + o_scale: [T, G, S] FP32 with strides (S, T*S, 1) + """ + assert cos_sin_cache.dtype == torch.float32 + T, _H, head_dim = o.shape + d = heads_per_group * head_dim + S = d // quant_group_size + half_rope = rope_dim // 2 + chunks_per_head = head_dim // quant_group_size + + # Reshape [T, H, head_dim] -> [T, G, heads_per_group, head_dim] + o_4d = o.view(T, n_groups, heads_per_group, head_dim) + + # Lookup cos/sin directly in fp32 + cos_sin = cos_sin_cache[positions] # [T, rope_dim] fp32 + cos = cos_sin[:, :half_rope] # [T, half_rope] fp32 + sin = cos_sin[:, half_rope:] # [T, half_rope] fp32 + + # Allocate outputs in [G, T, ...] contiguous layout + fp8_buf = torch.empty(n_groups, T, d, dtype=FP8_DTYPE, device=o.device) + scale_buf = torch.empty(n_groups, T, S, dtype=torch.float32, device=o.device) + + # Process each quant block, matching the Triton kernel's per-program logic + for g in range(n_groups): + for qb in range(S): + head_in_group = qb // chunks_per_head + chunk_in_head = qb % chunks_per_head + offset = chunk_in_head * quant_group_size + + # Load 128 bf16 elements and promote to fp32 for rotation+quant + block = o_4d[:, g, head_in_group, offset : offset + quant_group_size] + x = block.float() + + # Apply inverse RoPE in fp32 if this is the last chunk + # GPT-J style: interleaved pairs (even=x, odd=y) + if chunk_in_head == chunks_per_head - 1: + rope_start = nope_dim % quant_group_size # 64 + rope_region = x[:, rope_start:].clone() + x_vals = rope_region[:, ::2] + y_vals = rope_region[:, 1::2] + x_new = x_vals * cos + y_vals * sin + y_new = y_vals * cos - x_vals * sin + x = x.clone() + x[:, rope_start::2] = x_new + x[:, rope_start + 1 :: 2] = y_new + + # UE8M0 quant in fp32 + x_fp8, scale = _ref_ue8m0_quant_block(x) + + # Write to [G, T, D] contiguous memory + fp8_buf[g, :, qb * quant_group_size : (qb + 1) * quant_group_size] = x_fp8 + scale_buf[g, :, qb] = scale + + # Return transposed views + return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1) + + +# ========================================================================= +# Tests +# ========================================================================= + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128]) +@pytest.mark.parametrize( + "num_heads,n_groups", + [(64, 8), (32, 4), (128, 8)], + ids=["H64_G8", "H32_G4", "H128_G8"], +) +@pytest.mark.parametrize("seed", [0, 42]) +@torch.inference_mode() +def test_correctness(num_tokens, num_heads, n_groups, seed): + """Compare fused kernel against reference for FP8 values and scales.""" + torch.manual_seed(seed) + + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + + # Create inputs + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.randint( + 0, max_pos, (num_tokens,), device=device, dtype=torch.long + ) + cos_sin_cache = make_cos_sin_cache( + max_pos, ROPE_DIM, dtype=torch.float32, device=device + ) + + # Reference + ref_fp8, ref_scale = reference_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + # Fused kernel + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + # Check shapes + d = heads_per_group * HEAD_DIM + S = d // QUANT_GROUP_SIZE + assert ref_fp8.shape == (num_tokens, n_groups, d) + assert fused_fp8.shape == (num_tokens, n_groups, d) + assert ref_scale.shape == (num_tokens, n_groups, S) + assert fused_scale.shape == (num_tokens, n_groups, S) + + # Scales: exact match (both use identical UE8M0 algorithm) + # Scales may differ by one UE8M0 step (factor of 2) if fp32 rotation + # ordering shifts absmax across a power-of-2 boundary. Check ratio is + # close to 1. + scale_ratio = fused_scale / ref_scale.clamp(min=1e-30) + assert scale_ratio.max() <= 2.0 and scale_ratio.min() >= 0.5, ( + f"Scale ratio out of [0.5, 2]: min={scale_ratio.min():.4f} " + f"max={scale_ratio.max():.4f}" + ) + + # Compare via dequant (Triton vs PyTorch fp32 may differ by ULPs) + assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale) + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128]) +@pytest.mark.parametrize( + "num_heads,n_groups", + [(64, 8), (128, 8)], + ids=["H64_G8", "H128_G8"], +) +@torch.inference_mode() +def test_output_strides(num_tokens, num_heads, n_groups): + """Verify fused output layout: + - FP8: logical [T, G, D] backed by contiguous [G, T, D]. + - Scale: MN-major TMA-aligned (column-major: T-stride=1). + """ + + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.randint( + 0, max_pos, (num_tokens,), device=device, dtype=torch.long + ) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + # FP8: logical [T, G, D] backed by [G, T, D] row-major + d = heads_per_group * HEAD_DIM + expected_fp8_stride = (d, num_tokens * d, 1) + assert fused_fp8.stride() == expected_fp8_stride, ( + f"FP8 stride mismatch: got {fused_fp8.stride()}, expected {expected_fp8_stride}" + ) + + # Scale: MN-major TMA-aligned layout. After fp8_einsum permutes + # [T,G,S] -> [G,T,S], T-dim should have stride 1. + # Our output is [T,G,S] = transpose of [G,T,S]. + # So fused_scale.permute(1,0,2) should have T-stride=1. + perm = fused_scale.permute(1, 0, 2) # [G, T, S] + assert perm.stride(1) == 1 or num_tokens == 1, ( + f"Scale T-stride (after permute to [G,T,S]) should be 1, got {perm.stride(1)}" + ) + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128]) +@torch.inference_mode() +def test_per_group_contiguity(num_tokens): + """FP8 per-group slices must be contiguous. Scale per-group slices + are column-major (T-stride=1) — not row-major contiguous, which is + correct for TMA loads.""" + num_heads, n_groups = 64, 8 + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.randint( + 0, max_pos, (num_tokens,), device=device, dtype=torch.long + ) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + for g in range(n_groups): + fp8_slice = fused_fp8[:, g, :] + assert fp8_slice.is_contiguous(), ( + f"o_fp8[:, {g}, :] is not contiguous: " + f"shape={list(fp8_slice.shape)}, stride={list(fp8_slice.stride())}" + ) + + +@torch.inference_mode() +def test_scales_are_power_of_two(): + """Verify all scales are exact powers of 2 (UE8M0 property).""" + num_tokens, num_heads, n_groups = 32, 64, 8 + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.randint( + 0, max_pos, (num_tokens,), device=device, dtype=torch.long + ) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + _, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + # log2 of a power-of-two is an exact integer + log2_scales = torch.log2(fused_scale) + residual = (log2_scales - log2_scales.round()).abs() + assert residual.max() < 1e-5, ( + f"Not all scales are powers of 2: max log2 residual = {residual.max().item()}" + ) + + +@torch.inference_mode() +def test_nope_dims_unchanged(): + """Nope dimensions (first 448 per head) should only be quantized, + not rotated. Verify by dequantizing and comparing against + quantize-only reference (no RoPE).""" + num_tokens, num_heads, n_groups = 16, 64, 8 + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + torch.manual_seed(0) + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.randint( + 0, max_pos, (num_tokens,), device=device, dtype=torch.long + ) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + # Fused kernel result + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + # Reference: quantize without RoPE (identity rotation) + # Create a zero-sin cache so RoPE is identity + zero_cache = torch.zeros_like(cos_sin_cache) + half = ROPE_DIM // 2 + zero_cache[:, :half] = 1.0 # cos = 1 + # sin = 0 (already zero) + + norope_fp8, norope_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + zero_cache, + n_groups, + heads_per_group, + ) + + # Extract nope quant blocks only (first 3 of every 4 blocks per head) + chunks_per_head = HEAD_DIM // QUANT_GROUP_SIZE # 4 + + for h in range(heads_per_group): + for c in range(chunks_per_head - 1): # skip last chunk (has rope) + qb = h * chunks_per_head + c + start = qb * QUANT_GROUP_SIZE + end = start + QUANT_GROUP_SIZE + + fused_nope = fused_fp8[:, :, start:end].view(torch.uint8) + norope_nope = norope_fp8[:, :, start:end].view(torch.uint8) + assert torch.equal(fused_nope, norope_nope), ( + f"Nope block (head={h}, chunk={c}) differs between " + f"fused and no-rope reference" + ) + + fused_s = fused_scale[:, :, qb] + norope_s = norope_scale[:, :, qb] + assert torch.equal(fused_s, norope_s), ( + f"Nope scale (head={h}, chunk={c}) differs" + ) + + +@torch.inference_mode() +def test_single_token(): + """Edge case: single token.""" + num_tokens, num_heads, n_groups = 1, 64, 8 + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.tensor([42], device=device, dtype=torch.long) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + ref_fp8, ref_scale = reference_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale) + + +@torch.inference_mode() +def test_zero_positions(): + """Edge case: all positions are 0.""" + num_tokens, num_heads, n_groups = 16, 64, 8 + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.zeros(num_tokens, device=device, dtype=torch.long) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + ref_fp8, ref_scale = reference_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale) + + +@torch.inference_mode() +def test_large_values(): + """Edge case: values near FP8 saturation to test clamping.""" + num_tokens, num_heads, n_groups = 8, 64, 8 + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + + # Create inputs with large values that will saturate FP8 + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + o = o * 1000.0 # scale up to force saturation + positions = torch.randint( + 0, max_pos, (num_tokens,), device=device, dtype=torch.long + ) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + ref_fp8, ref_scale = reference_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + assert_dequant_close(ref_fp8, ref_scale, fused_fp8, fused_scale) + + +@torch.inference_mode() +def test_dequant_numerical_accuracy(): + """Verify dequantized values are close to the original (after inv RoPE).""" + num_tokens, num_heads, n_groups = 32, 64, 8 + heads_per_group = num_heads // n_groups + max_pos = 4096 + device = "cuda" + torch.manual_seed(0) + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.randint( + 0, max_pos, (num_tokens,), device=device, dtype=torch.long + ) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + # Get the post-inv-RoPE values (ground truth before quantization) + o_after_rope = reference_inv_rope(o.clone(), positions, cos_sin_cache) + d = heads_per_group * HEAD_DIM + o_after_rope = o_after_rope.view(num_tokens, n_groups, d) + + # Get fused quantized output + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + # Dequantize: broadcast scale [T, G, S] to [T, G, D] via repeat + S = d // QUANT_GROUP_SIZE + scale_expanded = ( + fused_scale.unsqueeze(-1) + .expand(num_tokens, n_groups, S, QUANT_GROUP_SIZE) + .reshape(num_tokens, n_groups, d) + ) + dequant = fused_fp8.float() * scale_expanded + + # Check relative error. + # FP8 e4m3 with UE8M0 (power-of-two scales that round UP) quantizes more + # coarsely than optimal scaling. Both paths rotate in fp32, so the bulk + # of the error comes from UE8M0 quantization itself (~10-12% typical). + o_gt = o_after_rope.transpose(0, 1).contiguous().transpose(0, 1) + dequant_contig = dequant.transpose(0, 1).contiguous().transpose(0, 1) + + abs_err = (dequant_contig.float() - o_gt.float()).abs() + rel_err = abs_err / (o_gt.float().abs().clamp(min=1e-6)) + mean_rel_err = rel_err.mean().item() + + assert mean_rel_err < 0.15, ( + f"Mean relative error too high: {mean_rel_err:.4f} (expected < 0.15)" + ) + + +def _unfused_inv_rope_fp8_quant( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + n_groups: int, + heads_per_group: int, + nope_dim: int = NOPE_DIM, + rope_dim: int = ROPE_DIM, +) -> tuple[torch.Tensor, torch.Tensor]: + """Unfused path matching deepseek_v4_attention.py:295-310. + + Uses the production CUDA RoPE kernel + per_token_group_quant_fp8. + """ + from vllm import _custom_ops as ops + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, + ) + + head_dim = o.shape[-1] + rope_dim_offset = head_dim - rope_dim + + # Step 1: In-place CUDA RoPE (same as production) + ops.rotary_embedding( + positions, + o, + None, + head_dim, + cos_sin_cache, + False, # is_neox=False for DeepseekV4 (GPT-J style) + rope_dim_offset=rope_dim_offset, + inverse=True, + ) + + # Step 2: Reshape + quant + reshape (same as production) + T = o.shape[0] + d = heads_per_group * head_dim + o = o.view(T, n_groups, -1) + o_flat = o.transpose(0, 1).contiguous().reshape(-1, d) + o_fp8, o_scale = per_token_group_quant_fp8( + o_flat, + group_size=QUANT_GROUP_SIZE, + use_ue8m0=True, + ) + o_fp8 = o_fp8.view(n_groups, T, d).transpose(0, 1) + o_scale = o_scale.view(n_groups, T, -1).transpose(0, 1) + return o_fp8, o_scale + + +# ========================================================================= +# End-to-end test including fp8_einsum +# ========================================================================= + + +@pytest.mark.parametrize("num_tokens", [1, 7, 32, 128, 1024]) +@pytest.mark.parametrize( + "num_heads,n_groups", + [(64, 8)], + ids=["H64_G8"], +) +@torch.inference_mode() +def test_einsum_end_to_end(num_tokens, num_heads, n_groups): + """End-to-end: fused inv_rope+quant → fp8_einsum must match + unfused CUDA_rope+quant → fp8_einsum bitwise. + + This catches stride/layout bugs that only manifest when the einsum + kernel actually consumes the quantized activations. + """ + from deep_gemm.utils.math import ceil_div + + from vllm.utils.deep_gemm import ( + fp8_einsum, + per_block_cast_to_fp8, + transform_sf_into_required_layout, + ) + + heads_per_group = num_heads // n_groups + d = heads_per_group * HEAD_DIM + o_lora_rank = 1024 + max_pos = 4096 + device = "cuda" + torch.manual_seed(0) + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.randint( + 0, max_pos, (num_tokens,), device=device, dtype=torch.long + ) + cos_sin_cache = make_cos_sin_cache(max_pos, device=device) + + # -- Weight quantization (shared between both paths) -- + w = torch.randn(n_groups, o_lora_rank, d, device=device, dtype=torch.bfloat16) + w_fp8 = torch.empty_like(w, dtype=torch.float8_e4m3fn) + w_scale = torch.empty( + n_groups, + ceil_div(o_lora_rank, 128), + ceil_div(d, 128), + device=device, + dtype=torch.float32, + ) + for g in range(n_groups): + w_fp8[g], w_scale[g] = per_block_cast_to_fp8(w[g], use_ue8m0=True) + + recipe = (1, 1, 128) + w_scale_t = transform_sf_into_required_layout( + sf=w_scale, + mn=o_lora_rank, + k=d, + recipe=(1, 128, 128), + num_groups=n_groups, + is_sfa=False, + ) + + # -- UNFUSED path -- + ref_fp8, ref_scale = _unfused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + z_ref = torch.empty( + num_tokens, n_groups, o_lora_rank, device=device, dtype=torch.bfloat16 + ) + fp8_einsum( + "bhr,hdr->bhd", (ref_fp8, ref_scale), (w_fp8, w_scale_t), z_ref, recipe=recipe + ) + + # -- FUSED path -- + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + z_fused = torch.empty( + num_tokens, n_groups, o_lora_rank, device=device, dtype=torch.bfloat16 + ) + fp8_einsum( + "bhr,hdr->bhd", + (fused_fp8, fused_scale), + (w_fp8, w_scale_t), + z_fused, + recipe=recipe, + ) + + # -- Checks -- + # Einsum output: Triton and CUDA both rotate in fp32 now, so diffs + # come from fp32 ordering and UE8M0 boundary shifts only. + # Use relative diff (same metric as test_fp8_einsum.py). + from deep_gemm.testing import calc_diff + + z_diff = calc_diff(z_fused, z_ref) + assert z_diff < 0.01, ( + f"Einsum output diff too large: {z_diff:.6f} (expected < 0.01)" + ) + + +@pytest.mark.parametrize("num_tokens", [1, 32, 256]) +@torch.inference_mode() +def test_with_real_deepseek_v4_rope(num_tokens, default_vllm_config): + """Test with real DeepseekV4ScalingRotaryEmbedding (GPT-J style, + mscale=0, YaRN scaling) matching the production config.""" + + num_heads = 64 + n_groups = 8 + heads_per_group = num_heads // n_groups + device = "cuda" + torch.manual_seed(0) + + # Build YaRN-scaled cos_sin_cache matching real DeepSeek V3/V4 config + # (mscale=0 → mscale=1.0, so no magnitude scaling) + from vllm.model_executor.layers.rotary_embedding.common import ( + yarn_find_correction_range, + yarn_linear_ramp_mask, + ) + + scaling_factor = 16 + base = 10000.0 + max_pos = 65536 + beta_fast, beta_slow = 32, 1 + + pos_freqs = base ** ( + torch.arange(0, ROPE_DIM, 2, dtype=torch.float32, device=device) / ROPE_DIM + ) + inv_freq_extra = 1.0 / pos_freqs + inv_freq_interp = 1.0 / (scaling_factor * pos_freqs) + low, high = yarn_find_correction_range( + beta_fast, beta_slow, ROPE_DIM, base, max_pos + ) + mask = 1 - yarn_linear_ramp_mask(low, high, ROPE_DIM // 2, dtype=torch.float32).to( + device + ) + inv_freq = inv_freq_interp * (1 - mask) + inv_freq_extra * mask + t = torch.arange(max_pos * scaling_factor, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + # mscale=0 → yarn_get_mscale returns 1.0 + cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # fp32 + + o = torch.randn( + num_tokens, num_heads, HEAD_DIM, device=device, dtype=torch.bfloat16 + ) + positions = torch.randint(0, 4096, (num_tokens,), device=device, dtype=torch.long) + + # UNFUSED: CUDA RoPE with is_neox=False (GPT-J) + from vllm import _custom_ops as ops + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, + ) + + o_unfused = o.clone() + ops.rotary_embedding( + positions, + o_unfused, + None, + HEAD_DIM, + cos_sin_cache, + False, # is_neox=False (GPT-J style) + rope_dim_offset=NOPE_DIM, + inverse=True, + ) + d = heads_per_group * HEAD_DIM + T = num_tokens + o_unfused = o_unfused.view(T, n_groups, d) + o_flat = o_unfused.transpose(0, 1).contiguous().reshape(-1, d) + ref_fp8, ref_scale = per_token_group_quant_fp8( + o_flat, + group_size=QUANT_GROUP_SIZE, + use_ue8m0=True, + ) + ref_fp8 = ref_fp8.view(n_groups, T, d).transpose(0, 1) + ref_scale = ref_scale.view(n_groups, T, -1).transpose(0, 1) + + # FUSED: use the real YaRN-scaled cos_sin_cache + fused_fp8, fused_scale = fused_inv_rope_fp8_quant( + o.clone(), + positions, + cos_sin_cache, + n_groups, + heads_per_group, + ) + + # Scales must match exactly (same UE8M0 algorithm) + # Compare via dequant (Triton bf16 rotation may differ from CUDA by 1 ULP) + assert_dequant_close( + ref_fp8, ref_scale, fused_fp8, fused_scale, msg="Real DeepSeek V4 rope" + ) diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index 40bd2af65ea1..7b9c11495e8b 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -718,7 +718,6 @@ def test_persistent_topk_stress() -> None: pytest.param( { "seq_lens": [2000, 6000, 30000, 80000], - "top_k": 2048, "data_type": "random", }, id="mixed_all_paths", @@ -727,7 +726,6 @@ def test_persistent_topk_stress() -> None: pytest.param( { "seq_lens": [2048, 4096, 8192, 16000], - "top_k": 2048, "data_type": "random", }, id="all_decode_medium", @@ -736,7 +734,6 @@ def test_persistent_topk_stress() -> None: pytest.param( { "seq_lens": [70000, 100000, 163840], - "top_k": 2048, "data_type": "random", }, id="all_large", @@ -745,7 +742,6 @@ def test_persistent_topk_stress() -> None: pytest.param( { "seq_lens": [32767, 32768, 32769, 32772], - "top_k": 2048, "data_type": "random", }, id="large_threshold_boundary", @@ -754,7 +750,6 @@ def test_persistent_topk_stress() -> None: pytest.param( { "seq_lens": [5000], - "top_k": 2048, "data_type": "random", }, id="single_row_medium", @@ -772,15 +767,15 @@ def test_persistent_topk_stress() -> None: pytest.param( { "seq_lens": [100, 2048, 10000, 80000], - "top_k": 2048, "data_type": "random", }, id="trivial_medium_large_mix", ), ], ) +@pytest.mark.parametrize("top_k", [512, 2048]) @torch.inference_mode() -def test_persistent_topk(test_config: dict) -> None: +def test_persistent_topk(test_config: dict, top_k: int) -> None: """ Tests specific to the persistent_topk kernel: - Mixed medium/large rows in the same batch (dynamic per-row dispatch) @@ -790,14 +785,15 @@ def test_persistent_topk(test_config: dict) -> None: run_large_context_topk_test( batch_size=len(test_config["seq_lens"]), seq_lens=test_config["seq_lens"], - top_k=test_config["top_k"], + top_k=top_k, data_type=test_config.get("data_type", "random"), ) @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@pytest.mark.parametrize("top_k", [512, 2048]) @torch.inference_mode() -def test_persistent_topk_padded_stride() -> None: +def test_persistent_topk_padded_stride(top_k: int) -> None: """ Test persistent_topk with padded logits (large stride, small seq_len) to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits @@ -806,7 +802,6 @@ def test_persistent_topk_padded_stride() -> None: set_random_seed(42) torch.set_default_device("cuda:0") - top_k = 2048 batch_size = 4 padded_stride = 163840 # DeepSeek-V3.2 max_model_len actual_seq_lens = [3000, 5000, 8000, 12000] diff --git a/tests/model_executor/test_routed_experts_capture.py b/tests/model_executor/test_routed_experts_capture.py index 770a3fa53850..0527417d1506 100644 --- a/tests/model_executor/test_routed_experts_capture.py +++ b/tests/model_executor/test_routed_experts_capture.py @@ -41,7 +41,9 @@ class DummyRouter(BaseRouter): def routing_method_type(self) -> RoutingMethodType: return RoutingMethodType.FUSED_TOPK - def _compute_routing(self, hidden_states, router_logits, indices_type): + def _compute_routing( + self, hidden_states, router_logits, indices_type, *, input_ids=None + ): topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64) topk_weights = torch.ones_like(topk_ids, dtype=torch.float32) return topk_weights, topk_ids diff --git a/tests/models/registry.py b/tests/models/registry.py index 09e1ee42f2e8..962cb7f62faa 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -260,6 +260,7 @@ def check_available_online( trust_remote_code=True, ), "DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"), + "DeepseekV4ForCausalLM": _HfExamplesInfo("Placeholder", is_available_online=False), "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"), "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"), "ExaoneForCausalLM": _HfExamplesInfo( @@ -1482,6 +1483,7 @@ def check_available_online( speculative_model="luccafong/deepseek_mtp_draft_random", trust_remote_code=True, ), + "DeepSeekV4MTPModel": _HfExamplesInfo("Placeholder", is_available_online=False), "ErnieMTPModel": _HfExamplesInfo( "baidu/ERNIE-4.5-21B-A3B-PT", trust_remote_code=True, diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py new file mode 100644 index 000000000000..dd2c25e2622f --- /dev/null +++ b/tests/models/test_deepseek_v4_mega_moe.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import pytest +import torch + +from vllm.model_executor.models.deepseek_v4 import ( + DeepseekV4MegaMoEExperts, + _stage_deepseek_v4_mega_moe_inputs, + make_deepseek_v4_expert_params_mapping, +) +from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8 + + +def test_deepseek_v4_mega_moe_expert_mapping(): + mapping = make_deepseek_v4_expert_params_mapping(2) + + assert mapping == [ + ("experts.w13_", "experts.0.w1.", 0, "w1"), + ("experts.w2_", "experts.0.w2.", 0, "w2"), + ("experts.w13_", "experts.0.w3.", 0, "w3"), + ("experts.w13_", "experts.1.w1.", 1, "w1"), + ("experts.w2_", "experts.1.w2.", 1, "w2"), + ("experts.w13_", "experts.1.w3.", 1, "w3"), + ] + + +def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float(): + raw = torch.tensor([0, 126, 127, 128], dtype=torch.uint8) + + decoded = DeepseekV4MegaMoEExperts._ue8m0_uint8_to_float(raw) + + assert torch.equal(decoded.view(torch.int32), raw.to(torch.int32) << 23) + assert decoded[0].item() == 0.0 + assert decoded[1].item() == 0.5 + assert decoded[2].item() == 1.0 + assert decoded[3].item() == 2.0 + + +def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): + vllm_config = SimpleNamespace( + scheduler_config=SimpleNamespace(max_num_batched_tokens=4), + compilation_config=SimpleNamespace(static_forward_context={}), + ) + experts = DeepseekV4MegaMoEExperts( + vllm_config, + num_experts=4, + num_local_experts=2, + experts_start_idx=2, + top_k=2, + hidden_size=128, + intermediate_size=128, + ) + + nonlocal_weight = torch.ones(128, 64, dtype=torch.uint8) + assert ( + experts.weight_loader( + experts.w13_weight, + nonlocal_weight, + "experts.w13_weight", + shard_id="w1", + expert_id=1, + return_success=True, + ) + is False + ) + + w1 = torch.full((128, 64), 3, dtype=torch.uint8) + w3 = torch.full((128, 64), 7, dtype=torch.uint8) + w2 = torch.full((128, 64), 11, dtype=torch.uint8) + + assert experts.weight_loader( + experts.w13_weight, + w1, + "experts.w13_weight", + shard_id="w1", + expert_id=2, + return_success=True, + ) + assert experts.weight_loader( + experts.w13_weight, + w3, + "experts.w13_weight", + shard_id="w3", + expert_id=2, + return_success=True, + ) + assert experts.weight_loader( + experts.w2_weight, + w2, + "experts.w2_weight", + shard_id="w2", + expert_id=2, + return_success=True, + ) + + assert torch.equal(experts.w13_weight[0, :128], w1) + assert torch.equal(experts.w13_weight[0, 128:], w3) + assert torch.equal(experts.w2_weight[0], w2) + assert torch.count_nonzero(experts.w13_weight[1]) == 0 + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.", +) +def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact(): + device = torch.device("cuda") + num_tokens = 7 + hidden_size = 256 + top_k = 8 + + generator = torch.Generator(device=device) + generator.manual_seed(0) + hidden_states = ( + torch.randn( + num_tokens, + hidden_size, + device=device, + dtype=torch.float32, + generator=generator, + ) + * 17.0 + ).to(torch.bfloat16) + hidden_states[0, :32] = 0 + hidden_states[1, 32:64] = 1.0e-6 + hidden_states[2, 64:96] = -1.0e-6 + + topk_ids = torch.randint( + 0, + 256, + (num_tokens, top_k), + device=device, + dtype=torch.int32, + generator=generator, + ) + topk_weights = torch.randn( + num_tokens, + top_k, + device=device, + dtype=torch.float32, + generator=generator, + ) + + ref_x, ref_x_sf = per_token_cast_to_fp8( + hidden_states, + use_ue8m0=True, + gran_k=32, + use_packed_ue8m0=True, + ) + ref_topk_idx = topk_ids.to(torch.int64) + ref_topk_weights = topk_weights.clone() + + fused_x = torch.empty_like(ref_x) + fused_x_sf = torch.empty_like(ref_x_sf) + fused_topk_idx = torch.empty_like(ref_topk_idx) + fused_topk_weights = torch.empty_like(ref_topk_weights) + + _stage_deepseek_v4_mega_moe_inputs( + hidden_states, + topk_weights, + topk_ids, + fused_x, + fused_x_sf, + fused_topk_idx, + fused_topk_weights, + ) + torch.accelerator.synchronize() + + assert torch.equal(fused_x.view(torch.uint8), ref_x.view(torch.uint8)) + assert torch.equal(fused_x_sf, ref_x_sf) + assert torch.equal(fused_topk_idx, ref_topk_idx) + assert torch.equal( + fused_topk_weights.view(torch.uint8), + ref_topk_weights.view(torch.uint8), + ) diff --git a/tests/models/test_deepseek_v4_pp.py b/tests/models/test_deepseek_v4_pp.py new file mode 100644 index 000000000000..7c0ae5dfd725 --- /dev/null +++ b/tests/models/test_deepseek_v4_pp.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.models.deepseek_v4 import DeepseekV4ForCausalLM +from vllm.model_executor.models.interfaces import supports_pp + + +def test_deepseek_v4_declares_pipeline_parallel_support(): + assert supports_pp(DeepseekV4ForCausalLM) diff --git a/tests/quantization/test_mxfp4.py b/tests/quantization/test_mxfp4.py new file mode 100644 index 000000000000..9295a6708c08 --- /dev/null +++ b/tests/quantization/test_mxfp4.py @@ -0,0 +1,37 @@ + + +def test_mxfp4_e8m0_scale_loading_preserves_raw_bytes(): + from types import SimpleNamespace + + import pytest + import torch + + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is None: + pytest.skip("torch does not expose float8_e8m0fnu") + + layer = object.__new__(FusedMoE) + layer.moe_config = SimpleNamespace(is_act_and_mul=True) + + expert_data = torch.zeros((4, 2), dtype=torch.uint8) + loaded_scale = torch.tensor( + [[0.0078125, 0.015625], [0.5, 1.0]], + dtype=e8m0_dtype, + ) + + layer._load_w13( + expert_data=expert_data, + shard_dim=0, + shard_id="w1", + loaded_weight=loaded_scale, + tp_rank=0, + ) + + torch.testing.assert_close( + expert_data[:2], + loaded_scale.view(torch.uint8), + rtol=0, + atol=0, + ) diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py index 4b0938d15520..f5b37194f927 100644 --- a/tests/reasoning/test_deepseekv3_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py @@ -6,6 +6,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.reasoning import ReasoningParserManager from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser @@ -33,6 +34,12 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type): assert isinstance(parser._parser, expected_parser_type) +def test_deepseek_v4_reasoning_parser_alias(): + parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4") + + assert parser_cls is DeepSeekV3ReasoningParser + + def test_identity_reasoning_parser_basic(tokenizer): parser = IdentityReasoningParser(tokenizer) diff --git a/tests/tokenizers_/fixtures/deepseek_v4/test_input_1.json b/tests/tokenizers_/fixtures/deepseek_v4/test_input_1.json new file mode 100644 index 000000000000..35e49588dfa3 --- /dev/null +++ b/tests/tokenizers_/fixtures/deepseek_v4/test_input_1.json @@ -0,0 +1,81 @@ +{ + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a specific location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit" + } + }, + "required": ["location"] + } + } + }, + { + "type": "function", + "function": { + "name": "search", + "description": "Search the web for information", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query" + }, + "num_results": { + "type": "integer", + "description": "Number of results to return" + } + }, + "required": ["query"] + } + } + } + ], + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather in Beijing?" + }, + { + "role": "assistant", + "reasoning": "The user wants to know the weather in Beijing. I should use the get_weather tool.", + "tool_calls": [ + { + "id": "call_001", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}" + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "call_001", + "content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}" + }, + { + "role": "assistant", + "reasoning": "Got the weather data. Let me format a nice response.", + "content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity." + } + ] +} diff --git a/tests/tokenizers_/fixtures/deepseek_v4/test_input_2.json b/tests/tokenizers_/fixtures/deepseek_v4/test_input_2.json new file mode 100644 index 000000000000..a301609ac2b7 --- /dev/null +++ b/tests/tokenizers_/fixtures/deepseek_v4/test_input_2.json @@ -0,0 +1,24 @@ +[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "reasoning": "The user said hello, I should greet back.", + "content": "Hi there! How can I help you?" + }, + { + "role": "user", + "content": "What is the capital of France?" + }, + { + "role": "assistant", + "reasoning": "The user asks about the capital of France. It is Paris.", + "content": "The capital of France is Paris." + } +] \ No newline at end of file diff --git a/tests/tokenizers_/fixtures/deepseek_v4/test_input_3.json b/tests/tokenizers_/fixtures/deepseek_v4/test_input_3.json new file mode 100644 index 000000000000..d2dc42e3de20 --- /dev/null +++ b/tests/tokenizers_/fixtures/deepseek_v4/test_input_3.json @@ -0,0 +1,159 @@ +[ + { + "role": "system", + "content": "该助手为DeepSeek,由深度求索公司创造。" + }, + { + "role": "latest_reminder", + "content": "2026-02-21,星期六,广州,App,中文" + }, + { + "role": "developer", + "content": "小柴胡冲剂和布洛芬能一起吃吗?\n\nCITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】", + "tools": [ + { + "type": "function", + "function": { + "name": "search", + "description": "Web search. Split multiple queries with '||'.", + "parameters": { + "type": "object", + "properties": { + "queries": { + "type": "string", + "description": "query1||query2" + } + }, + "required": [ + "queries" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }, + { + "type": "function", + "function": { + "name": "open", + "description": "Batch open IDs (format 【{id}†...】) or URLs.", + "parameters": { + "type": "object", + "properties": { + "open_list": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "description": "ID or URL", + "anyOf": [ + { + "type": "integer" + }, + { + "type": "string" + } + ], + "default": -1 + }, + "cursor": { + "type": "integer", + "description": "", + "default": -1 + }, + "loc": { + "type": "integer", + "description": "Start line", + "default": -1 + }, + "num_lines": { + "type": "integer", + "description": "", + "default": -1 + }, + "view_source": { + "type": "boolean", + "description": "", + "default": false + } + }, + "additionalProperties": false + }, + "description": "" + } + }, + "required": [ + "open_list" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }, + { + "type": "function", + "function": { + "name": "find", + "description": "Find exact text pattern in pages.", + "parameters": { + "type": "object", + "properties": { + "find_list": { + "type": "array", + "items": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "" + }, + "cursor": { + "type": "integer", + "description": "", + "default": -1 + } + }, + "required": [ + "pattern" + ], + "additionalProperties": false + }, + "description": "" + } + }, + "required": [ + "find_list" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + } + ] + }, + { + "role": "assistant", + "content": "", + "reasoning": "用户想知道小柴胡冲剂和布洛芬能否一起服用。", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "search", + "arguments": "{\"queries\": \"小柴胡冲剂 布洛芬 相互作用 一起吃\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[0]" + }, + { + "role": "assistant", + "content": "请及时就医。", + "reasoning": "现在开始组织回答。", + "tool_calls": [] + } +] \ No newline at end of file diff --git a/tests/tokenizers_/fixtures/deepseek_v4/test_input_4.json b/tests/tokenizers_/fixtures/deepseek_v4/test_input_4.json new file mode 100644 index 000000000000..d5e0791dd699 --- /dev/null +++ b/tests/tokenizers_/fixtures/deepseek_v4/test_input_4.json @@ -0,0 +1,28 @@ +[ + { + "role": "system", + "content": "该助手为DeepSeek-V3,由深度求索公司创造。\n今天是2025年10月17日,星期五。" + }, + { + "role": "latest_reminder", + "content": "2024-11-15,上海市,App,中文" + }, + { + "role": "user", + "content": "热海大滚锅是世界著名温泉吗" + }, + { + "role": "assistant", + "content": "关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。\n\n为了让你快速了解,我整理了一个简要的表格:\n\n| 维度 | 热海大滚锅的地位与特点 |\n| :--- | :--- |\n| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 |\n| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 |\n| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 |\n| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 |\n\n### 💡 游玩攻略与温馨提示\n\n如果你计划前往热海大滚锅,这里有一些实用信息供你参考:\n\n- **门票与开放时间**:\n - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。\n - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。\n\n- **特色体验**:\n - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。\n - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。\n\n- **注意事项**:\n - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。\n - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。\n\n希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。", + "mask": 1 + }, + { + "role": "user", + "content": "世界著名温泉有哪些", + "task": "action" + }, + { + "role": "assistant", + "content": "Search" + } +] \ No newline at end of file diff --git a/tests/tokenizers_/fixtures/deepseek_v4/test_output_1.txt b/tests/tokenizers_/fixtures/deepseek_v4/test_output_1.txt new file mode 100644 index 000000000000..dbd823476c1c --- /dev/null +++ b/tests/tokenizers_/fixtures/deepseek_v4/test_output_1.txt @@ -0,0 +1,36 @@ +<|begin▁of▁sentence|> + +## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following: + +<|DSML|tool_calls> +<|DSML|invoke name="$TOOL_NAME"> +<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<|DSML|invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by ), you MUST output your complete reasoning inside ... BEFORE any tool calls or final response. + +Otherwise, output directly after with tool calls or final response. + +### Available Tool Schemas + +{"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}} +{"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +You are a helpful assistant.<|User|>What's the weather in Beijing?<|Assistant|>The user wants to know the weather in Beijing. I should use the get_weather tool. + +<|DSML|tool_calls> +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Beijing +<|DSML|parameter name="unit" string="true">celsius + +<|end▁of▁sentence|><|User|>{"temperature": 22, "condition": "sunny", "humidity": 45}<|Assistant|>Got the weather data. Let me format a nice response.The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/tokenizers_/fixtures/deepseek_v4/test_output_2.txt b/tests/tokenizers_/fixtures/deepseek_v4/test_output_2.txt new file mode 100644 index 000000000000..fc397ef54972 --- /dev/null +++ b/tests/tokenizers_/fixtures/deepseek_v4/test_output_2.txt @@ -0,0 +1 @@ +<|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|>The user asks about the capital of France. It is Paris.The capital of France is Paris.<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/tokenizers_/fixtures/deepseek_v4/test_output_3.txt b/tests/tokenizers_/fixtures/deepseek_v4/test_output_3.txt new file mode 100644 index 000000000000..edee563300d4 --- /dev/null +++ b/tests/tokenizers_/fixtures/deepseek_v4/test_output_3.txt @@ -0,0 +1,38 @@ +<|begin▁of▁sentence|>该助手为DeepSeek,由深度求索公司创造。<|latest_reminder|>2026-02-21,星期六,广州,App,中文<|User|>小柴胡冲剂和布洛芬能一起吃吗? + +CITATION FORMAT: 【{cursor_id}†L{start_line_id}(-L{end_line_id})?】 + +## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following: + +<|DSML|tool_calls> +<|DSML|invoke name="$TOOL_NAME"> +<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<|DSML|invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by ), you MUST output your complete reasoning inside ... BEFORE any tool calls or final response. + +Otherwise, output directly after with tool calls or final response. + +### Available Tool Schemas + +{"name": "search", "description": "Web search. Split multiple queries with '||'.", "parameters": {"type": "object", "properties": {"queries": {"type": "string", "description": "query1||query2"}}, "required": ["queries"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} +{"name": "open", "description": "Batch open IDs (format 【{id}†...】) or URLs.", "parameters": {"type": "object", "properties": {"open_list": {"type": "array", "items": {"type": "object", "properties": {"id": {"description": "ID or URL", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "", "default": -1}, "loc": {"type": "integer", "description": "Start line", "default": -1}, "num_lines": {"type": "integer", "description": "", "default": -1}, "view_source": {"type": "boolean", "description": "", "default": false}}, "additionalProperties": false}, "description": ""}}, "required": ["open_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} +{"name": "find", "description": "Find exact text pattern in pages.", "parameters": {"type": "object", "properties": {"find_list": {"type": "array", "items": {"type": "object", "properties": {"pattern": {"type": "string", "description": ""}, "cursor": {"type": "integer", "description": "", "default": -1}}, "required": ["pattern"], "additionalProperties": false}, "description": ""}}, "required": ["find_list"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +<|Assistant|>用户想知道小柴胡冲剂和布洛芬能否一起服用。 + +<|DSML|tool_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="queries" string="true">小柴胡冲剂 布洛芬 相互作用 一起吃 + +<|end▁of▁sentence|><|User|>[0]<|Assistant|>现在开始组织回答。请及时就医。<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/tokenizers_/fixtures/deepseek_v4/test_output_4.txt b/tests/tokenizers_/fixtures/deepseek_v4/test_output_4.txt new file mode 100644 index 000000000000..d30bd5d06cf3 --- /dev/null +++ b/tests/tokenizers_/fixtures/deepseek_v4/test_output_4.txt @@ -0,0 +1,29 @@ +<|begin▁of▁sentence|>该助手为DeepSeek-V3,由深度求索公司创造。 +今天是2025年10月17日,星期五。<|latest_reminder|>2024-11-15,上海市,App,中文<|User|>热海大滚锅是世界著名温泉吗<|Assistant|>关于热海大滚锅是否是世界著名温泉,可以这样总结:它在**中国乃至全球的地热奇观中占有重要地位**,但“世界著名”的称号更侧重于它作为**独特的地质现象和旅游景点**,而非普遍的温泉疗养体验。 + +为了让你快速了解,我整理了一个简要的表格: + +| 维度 | 热海大滚锅的地位与特点 | +| :--- | :--- | +| **地质奇观** | **世界罕见**的高温地热系统,泉眼多、水温高、形态多样。 | +| **国内声誉** | **中国三大地热区之一**,**国家5A级旅游景区**,自古闻名(徐霞客曾游历并记载)。 | +| **国际知名度** | 在一些旅游平台被提及为“世界六大温泉”之一,但此说法流传不广,其国际声誉更多建立在地质独特性上。 | +| **核心体验** | **观赏地热奇观**(如97℃沸腾的“大滚锅”)、**体验温泉煮鸡蛋**。 | + +### 💡 游玩攻略与温馨提示 + +如果你计划前往热海大滚锅,这里有一些实用信息供你参考: + +- **门票与开放时间**: + - **门票**:景区门票约为**50元/人**。如果选择包含温泉沐浴的套餐,价格会更高,例如约**288元**。 + - **开放时间**:景区一般**08:00-18:00**开放,但具体时间可能变动,建议提前核实。 + +- **特色体验**: + - **温泉煮鸡蛋**:这几乎是必试项目。可以在景区门口购买用草绳串起的生鸡蛋(约5-8元/串),然后到“大滚锅”旁的指定区域蒸煮,几分钟便可熟食,趣味十足。 + - **金汤足浴**:可以直接用从“大滚锅”流出的温泉水泡脚,缓解旅途疲劳。 + +- **注意事项**: + - **安全第一**:“大滚锅”水温极高,务必遵守游览规则,在指定区域内观赏,切勿随意触碰泉水。 + - **规划行程**:建议为热海景区预留**3-4小时**的游览时间。景区内步道不走回头路,出入口有观光车接送。 + +希望这些信息能帮助你更好地了解热海大滚锅。如果你对腾冲的其他景点或者行程规划有更多疑问,我很乐意提供进一步的信息。<|end▁of▁sentence|><|User|>世界著名温泉有哪些<|Assistant|><|action|>Search<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/tokenizers_/test_deepseek_v4.py b/tests/tokenizers_/test_deepseek_v4.py new file mode 100644 index 000000000000..9f3b88cf658d --- /dev/null +++ b/tests/tokenizers_/test_deepseek_v4.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from vllm.entrypoints.chat_utils import parse_chat_messages +from vllm.renderers.registry import RENDERER_REGISTRY +from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer +from vllm.tokenizers.registry import TokenizerRegistry + +FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4" + + +class FakeHfTokenizer: + vocab_size = 100 + + def get_added_vocab(self) -> dict[str, int]: + return {"": 100} + + def encode( + self, + text: str, + add_special_tokens: bool = False, + **kwargs, + ) -> list[int]: + self.last_encode = (text, add_special_tokens, kwargs) + return [len(text)] + + +def _tokenizer(): + return get_deepseek_v4_tokenizer(FakeHfTokenizer()) + + +def _model_config(): + return SimpleNamespace( + multimodal_config=None, + allowed_local_media_path="", + allowed_media_domains=None, + ) + + +def _load_reference_case(case_id: int): + data = json.loads((FIXTURES_DIR / f"test_input_{case_id}.json").read_text()) + if isinstance(data, dict): + return data["messages"], data.get("tools") + return data, None + + +def _render_reference_case(case_id: int, **kwargs): + messages, tools = _load_reference_case(case_id) + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + return _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tools=tools, + tokenize=False, + **kwargs, + ) + + +def test_deepseek_v4_tokenizer_registered(): + assert TokenizerRegistry.load_tokenizer_cls("deepseek_v4").__name__ == ( + "DeepseekV4Tokenizer" + ) + assert RENDERER_REGISTRY.load_renderer_cls("deepseek_v4").__name__ == ( + "DeepseekV4Renderer" + ) + + +def test_deepseek_v4_defaults_to_chat_mode(): + prompt = _tokenizer().apply_chat_template( + [{"role": "user", "content": "Hello"}], + tokenize=False, + ) + + assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") + + +@pytest.mark.parametrize("kwargs", [{"thinking": True}, {"enable_thinking": True}]) +def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs): + prompt = _tokenizer().apply_chat_template( + [{"role": "user", "content": "Hello"}], + tokenize=False, + **kwargs, + ) + + assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") + + +def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools(): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + + prompt = _tokenizer().apply_chat_template( + [{"role": "user", "content": "Weather?"}], + tools=tools, + tokenize=False, + ) + + assert "## Tools" in prompt + assert "<|DSML|tool_calls>" in prompt + assert "" in prompt + assert "function_calls" not in prompt + assert '"name": "get_weather"' in prompt + assert prompt.endswith("<|User|>Weather?<|Assistant|>") + + +def test_deepseek_v4_renders_parsed_history_tool_arguments(): + messages = [ + {"role": "user", "content": "List the repo"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "str_replace_editor", + "arguments": '{"command": "view", "path": "/testbed"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": "file list", + }, + ] + tools = [ + { + "type": "function", + "function": { + "name": "str_replace_editor", + "description": "Edit files", + "parameters": { + "type": "object", + "properties": { + "command": {"type": "string"}, + "path": {"type": "string"}, + }, + "required": ["command", "path"], + }, + }, + } + ] + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tools=tools, + tokenize=False, + ) + + assert '<|DSML|parameter name="command" string="true">view' in prompt + assert '<|DSML|parameter name="path" string="true">/testbed' in prompt + assert 'parameter name="arguments"' not in prompt + + +@pytest.mark.parametrize("reasoning_effort", ["none", "low", "medium", "high"]) +def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort): + prompt = _tokenizer().apply_chat_template( + [{"role": "user", "content": "Hello"}], + tokenize=False, + enable_thinking=True, + reasoning_effort=reasoning_effort, + ) + + assert prompt.endswith("<|Assistant|>") + assert "Reasoning Effort: Absolute maximum" not in prompt + + +def test_deepseek_v4_preserves_reference_max_reasoning_effort(): + prompt = _tokenizer().apply_chat_template( + [{"role": "user", "content": "Hello"}], + tokenize=False, + enable_thinking=True, + reasoning_effort="max", + ) + + assert prompt.startswith( + "<|begin▁of▁sentence|>Reasoning Effort: Absolute maximum" + ) + + +@pytest.mark.parametrize( + ("case_id", "kwargs"), + [ + (1, {"thinking": True}), + (2, {"thinking": True}), + (3, {"thinking": True}), + (4, {}), + ], +) +def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs): + prompt = _render_reference_case(case_id, **kwargs) + + expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text() + assert prompt == expected diff --git a/tests/tool_parsers/test_deepseekv4_tool_parser.py b/tests/tool_parsers/test_deepseekv4_tool_parser.py new file mode 100644 index 000000000000..631d0fb97b33 --- /dev/null +++ b/tests/tool_parsers/test_deepseekv4_tool_parser.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for DeepSeekV4ToolParser.""" + +import json +from unittest.mock import MagicMock + +from vllm.tool_parsers import ToolParserManager +from vllm.tool_parsers.deepseekv4_tool_parser import DeepSeekV4ToolParser + +MOCK_TOKENIZER = MagicMock() +MOCK_TOKENIZER.get_vocab.return_value = {} + +TC_START = "<|DSML|tool_calls>" +TC_END = "" +INV_START = '<|DSML|invoke name="' +INV_END = "" +PARAM_START = '<|DSML|parameter name="' +PARAM_END = "" + + +def make_parser(tools=None) -> DeepSeekV4ToolParser: + return DeepSeekV4ToolParser(MOCK_TOKENIZER, tools=tools) + + +def make_request(tools=None) -> MagicMock: + req = MagicMock() + req.tools = tools + return req + + +def build_tool_call(func_name: str, params: dict[str, str]) -> str: + param_strs = "".join( + f'{PARAM_START}{k}" string="true">{v}{PARAM_END}\n' for k, v in params.items() + ) + return f'{TC_START}\n{INV_START}{func_name}">\n{param_strs}{INV_END}\n{TC_END}' + + +def stream(parser: DeepSeekV4ToolParser, full_text: str, chunk_size: int = 7): + deltas = [] + previous_text = "" + for start in range(0, len(full_text), chunk_size): + delta_text = full_text[start : start + chunk_size] + current_text = previous_text + delta_text + delta = parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[1], + request=make_request(), + ) + previous_text = current_text + if delta is not None: + deltas.append(delta) + return deltas + + +def reconstruct_args(deltas, tool_index: int = 0) -> str: + fragments = [] + for delta in deltas: + if delta.tool_calls: + for tool_call in delta.tool_calls: + if ( + tool_call.index == tool_index + and tool_call.function + and tool_call.function.arguments + ): + fragments.append(tool_call.function.arguments) + return "".join(fragments) + + +def test_registered(): + assert ToolParserManager.get_tool_parser("deepseek_v4") is DeepSeekV4ToolParser + + +def test_extract_tool_calls(): + parser = make_parser() + model_output = "Let me check. " + build_tool_call( + "get_weather", {"location": "Beijing", "unit": "celsius"} + ) + + result = parser.extract_tool_calls(model_output, make_request()) + + assert result.tools_called + assert result.content == "Let me check. " + assert len(result.tool_calls) == 1 + tool_call = result.tool_calls[0] + assert tool_call.function.name == "get_weather" + assert json.loads(tool_call.function.arguments) == { + "location": "Beijing", + "unit": "celsius", + } + + +def test_function_calls_block_is_not_accepted(): + parser = make_parser() + model_output = build_tool_call("search", {"query": "vllm"}).replace( + "tool_calls", "function_calls" + ) + + result = parser.extract_tool_calls(model_output, make_request()) + + assert not result.tools_called + assert result.content == model_output + + +def test_streaming_extracts_complete_invokes(): + parser = make_parser() + full_text = build_tool_call("search", {"query": "deepseek v4"}) + + deltas = stream(parser, full_text, chunk_size=5) + + names = [ + tool_call.function.name + for delta in deltas + if delta.tool_calls + for tool_call in delta.tool_calls + ] + assert names == ["search"] + assert json.loads(reconstruct_args(deltas)) == {"query": "deepseek v4"} diff --git a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py new file mode 100644 index 000000000000..c3a7cb0a20ae --- /dev/null +++ b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py @@ -0,0 +1,2465 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Correctness tests for the DeepSeek V4 sparse MLA reference path.""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm.config.compilation import CompilationMode, CUDAGraphMode +from vllm.model_executor.layers import ( + deepseek_v4_attention as deepseek_v4_attention_module, +) +from vllm.model_executor.layers.deepseek_v4_attention import ( + _deepseek_v4_fp8_einsum_config, + _sparse_mla_prefill_workspace_bounds, + deepseek_v4_fp8_einsum, +) +from vllm.utils.deep_gemm import fp8_einsum +from vllm.v1.attention.backend import AttentionCGSupport +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + FlashMLASparseMetadataBuilder, +) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + disable_sparse_mla_reference_cudagraphs_if_enabled, + sparse_mla_reference_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk, + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_fp8ds_paged_sparse_mla_attention_chunk, + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead, + accumulate_gathered_sparse_mla_attention_chunk, + accumulate_indexed_sparse_mla_attention_chunk, + build_combined_sparse_mla_decode_valid_mask, + finish_gathered_sparse_mla_attention, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead, + fp8ds_paged_sparse_mla_attention_with_sink_multihead, + matmul_sparse_mla_attention_with_sink, + merge_sparse_mla_subset_with_sink, + merge_two_sparse_mla_subsets_with_sink, + sparse_mla_decode_head_block_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_reference import ( + accumulate_reference_attention_chunk, + finish_reference_attention_no_sink, + merge_reference_attention_with_sink, + new_reference_attention_state, + reference_attention_no_sink, + reference_sparse_mla_prefill, + sink_aware_reference_attention, +) +from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadataBuilder +from vllm.v1.attention.ops.deepseek_v4_ops import ( + dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( + deepseek_v4_sm12_fp8_einsum, +) +from vllm.v1.kv_cache_interface import MLAAttentionSpec, SlidingWindowMLASpec + +_FP8_DIM = 448 +_ROPE_DIM = 64 +_SCALE_DIM = 8 +_TOKEN_DATA_SIZE = _FP8_DIM + _ROPE_DIM * 2 + + +class _FakeWorkspaceManager: + + def get_simultaneous(self, *specs): + return tuple(torch.empty(shape, dtype=dtype) for shape, dtype in specs) + + +def test_triton_sparse_mla_default_topk_chunk_size(monkeypatch) -> None: + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + + assert sparse_mla_reference_topk_chunk_size() == 512 + + +def test_sparse_mla_prefill_workspace_bounds_use_active_prefill_lengths() -> None: + seq_lens_cpu = torch.tensor([15_000, 2_048], dtype=torch.int32) + gather_lens_cpu = torch.tensor([15_000, 2_048], dtype=torch.int32) + + compressed_region_size, row_stride = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=4, + swa_only=False, + ) + + assert compressed_region_size == 3_750 + assert row_stride == 18_750 + + +def test_sparse_mla_prefill_workspace_bounds_for_swa_only() -> None: + seq_lens_cpu = torch.tensor([15_000], dtype=torch.int32) + gather_lens_cpu = torch.tensor([15_000], dtype=torch.int32) + + compressed_region_size, row_stride = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=1, + swa_only=True, + ) + + assert compressed_region_size == 0 + assert row_stride == 15_000 + + +@pytest.mark.parametrize( + ("num_decode_tokens", "expected_head_block_size"), + [ + (0, 1), + (1, 1), + (4, 1), + (5, 2), + (8, 2), + (15, 2), + (16, 4), + (32, 4), + ], +) +def test_triton_sparse_mla_decode_head_block_size( + num_decode_tokens: int, + expected_head_block_size: int, + monkeypatch, +) -> None: + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", raising=False) + + assert ( + sparse_mla_decode_head_block_size(num_decode_tokens) + == expected_head_block_size + ) + + +@pytest.mark.parametrize("configured_head_block_size", ["1", "2", "4"]) +def test_triton_sparse_mla_decode_head_block_size_env_override( + configured_head_block_size: str, + monkeypatch, +) -> None: + monkeypatch.setenv( + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", + configured_head_block_size, + ) + + assert sparse_mla_decode_head_block_size(1) == int(configured_head_block_size) + assert sparse_mla_decode_head_block_size(32) == int(configured_head_block_size) + + +@pytest.mark.parametrize("configured_head_block_size", ["0", "3", "invalid"]) +def test_triton_sparse_mla_decode_head_block_size_ignores_invalid_env_override( + configured_head_block_size: str, + monkeypatch, +) -> None: + monkeypatch.setenv( + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", + configured_head_block_size, + ) + + assert sparse_mla_decode_head_block_size(8) == 2 + + +def test_swa_mtp_decode_reference_uses_global_swa_slots(monkeypatch) -> None: + captured: dict[str, torch.Tensor] = {} + + def fail_paged_attention_with_sink_multihead(**kwargs) -> None: + raise AssertionError("MTP SWA decode must use explicit SWA indices") + + def fake_accumulate_global_slots(**kwargs) -> None: + captured["slot_ids"] = kwargs["slot_ids"] + captured["lens"] = kwargs["lens"] + + def fake_finish_with_sink(*args, **kwargs) -> None: + kwargs["output"].zero_() + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: _FakeWorkspaceManager(), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "fp8ds_paged_sparse_mla_attention_with_sink_multihead", + fail_paged_attention_with_sink_multihead, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead", + fake_accumulate_global_slots, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "finish_sparse_mla_attention_with_sink", + fake_finish_with_sink, + ) + + attention = SimpleNamespace( + num_heads=2, + scale=0.1, + attn_sink=torch.zeros(2, dtype=torch.float32), + ) + swa_indices = torch.arange(48, dtype=torch.int32).reshape(6, 1, 8) + swa_lens = torch.tensor([2, 3, 4, 2, 3, 4], dtype=torch.int32) + metadata = SimpleNamespace( + num_decodes=2, + num_decode_tokens=6, + decode_swa_lens=swa_lens, + decode_swa_indices=swa_indices, + seq_lens=torch.tensor([11, 22], dtype=torch.int32), + block_table=torch.empty((2, 4), dtype=torch.int32), + block_size=256, + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32), + ) + + deepseek_v4_attention_module.DeepseekV4MLAAttention._forward_sparse_mla_swa_decode_reference( + attention, + q=torch.empty((6, 1, 2, 512), dtype=torch.bfloat16), + swa_k_cache=torch.empty((1, 256, 584), dtype=torch.uint8), + swa_metadata=metadata, + output=torch.empty((6, 2, 512), dtype=torch.bfloat16), + ) + + torch.testing.assert_close(captured["slot_ids"], swa_indices) + torch.testing.assert_close(captured["lens"], swa_lens) + + +def test_compressed_mtp_decode_reference_uses_global_swa_slots(monkeypatch) -> None: + captured: list[torch.Tensor] = [] + + def fail_matmul_decode(**kwargs) -> None: + raise AssertionError("MTP compressed decode must not stage paged SWA") + + def fail_direct_global_paged(**kwargs) -> None: + raise AssertionError("MTP compressed decode must not use paged SWA window") + + def fake_accumulate_global_slots(**kwargs) -> None: + captured.append(kwargs["slot_ids"]) + + def fake_finish_two_states(*args, **kwargs) -> None: + kwargs["output"].zero_() + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: _FakeWorkspaceManager(), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "dequantize_combined_sparse_mla_decode_kv", + fail_matmul_decode, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "fp8ds_global_paged_sparse_mla_attention_with_sink_multihead", + fail_direct_global_paged, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead", + fake_accumulate_global_slots, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "finish_two_sparse_mla_attention_states_with_sink", + fake_finish_two_states, + ) + + attention = SimpleNamespace( + num_heads=2, + scale=0.1, + attn_sink=torch.zeros(2, dtype=torch.float32), + compress_ratio=4, + ) + swa_indices = torch.arange(48, dtype=torch.int32).reshape(6, 1, 8) + topk_slot_ids = torch.arange(24, dtype=torch.int32).reshape(6, 1, 4) + swa_metadata = SimpleNamespace( + num_decodes=2, + num_decode_tokens=6, + decode_swa_lens=torch.full((6,), 3, dtype=torch.int32), + decode_swa_indices=swa_indices, + seq_lens=torch.tensor([11, 22], dtype=torch.int32), + block_table=torch.empty((2, 4), dtype=torch.int32), + block_size=256, + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32), + ) + + deepseek_v4_attention_module.DeepseekV4MLAAttention._forward_sparse_mla_compressed_decode_reference( + attention, + q=torch.empty((6, 1, 2, 512), dtype=torch.bfloat16), + compressed_k_cache=torch.empty((1, 64, 584), dtype=torch.uint8), + swa_k_cache=torch.empty((1, 256, 584), dtype=torch.uint8), + topk_indices=topk_slot_ids, + topk_lens=torch.full((6,), 4, dtype=torch.int32), + swa_metadata=swa_metadata, + attn_metadata=SimpleNamespace(block_size=256), + output=torch.empty((6, 2, 512), dtype=torch.bfloat16), + ) + + assert len(captured) == 2 + torch.testing.assert_close(captured[0], topk_slot_ids[:, 0]) + torch.testing.assert_close(captured[1], swa_indices) + + +@pytest.mark.parametrize( + ("capability_major", "expected_recipe", "expected_tma_aligned"), + [ + (9, (1, 128, 128), False), + (10, (1, 1, 128), True), + (12, (1, 128, 128), False), + ], +) +def test_deepseek_v4_fp8_einsum_config_for_sm12x( + capability_major: int, + expected_recipe: tuple[int, int, int], + expected_tma_aligned: bool, +) -> None: + assert _deepseek_v4_fp8_einsum_config(capability_major) == ( + expected_recipe, + expected_tma_aligned, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("use_e8m0_scale", [False, True]) +def test_deepseek_v4_sm12_triton_fp8_einsum_matches_deepgemm_reference( + use_e8m0_scale: bool, +) -> None: + if use_e8m0_scale and not hasattr(torch, "float8_e8m0fnu"): + pytest.skip("torch does not expose float8_e8m0fnu") + torch.manual_seed(0) + num_tokens = 17 + num_groups = 4 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.empty( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.float8_e4m3fn, + ) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.empty( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.float8_e4m3fn, + ) + b = b_flat.view(num_groups, out_rank, hidden_size) + if use_e8m0_scale: + scale_choices = torch.tensor( + [0.00390625, 0.0078125, 0.015625, 0.03125], + device="cuda", + dtype=torch.float32, + ) + scale_indices = torch.randint( + 0, + len(scale_choices), + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + ) + b_scale_flat = scale_choices[scale_indices].to(torch.float8_e8m0fnu) + b_scale_ref_flat = b_scale_flat.to(torch.float32) + else: + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale_ref_flat = b_scale_flat + b_scale_ref = b_scale_ref_flat.view( + num_groups, out_rank // 128, hidden_size // 128 + ) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum( + "bhr,hdr->bhd", + (a, a_scale), + (b, b_scale_ref), + expected, + recipe=recipe, + ) + deepseek_v4_fp8_einsum( + a, + a_scale, + b_flat, + b_scale_flat, + actual, + "bhr,hdr->bhd", + list(recipe), + ) + + torch.testing.assert_close(actual, expected, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_deepseek_v4_sm12_triton_fp8_einsum_primitive_matches_reference() -> None: + torch.manual_seed(0) + num_tokens = 17 + num_groups = 4 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.empty( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.float8_e4m3fn, + ) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.empty( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.float8_e4m3fn, + ) + b = b_flat.view(num_groups, out_rank, hidden_size) + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale = b_scale_flat.view(num_groups, out_rank // 128, hidden_size // 128) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe) + deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, actual) + + torch.testing.assert_close(actual, expected, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_groups", [1, 2, 4]) +def test_deepseek_v4_sm12_triton_fp8_einsum_supports_tp_local_group_counts( + num_groups: int, +) -> None: + torch.manual_seed(18 + num_groups) + num_tokens = 5 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.empty( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.float8_e4m3fn, + ) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.empty( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.float8_e4m3fn, + ) + b = b_flat.view(num_groups, out_rank, hidden_size) + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale = b_scale_flat.view(num_groups, out_rank // 128, hidden_size // 128) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe) + deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, actual) + + torch.testing.assert_close(actual, expected, rtol=0, atol=0) + + +def _masked_scores( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> torch.Tensor: + q_bhd = q[:, 0].float() if q.dim() == 4 else q.float() + scores = torch.einsum("bhd,btd->bht", q_bhd, kv.float()) * scale + return scores.masked_fill(~valid_tokens[:, None, :], float("-inf")) + + +def _golden_no_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor]: + scores = _masked_scores(q, kv, valid_tokens, scale) + lse = torch.logsumexp(scores, dim=-1) + weights = torch.exp(scores - lse[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + output = torch.einsum("bht,btd->bhd", weights, kv.float()) + valid = valid_tokens.any(dim=-1) + output = torch.where( + valid[:, None, None], + output, + torch.zeros((), dtype=output.dtype, device=output.device), + ) + return output, lse + + +def _golden_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, +) -> torch.Tensor: + scores = _masked_scores(q, kv, valid_tokens, scale) + sink = attn_sink[None, :].float() + score_max = scores.amax(dim=-1) + merge_max = torch.maximum(score_max, sink) + + weights = torch.exp(scores - merge_max[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + + sink_weight = torch.exp(sink - merge_max) + sink_weight = torch.nan_to_num(sink_weight) + denom = weights.sum(dim=-1) + sink_weight + numerator = torch.einsum("bht,btd->bhd", weights, kv.float()) + return numerator / denom[:, :, None] + + +def _chunked_no_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + q_bhd, max_score, denom, acc = new_reference_attention_state(q) + for chunk_start in range(0, kv.shape[1], chunk_size): + chunk_end = min(chunk_start + chunk_size, kv.shape[1]) + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=kv[:, chunk_start:chunk_end], + valid_tokens=valid_tokens[:, chunk_start:chunk_end], + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + return finish_reference_attention_no_sink(max_score, denom, acc) + + +def _write_fp8_ds_mla_token( + k_cache: torch.Tensor, + slot: int, + block_size: int, +) -> torch.Tensor: + block_idx = slot // block_size + block_offset = slot % block_size + + values = ( + (torch.arange(_FP8_DIM, device=k_cache.device, dtype=torch.float32) % 17) + - 8 + ) / 16.0 + values = values + float(slot) / 32.0 + scale_exponents = torch.tensor( + [-2, -1, 0, 1, 2, -2, 1], + device=k_cache.device, + dtype=torch.float32, + ) + scales = torch.exp2(scale_exponents) + scale_per_dim = scales.repeat_interleave(64) + + fp8_values = (values / scale_per_dim).to(torch.float8_e4m3fn) + expected_nope = fp8_values.float() * scale_per_dim + rope = ( + torch.linspace(-1.0, 1.0, _ROPE_DIM, device=k_cache.device) + + float(slot) / 16.0 + ).to(torch.bfloat16) + + flat_block = k_cache[block_idx].view(-1) + token_data_start = block_offset * _TOKEN_DATA_SIZE + token_scale_start = block_size * _TOKEN_DATA_SIZE + block_offset * _SCALE_DIM + flat_block[token_data_start : token_data_start + _FP8_DIM] = fp8_values.view( + torch.uint8 + ) + flat_block[ + token_data_start + _FP8_DIM : token_data_start + _TOKEN_DATA_SIZE + ] = rope.view(torch.uint8) + + encoded_scales = (scale_exponents.to(torch.int32) + 127).to(torch.uint8) + flat_block[token_scale_start : token_scale_start + encoded_scales.numel()] = ( + encoded_scales + ) + flat_block[ + token_scale_start + encoded_scales.numel() : token_scale_start + _SCALE_DIM + ] = 127 + + return torch.cat([expected_nope, rope.float()]).to(torch.bfloat16) + + +def test_reference_attention_no_sink_matches_logsumexp() -> None: + torch.manual_seed(0) + scale = 0.25 + q = torch.randn(3, 4, 5) + kv = torch.randn(3, 6, 5) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, False], + [False, False, False, False, False, False], + [True, False, True, True, True, False], + ], + dtype=torch.bool, + ) + output, lse = reference_attention_no_sink(q, kv, valid_tokens, scale) + expected_output, expected_lse = _golden_no_sink_attention( + q, + kv, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=1e-6, atol=1e-6) + torch.testing.assert_close(lse, expected_lse, rtol=1e-6, atol=1e-6) + + + +def test_reference_attention_ignores_nan_kv_for_invalid_tokens() -> None: + torch.manual_seed(24) + q = torch.randn(2, 1, 3, 8) + kv = torch.randn(2, 4, 8) + kv[:, 2:] = float("nan") + valid_tokens = torch.tensor( + [[True, True, False, False], [True, False, False, False]], + dtype=torch.bool, + ) + + output, lse = reference_attention_no_sink( + q=q, + kv=kv, + valid_tokens=valid_tokens, + scale=0.125, + ) + + assert torch.isfinite(output).all() + assert torch.isfinite(lse).all() + + +def test_sink_aware_reference_attention_matches_dense_golden() -> None: + torch.manual_seed(1) + scale = 0.125 + q = torch.randn(3, 1, 4, 5) + kv = torch.randn(3, 6, 5) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, False], + [False, False, False, False, False, False], + [False, True, True, False, True, True], + ], + dtype=torch.bool, + ) + sink = torch.tensor([-1.0, 0.25, 1.5, -0.5]) + output = torch.empty(3, 4, 5) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, output) + expected = _golden_sink_attention(q, kv, valid_tokens, scale, sink) + + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + + +def test_lse_merge_with_sink_matches_concatenated_attention() -> None: + torch.manual_seed(2) + scale = 0.2 + q = torch.randn(4, 3, 7) + compressed_kv = torch.randn(4, 5, 7) + swa_kv = torch.randn(4, 3, 7) + compressed_kv[:, 1] = compressed_kv[:, 0] + swa_kv[:, 2] = compressed_kv[:, 0] + compressed_valid = torch.tensor( + [ + [True, True, False, True, False], + [False, False, False, False, False], + [True, False, True, True, False], + [False, False, False, False, False], + ], + dtype=torch.bool, + ) + swa_valid = torch.tensor( + [ + [True, False, True], + [True, True, False], + [False, False, False], + [False, False, False], + ], + dtype=torch.bool, + ) + sink = torch.tensor([-0.25, 0.75, 1.25]) + output = torch.empty(4, 3, 7) + comp_output, comp_lse = reference_attention_no_sink( + q, + compressed_kv, + compressed_valid, + scale, + ) + swa_output, swa_lse = reference_attention_no_sink(q, swa_kv, swa_valid, scale) + merge_reference_attention_with_sink( + subset_outputs=[comp_output, swa_output], + subset_lses=[comp_lse, swa_lse], + attn_sink=sink, + output=output, + ) + + expected = _golden_sink_attention( + q, + torch.cat([compressed_kv, swa_kv], dim=1), + torch.cat([compressed_valid, swa_valid], dim=1), + scale, + sink, + ) + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + assert torch.equal(output[3], torch.zeros_like(output[3])) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_lse_merge_with_sink_matches_reference() -> None: + torch.manual_seed(5) + comp_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + swa_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + comp_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + swa_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + comp_lse[1, 2] = float("-inf") + swa_lse[2, 1] = float("-inf") + sink = torch.tensor([-0.5, 0.25, 1.0, -1.5], device="cuda") + + output = torch.empty(3, 4, 9, device="cuda", dtype=torch.bfloat16) + expected = torch.empty_like(output) + merge_two_sparse_mla_subsets_with_sink( + subset0_output=comp_output, + subset0_lse=comp_lse, + subset1_output=swa_output, + subset1_lse=swa_lse, + attn_sink=sink, + output=output, + ) + merge_reference_attention_with_sink( + subset_outputs=[comp_output, swa_output], + subset_lses=[comp_lse, swa_lse], + attn_sink=sink, + output=expected, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_single_lse_merge_with_sink_matches_reference() -> None: + torch.manual_seed(14) + subset_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + subset_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + subset_lse[1, 2] = float("-inf") + sink = torch.tensor([-0.5, 0.25, 1.0, -1.5], device="cuda") + + output = torch.empty(3, 4, 9, device="cuda", dtype=torch.bfloat16) + expected = torch.empty_like(output) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=sink, + output=expected, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_finish_with_sink_matches_finish_then_merge_reference() -> None: + torch.manual_seed(18) + max_score = torch.randn(4, 3, device="cuda", dtype=torch.float32) + denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + denom[1, 2] = 0.0 + max_score[1, 2] = float("-inf") + acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + sink = torch.tensor( + [-0.5, 0.25, 1.0, -float("inf"), -float("inf")], + device="cuda", + dtype=torch.float32, + ) + + output = torch.full((4, 5, 17), -7.0, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, output) + + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + expected = torch.empty(4, 3, 17, device="cuda", dtype=torch.bfloat16) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=sink[:3], + output=expected, + ) + + torch.testing.assert_close( + output[:, :3].float(), expected.float(), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + output[:, 3:].float(), + torch.full_like(output[:, 3:].float(), -7.0), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_finish_with_sink_returns_zero_when_no_tokens_or_sink() -> None: + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.full((2, 3, 17), float("nan"), device="cuda") + sink = torch.full((3,), float("-inf"), device="cuda") + + single_output = torch.full( + (2, 3, 17), 7.0, device="cuda", dtype=torch.bfloat16 + ) + finish_sparse_mla_attention_with_sink( + max_score, + denom, + acc, + sink, + output=single_output, + ) + torch.testing.assert_close( + single_output.float(), + torch.zeros_like(single_output.float()), + rtol=0, + atol=0, + ) + + two_output = torch.full((2, 3, 17), 7.0, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + max_score, + denom, + acc, + max_score, + denom, + acc, + sink, + output=two_output, + ) + torch.testing.assert_close( + two_output.float(), + torch.zeros_like(two_output.float()), + rtol=0, + atol=0, + ) + + +def test_triton_finish_two_states_with_sink_matches_finish_then_merge() -> None: + torch.manual_seed(22) + comp_max = torch.randn(4, 3, device="cuda", dtype=torch.float32) + comp_denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + comp_acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + swa_max = torch.randn(4, 3, device="cuda", dtype=torch.float32) + swa_denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + swa_acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + sink = torch.tensor( + [-0.5, 0.25, 1.0, -float("inf"), -float("inf")], + device="cuda", + dtype=torch.float32, + ) + + comp_denom[0, 1] = 0.0 + comp_max[0, 1] = float("-inf") + swa_denom[2, 0] = 0.0 + swa_max[2, 0] = float("-inf") + comp_denom[3, 2] = 0.0 + comp_max[3, 2] = float("-inf") + swa_denom[3, 2] = 0.0 + swa_max[3, 2] = float("-inf") + + output = torch.full((4, 5, 17), -7.0, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + output, + ) + + comp_output = torch.empty_like(comp_acc) + comp_lse = torch.empty_like(comp_max) + swa_output = torch.empty_like(swa_acc) + swa_lse = torch.empty_like(swa_max) + finish_gathered_sparse_mla_attention( + comp_max, + comp_denom, + comp_acc, + comp_output, + comp_lse, + ) + finish_gathered_sparse_mla_attention( + swa_max, + swa_denom, + swa_acc, + swa_output, + swa_lse, + ) + expected = torch.empty(4, 3, 17, device="cuda", dtype=torch.bfloat16) + merge_two_sparse_mla_subsets_with_sink( + subset0_output=comp_output, + subset0_lse=comp_lse, + subset1_output=swa_output, + subset1_lse=swa_lse, + attn_sink=sink[:3], + output=expected, + ) + + torch.testing.assert_close( + output[:, :3].float(), expected.float(), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + output[:, 3:].float(), + torch.full_like(output[:, 3:].float(), -7.0), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_dim", [16, 512]) +def test_triton_gathered_attention_chunk_matches_reference(head_dim: int) -> None: + torch.manual_seed(6) + scale = 0.125 + q = torch.randn(2, 1, 5, head_dim, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :3] + kv = torch.randn(2, 5, head_dim, device="cuda", dtype=torch.bfloat16) + slot_ids = torch.tensor( + [ + [0, 1, -1, 3, 4], + [5, -1, 7, 8, -1], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, head_dim), device="cuda") + + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv[:, :2], + slot_ids=slot_ids[:, :2], + lens=lens, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv[:, 2:], + slot_ids=slot_ids[:, 2:], + lens=lens, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q_active, + kv, + valid_tokens, + scale, + ) + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_gathered_attention_chunk_matches_reference_without_slot_ids() -> None: + torch.manual_seed(8) + scale = 0.2 + q = torch.randn(3, 1, 2, 32, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(3, 6, 32, device="cuda", dtype=torch.bfloat16) + lens = torch.tensor([6, 3, 0], dtype=torch.int32, device="cuda") + max_score = torch.full((3, 2), float("-inf"), device="cuda") + denom = torch.zeros((3, 2), device="cuda") + acc = torch.zeros((3, 2, 32), device="cuda") + + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv, + slot_ids=None, + lens=lens, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(kv.shape[1], device="cuda") + valid_tokens = offsets[None, :] < lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q, + kv, + valid_tokens, + scale, + ) + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_dequantize_global_slots_k_cache_fp8_ds_mla_layout() -> None: + block_size = 4 + num_blocks = 2 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) + for slot in (0, 3, 4) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 4], + [4, 0, 3, -1], + ], + dtype=torch.int32, + device="cuda", + ) + + output = torch.empty(2, 4, 512, dtype=torch.bfloat16, device="cuda") + dequantize_global_slots_k_cache(output, k_cache, slot_ids, block_size) + + expected = torch.zeros_like(output) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + expected[token_idx, topk_idx] = expected_by_slot[slot] + + torch.testing.assert_close(output.float(), expected.float(), rtol=0, atol=0) + + output_from_3d_indices = torch.empty_like(output) + dequantize_global_slots_k_cache( + output_from_3d_indices, + k_cache, + slot_ids.unsqueeze(1), + block_size, + ) + torch.testing.assert_close( + output_from_3d_indices.float(), + expected.float(), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_dequantize_combined_sparse_mla_decode_kv_writes_direct_views() -> None: + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 2, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 3, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + for slot in (0, 3, 4): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for slot in (0, 1, 2, 3, 4): + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + compressed_slot_ids = torch.tensor( + [[0, 3, -1], [4, 0, 3]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([5, 7], dtype=torch.int32, device="cuda") + swa_lens = torch.tensor([2, 3], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[0, 1, 2], [2, 0, 1]], + dtype=torch.int32, + device="cuda", + ) + + combined = torch.full( + (2, 6, 512), + -7, + dtype=torch.bfloat16, + device="cuda", + ) + dequantize_combined_sparse_mla_decode_kv( + combined, + compressed_cache, + compressed_slot_ids, + compressed_block_size, + swa_cache, + seq_lens, + swa_lens, + block_table, + swa_block_size, + ) + + expected_comp = torch.empty(2, 3, 512, dtype=torch.bfloat16, device="cuda") + expected_swa = torch.full( + (2, 3, 512), + -7, + dtype=torch.bfloat16, + device="cuda", + ) + dequantize_global_slots_k_cache( + expected_comp, + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + dequantize_and_gather_k_cache( + expected_swa, + swa_cache, + seq_lens=seq_lens, + gather_lens=swa_lens, + block_table=block_table, + block_size=swa_block_size, + offset=0, + ) + expected = torch.full_like(combined, -7) + expected[:, :3].copy_(expected_comp) + expected[:, 3:].copy_(expected_swa) + + torch.testing.assert_close(combined.float(), expected.float(), rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_global_slots_attention_chunk_matches_reference() -> None: + torch.manual_seed(10) + block_size = 4 + num_blocks = 3 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) + for slot in (0, 1, 3, 4, 7, 8) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 8, 1], + [7, -1, 4, 0, 8], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + scale = 0.0625 + + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, :2], + lens=lens, + block_size=block_size, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, 2:], + lens=lens, + block_size=block_size, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + gathered[token_idx, topk_idx] = expected_by_slot[slot] + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_global_slots_multihead_attention_matches_reference( + head_block_size: int, +) -> None: + torch.manual_seed(19) + block_size = 4 + num_blocks = 3 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) + for slot in (0, 1, 3, 4, 7, 8) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 8, 1], + [7, -1, 4, 0, 8], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :5] + scale = 0.0625 + + max_score = torch.full((2, 5), float("-inf"), device="cuda") + denom = torch.zeros((2, 5), device="cuda") + acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, :2], + lens=lens, + block_size=block_size, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=head_block_size, + ) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, 2:], + lens=lens, + block_size=block_size, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=head_block_size, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + gathered[token_idx, topk_idx] = expected_by_slot[slot] + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q_active, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_paged_attention_chunk_matches_reference() -> None: + torch.manual_seed(12) + block_size = 4 + k_cache = torch.zeros( + 3, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [ + [1, 0, 2], + [2, 1, 0], + ], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([6, 9], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 4], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + scale = 0.0625 + + gathered = torch.zeros(2, 4, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + block_idx = pos // block_size + block_offset = pos % block_size + physical_block = int(block_table[token_idx, block_idx].item()) + slot = physical_block * block_size + block_offset + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[token_idx, gather_idx] = expected_by_slot[slot] + + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=2, + num_candidates=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(gathered.shape[1], device="cuda") + valid_tokens = offsets[None, :] < gather_lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_paged_multihead_attention_matches_singlehead_and_reference( + head_block_size: int, +) -> None: + torch.manual_seed(23) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [ + [1, 0, 2, 3], + [2, 3, 1, 0], + ], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :5] + scale = 0.0625 + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + block_idx = pos // block_size + block_offset = pos % block_size + physical_block = int(block_table[token_idx, block_idx].item()) + slot = physical_block * block_size + block_offset + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[token_idx, gather_idx] = expected_by_slot[slot] + + single_max = torch.full((2, 5), float("-inf"), device="cuda") + single_denom = torch.zeros((2, 5), device="cuda") + single_acc = torch.zeros((2, 5, 512), device="cuda") + multi_max = torch.full_like(single_max, float("-inf")) + multi_denom = torch.zeros_like(single_denom) + multi_acc = torch.zeros_like(single_acc) + + for candidate_offset, num_candidates in ((0, 2), (2, 3)): + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=candidate_offset, + num_candidates=num_candidates, + scale=scale, + max_score=single_max, + denom=single_denom, + acc=single_acc, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=candidate_offset, + num_candidates=num_candidates, + scale=scale, + max_score=multi_max, + denom=multi_denom, + acc=multi_acc, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(multi_max, single_max, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(multi_denom, single_denom, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(multi_acc, single_acc, rtol=2e-2, atol=2e-2) + + output = torch.empty_like(multi_acc) + lse = torch.empty_like(multi_max) + finish_gathered_sparse_mla_attention( + max_score=multi_max, + denom=multi_denom, + acc=multi_acc, + output=output, + lse=lse, + ) + offsets = torch.arange(gathered.shape[1], device="cuda") + valid_tokens = offsets[None, :] < gather_lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q_active, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_paged_attention_with_sink_matches_reference() -> None: + torch.manual_seed(15) + block_size = 4 + k_cache = torch.zeros( + 3, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor([[1, 0, 2]], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([7], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([4], dtype=torch.int32, device="cuda") + q = torch.randn(1, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.tensor([-0.25, 0.5, 1.25], device="cuda") + scale = 0.0625 + + gathered = torch.zeros(1, 4, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + start_pos = int(seq_lens[0].item() - gather_lens[0].item()) + for gather_idx in range(int(gather_lens[0].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[0, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[0, gather_idx] = expected_by_slot[slot] + + max_score = torch.full((1, 3), float("-inf"), device="cuda") + denom = torch.zeros((1, 3), device="cuda") + acc = torch.zeros((1, 3, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=4, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + + output = torch.empty(1, 3, 512, device="cuda", dtype=torch.bfloat16) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output, + ) + valid_tokens = torch.ones(1, 4, device="cuda", dtype=torch.bool) + expected = _golden_sink_attention(q, gathered, valid_tokens, scale, sink) + + torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_paged_attention_with_sink_direct_matches_state_path( + head_block_size: int, +) -> None: + torch.manual_seed(29) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-0.5, 0.5, 5, device="cuda") + scale = 0.0625 + + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + _write_fp8_ds_mla_token(k_cache, slot, block_size) + + max_score = torch.full((2, 5), float("-inf"), device="cuda") + denom = torch.zeros((2, 5), device="cuda") + acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=1, + ) + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, expected) + + actual = torch.empty_like(expected) + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_global_paged_attention_with_sink_direct_matches_state_path( + head_block_size: int, +) -> None: + torch.manual_seed(31) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 4, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 4, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + slot_ids = torch.tensor( + [[0, 3, -1, 8, 1], [7, -1, 4, 0, 8]], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-1.0, 1.0, 5, device="cuda") + scale = 0.0625 + + for slot in (0, 1, 3, 4, 7, 8): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // swa_block_size].item()) + slot = physical_block * swa_block_size + pos % swa_block_size + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + comp_max = torch.full((2, 5), float("-inf"), device="cuda") + comp_denom = torch.zeros((2, 5), device="cuda") + comp_acc = torch.zeros((2, 5, 512), device="cuda") + swa_max = torch.full((2, 5), float("-inf"), device="cuda") + swa_denom = torch.zeros((2, 5), device="cuda") + swa_acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_cache, + slot_ids=slot_ids, + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=0, + scale=scale, + max_score=comp_max, + denom=comp_denom, + acc=comp_acc, + head_block_size=1, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=swa_block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=swa_max, + denom=swa_denom, + acc=swa_acc, + head_block_size=1, + ) + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + expected, + ) + + actual = torch.empty_like(expected) + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_cache, + slot_ids=slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + swa_block_size=swa_block_size, + num_compressed_candidates=5, + num_swa_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_matmul_sparse_mla_attention_with_sink_matches_reference() -> None: + torch.manual_seed(41) + q = torch.randn(2, 1, 5, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(2, 7, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, True, True], + [False, True, True, False, True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 5, device="cuda") + scale = 0.0625 + + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention( + q, + kv, + valid_tokens, + scale, + sink, + expected, + ) + + actual = torch.empty_like(expected) + matmul_sparse_mla_attention_with_sink( + q, + kv, + valid_tokens, + scale, + sink, + actual, + num_heads=5, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_build_combined_sparse_mla_decode_valid_mask_matches_torch() -> None: + compressed_slot_ids = torch.tensor( + [ + [7, 4, -1, 9, 11], + [2, -1, 3, 8, 10], + [-1, -1, -1, -1, -1], + ], + device="cuda", + dtype=torch.int32, + ) + topk_lens = torch.tensor([4, 3, 0], device="cuda", dtype=torch.int32) + swa_lens = torch.tensor([3, 1, 0], device="cuda", dtype=torch.int32) + valid_tokens = torch.empty(3, 9, device="cuda", dtype=torch.bool) + + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + + comp_offsets = torch.arange(5, device="cuda", dtype=torch.int32) + swa_offsets = torch.arange(4, device="cuda", dtype=torch.int32) + expected = torch.empty_like(valid_tokens) + expected[:, :5] = (comp_offsets[None, :] < topk_lens[:, None]) & ( + compressed_slot_ids >= 0 + ) + expected[:, 5:] = swa_offsets[None, :] < swa_lens[:, None] + + torch.testing.assert_close(valid_tokens, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_heads", [8, 16, 32, 64]) +def test_triton_fp8ds_paged_attention_with_sink_supports_tp_local_heads( + num_heads: int, +) -> None: + torch.manual_seed(37 + num_heads) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, num_heads, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-0.5, 0.5, num_heads, device="cuda") + scale = 0.0625 + + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + _write_fp8_ds_mla_token(k_cache, slot, block_size) + + max_score = torch.full((2, num_heads), float("-inf"), device="cuda") + denom = torch.zeros((2, num_heads), device="cuda") + acc = torch.zeros((2, num_heads, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=1, + ) + expected = torch.empty(2, num_heads, 512, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, expected) + + actual = torch.empty_like(expected) + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=4, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_heads", [8, 16, 32, 64]) +def test_triton_fp8ds_global_paged_attention_with_sink_supports_tp_local_heads( + num_heads: int, +) -> None: + torch.manual_seed(41 + num_heads) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 4, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 4, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + slot_ids = torch.tensor( + [[0, 3, -1, 8, 1], [7, -1, 4, 0, 8]], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, num_heads, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-1.0, 1.0, num_heads, device="cuda") + scale = 0.0625 + + for slot in (0, 1, 3, 4, 7, 8): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // swa_block_size].item()) + slot = physical_block * swa_block_size + pos % swa_block_size + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + comp_max = torch.full((2, num_heads), float("-inf"), device="cuda") + comp_denom = torch.zeros((2, num_heads), device="cuda") + comp_acc = torch.zeros((2, num_heads, 512), device="cuda") + swa_max = torch.full((2, num_heads), float("-inf"), device="cuda") + swa_denom = torch.zeros((2, num_heads), device="cuda") + swa_acc = torch.zeros((2, num_heads, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_cache, + slot_ids=slot_ids, + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=0, + scale=scale, + max_score=comp_max, + denom=comp_denom, + acc=comp_acc, + head_block_size=1, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=swa_block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=swa_max, + denom=swa_denom, + acc=swa_acc, + head_block_size=1, + ) + expected = torch.empty(2, num_heads, 512, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + expected, + ) + + actual = torch.empty_like(expected) + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_cache, + slot_ids=slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + swa_block_size=swa_block_size, + num_compressed_candidates=5, + num_swa_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=4, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_indexed_bf16_prefill_chunks_match_reference() -> None: + torch.manual_seed(17) + q = torch.randn(5, 5, 16, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :3] + kv = torch.randn(2, 7, 16, device="cuda", dtype=torch.bfloat16) + kv_flat = kv.reshape(-1, q.shape[-1]) + combined_indices = torch.tensor( + [ + [0, 3, -1, 5, 3, 1], + [4, -1, 2, 2, 1, 8], + [-1, -1, -1, -1, -1, -1], + [8, 0, 9, -1, 7, 4], + [13, 12, 0, 12, -1, 3], + ], + dtype=torch.int64, + device="cuda", + ) + combined_lens = torch.tensor([5, 4, 0, 6, 5], dtype=torch.int32, device="cuda") + sink = torch.tensor([-0.5, 1.0, 0.25], dtype=torch.float32, device="cuda") + scale = 0.375 + output = torch.empty_like(q_active) + + for token_start in (0, 2, 4): + token_end = min(token_start + 2, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + max_score = torch.full( + (q_chunk.shape[0], q_active.shape[1]), + float("-inf"), + device="cuda", + ) + denom = torch.zeros_like(max_score) + acc = torch.zeros( + q_chunk.shape[0], + q_active.shape[1], + q_chunk.shape[-1], + device="cuda", + dtype=torch.float32, + ) + for index_start in (0, 3): + index_end = min(index_start + 3, combined_indices.shape[-1]) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output[token_start:token_end], + ) + + expected = torch.empty_like(q_active) + reference_sparse_mla_prefill( + q=q_active, + kv=kv, + combined_indices=combined_indices, + combined_lens=combined_lens, + scale=scale, + attn_sink=sink, + output=expected, + topk_chunk_size=3, + query_chunk_size=2, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.parametrize( + ("topk_chunk_size", "query_chunk_size"), + [(1, 1), (2, 3), (5, 2)], +) +def test_reference_sparse_mla_prefill_matches_dense_golden( + topk_chunk_size: int, + query_chunk_size: int, +) -> None: + torch.manual_seed(4) + scale = 0.375 + q = torch.randn(4, 2, 3) + kv = torch.randn(2, 5, 3) + combined_indices = torch.tensor( + [ + [0, 3, -1, 5, 3], + [4, -1, 2, 2, 1], + [-1, -1, -1, -1, -1], + [8, 0, 9, -1, 7], + ], + dtype=torch.int64, + ) + combined_lens = torch.tensor([4, 3, 0, 5], dtype=torch.int32) + sink = torch.tensor([-0.5, 1.0]) + output = torch.empty_like(q) + + reference_sparse_mla_prefill( + q=q, + kv=kv, + combined_indices=combined_indices, + combined_lens=combined_lens, + scale=scale, + attn_sink=sink, + output=output, + topk_chunk_size=topk_chunk_size, + query_chunk_size=query_chunk_size, + ) + + kv_flat = kv.reshape(-1, q.shape[-1]) + offsets = torch.arange(combined_indices.shape[-1]) + valid_tokens = (offsets[None, :] < combined_lens[:, None]) & ( + combined_indices >= 0 + ) + safe_indices = torch.where( + valid_tokens, + combined_indices, + torch.zeros((), dtype=combined_indices.dtype), + ).long() + gathered_kv = kv_flat[safe_indices] + expected = _golden_sink_attention(q, gathered_kv, valid_tokens, scale, sink) + + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("chunk_size", [1, 2, 5]) +def test_chunked_reference_accumulation_matches_one_shot(chunk_size: int) -> None: + torch.manual_seed(3) + scale = 0.3 + q = torch.randn(3, 2, 4) + kv = torch.randn(3, 9, 4) + valid_tokens = torch.tensor( + [ + [True, False, True, True, False, False, True, False, True], + [False, False, False, False, False, False, False, False, False], + [True, True, True, False, True, False, True, True, False], + ], + dtype=torch.bool, + ) + output, lse = _chunked_no_sink_attention( + q, + kv, + valid_tokens, + scale, + chunk_size, + ) + expected_output, expected_lse = _golden_no_sink_attention( + q, + kv, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=1e-6, atol=1e-6) + torch.testing.assert_close(lse, expected_lse, rtol=1e-6, atol=1e-6) + +def test_triton_sparse_mla_fallback_allows_cudagraph_support_by_default( + monkeypatch, +) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", raising=False) + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + + assert FlashMLASparseMetadataBuilder.get_cudagraph_support(None, mla_spec) is ( + AttentionCGSupport.UNIFORM_BATCH + ) + assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(None, swa_spec) is ( + AttentionCGSupport.UNIFORM_BATCH + ) + + vllm_config = SimpleNamespace( + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ) + ) + disable_sparse_mla_reference_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE + assert vllm_config.compilation_config.compile_sizes == [1, 2] + assert vllm_config.compilation_config.compile_ranges_endpoints == [8192] + assert ( + vllm_config.compilation_config.cudagraph_mode + == CUDAGraphMode.FULL_AND_PIECEWISE + ) + assert vllm_config.compilation_config.cudagraph_capture_sizes == [1, 2, 4] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 4 + + + +def test_triton_sparse_mla_fallback_can_disable_cudagraphs(monkeypatch) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", "0") + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + + assert FlashMLASparseMetadataBuilder.get_cudagraph_support(None, mla_spec) is ( + AttentionCGSupport.NEVER + ) + assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(None, swa_spec) is ( + AttentionCGSupport.NEVER + ) + + vllm_config = SimpleNamespace( + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ) + ) + disable_sparse_mla_reference_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.NONE + assert vllm_config.compilation_config.compile_sizes == [] + assert vllm_config.compilation_config.compile_ranges_endpoints == [] + assert vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert vllm_config.compilation_config.cudagraph_capture_sizes == [] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 0 + + +def test_triton_sparse_mla_fallback_disables_cudagraphs_for_mtp( + monkeypatch, +) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", raising=False) + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + vllm_config = SimpleNamespace( + speculative_config=SimpleNamespace( + method="mtp", + num_speculative_tokens=2, + ), + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ), + ) + + assert FlashMLASparseMetadataBuilder.get_cudagraph_support( + vllm_config, + mla_spec, + ) is AttentionCGSupport.NEVER + assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support( + vllm_config, + swa_spec, + ) is AttentionCGSupport.NEVER + + disable_sparse_mla_reference_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.NONE + assert vllm_config.compilation_config.compile_sizes == [] + assert vllm_config.compilation_config.compile_ranges_endpoints == [] + assert vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert vllm_config.compilation_config.cudagraph_capture_sizes == [] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 0 diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py new file mode 100644 index 000000000000..7c0e4a1bcb1b --- /dev/null +++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.v1.attention.utils import create_vllm_config +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadataBuilder +from vllm.v1.kv_cache_interface import MLAAttentionSpec + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size(): + """Regression test: DeepseekV4 compression path must compute slot_mapping from + compressed positions, not reuse the uncompressed common metadata mapping. + """ + device = torch.device("cuda") + + # storage_block_size = block_size // compress_ratio = 256 // 4 = 64 + kv_cache_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=128, + dtype=torch.bfloat16, + compress_ratio=4, + ) + vllm_config = create_vllm_config(max_model_len=1024) + builder = DeepseekV32IndexerMetadataBuilder( + kv_cache_spec=kv_cache_spec, + layer_names=["dummy"], + vllm_config=vllm_config, + device=device, + ) + + # Construct a single request where: + # - num_computed = 240 (=> compressed_pos_start = 60) + # - query_len = 40 (=> num_groups = 10) + # => compressed positions are 60..69 which cross the storage block boundary at 64. + query_start_loc = torch.tensor([0, 40], dtype=torch.int32, device=device) + query_start_loc_cpu = query_start_loc.cpu() + seq_lens = torch.tensor([280], dtype=torch.int32, device=device) # 240 + 40 + + # Two blocks: compressed positions 0..63 map to block 5, 64..127 map to block 7. + block_table_tensor = torch.tensor([[5, 7]], dtype=torch.int32, device=device) + + # Dummy uncompressed slot mapping (length == uncompressed num_actual_tokens). + slot_mapping = torch.full((40,), -123, dtype=torch.int64, device=device) + + common = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + num_reqs=1, + num_actual_tokens=40, + max_query_len=40, + max_seq_len=280, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + causal=True, + ) + + md = builder.build(common_prefix_len=0, common_attn_metadata=common) + + # The compressed slot_mapping retains the original uncompressed size (40). + # Only every compress_ratio-th position gets a valid slot; the rest are -1. + assert md.slot_mapping.numel() == 40 + valid_slots = md.slot_mapping[md.slot_mapping >= 0] + assert valid_slots.numel() == 10 # 40 tokens / compress_ratio 4 + + storage_bs = kv_cache_spec.storage_block_size # 64 + # Compressed positions 60..63 land in block 5, positions 64..69 in block 7. + expected = torch.tensor( + [ + 5 * storage_bs + 60, + 5 * storage_bs + 61, + 5 * storage_bs + 62, + 5 * storage_bs + 63, + ] + + [ + 7 * storage_bs + 0, + 7 * storage_bs + 1, + 7 * storage_bs + 2, + 7 * storage_bs + 3, + 7 * storage_bs + 4, + 7 * storage_bs + 5, + ], + dtype=torch.int64, + device=device, + ) + torch.testing.assert_close(valid_slots, expected) diff --git a/tests/v1/attention/test_sparse_attn_indexer.py b/tests/v1/attention/test_sparse_attn_indexer.py new file mode 100644 index 000000000000..3b781a0f3807 --- /dev/null +++ b/tests/v1/attention/test_sparse_attn_indexer.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.model_executor.layers.sparse_attn_indexer import ( + SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH, + SM120_SHORT_ROW_TOPK_MAX_ROWS, + SM120_SHORT_ROW_TOPK_MAX_WIDTH, + _should_use_sm120_short_row_topk_decode, +) + + +@pytest.mark.parametrize( + ("topk_tokens", "logits_width", "num_rows", "is_cuda_sm120", "expected"), + [ + (512, SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH, 32, True, True), + (512, 8192, SM120_SHORT_ROW_TOPK_MAX_ROWS, True, True), + (512, 8192, SM120_SHORT_ROW_TOPK_MAX_ROWS + 1, True, False), + (512, SM120_SHORT_ROW_TOPK_MAX_WIDTH, 1, True, False), + (512, 4096, 1, False, False), + (2048, 4096, 1, True, False), + ], +) +def test_sm120_short_row_topk_decode_selector( + topk_tokens: int, + logits_width: int, + num_rows: int, + is_cuda_sm120: bool, + expected: bool, +) -> None: + assert ( + _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits_width, + num_rows, + is_cuda_sm120, + ) + is expected + ) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 22acc748d24b..ce3bee22c9e6 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -42,7 +42,10 @@ FlashMLASparseBackend, triton_convert_req_index_to_global_index, ) -from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks +from vllm.v1.attention.backends.mla.indexer import ( + sparse_indexer_max_logits_bytes, + split_indexer_prefill_chunks, +) from vllm.v1.attention.backends.utils import split_prefill_chunks from vllm.v1.attention.ops import flashmla @@ -218,8 +221,14 @@ def test_sparse_backend_decode_correctness( if not ok: pytest.skip(reason) elif backend_cls == FlashInferMLASparseBackend: - if not current_platform.has_device_capability(100): - pytest.skip("FlashInferMLASparseBackend requires SM 10.0 or higher") + capability = current_platform.get_device_capability() + if capability is None or not backend_cls.supports_compute_capability( + capability + ): + pytest.skip( + "FlashInferMLASparseBackend does not support " + f"{capability} on this platform" + ) batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla" @@ -781,6 +790,20 @@ def test_split_indexer_prefill_chunks( assert out == expected +def test_sparse_indexer_max_logits_bytes_uses_sm12x_safe_default(monkeypatch): + monkeypatch.delenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", raising=False) + + assert sparse_indexer_max_logits_bytes(is_sm12x=True) == 256 * 1024 * 1024 + assert sparse_indexer_max_logits_bytes(is_sm12x=False) == 512 * 1024 * 1024 + + +def test_sparse_indexer_max_logits_bytes_honors_env_override(monkeypatch): + monkeypatch.setenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "384") + + assert sparse_indexer_max_logits_bytes(is_sm12x=True) == 384 * 1024 * 1024 + assert sparse_indexer_max_logits_bytes(is_sm12x=False) == 384 * 1024 * 1024 + + def test_split_indexer_prefill_chunks_single_request_overflow(): """Test that single request exceeding budget is sub-chunked on query dim.""" seq_lens = torch.tensor([1000, 50]) diff --git a/tests/v1/attention/test_sparse_mla_env.py b/tests/v1/attention/test_sparse_mla_env.py new file mode 100644 index 000000000000..89bb9be58ee2 --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_env.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from collections.abc import Iterator +from contextlib import contextmanager + +import torch + +from vllm.envs import environment_variables +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_sparse_mla_attention_dump_enabled, + is_sparse_mla_reference_attention_enabled, + sparse_mla_attention_dump_path, + sparse_mla_reference_cudagraphs_allowed, + sparse_mla_reference_head_block_size, + sparse_mla_reference_query_chunk_size, + sparse_mla_reference_topk_chunk_size, +) + +_SPARSE_MLA_ENV_NAMES = ( + "VLLM_TRITON_MLA_SPARSE", + "VLLM_TRITON_MLA_SPARSE_DUMP", + "VLLM_TRITON_MLA_SPARSE_DUMP_PATH", + "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", + "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", +) + + +@contextmanager +def _patched_sparse_mla_env(**updates: str) -> Iterator[None]: + previous = {name: os.environ.get(name) for name in _SPARSE_MLA_ENV_NAMES} + try: + for name in _SPARSE_MLA_ENV_NAMES: + os.environ.pop(name, None) + os.environ.update(updates) + yield + finally: + for name, value in previous.items(): + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + + +def test_sparse_mla_reference_env_uses_new_name() -> None: + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE="0"): + assert not is_sparse_mla_reference_attention_enabled(torch.device("cpu")) + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE="1"): + assert is_sparse_mla_reference_attention_enabled(torch.device("cpu")) + + +def test_sparse_mla_dump_env_uses_new_name() -> None: + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_DUMP="0"): + assert not is_sparse_mla_attention_dump_enabled() + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_DUMP="1"): + assert is_sparse_mla_attention_dump_enabled() + + +def test_sparse_mla_cudagraph_env_defaults_to_allowed() -> None: + with _patched_sparse_mla_env(): + assert sparse_mla_reference_cudagraphs_allowed() + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH="0"): + assert not sparse_mla_reference_cudagraphs_allowed() + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH="1"): + assert sparse_mla_reference_cudagraphs_allowed() + + +def test_sparse_mla_head_block_env_accepts_supported_values() -> None: + with _patched_sparse_mla_env(): + assert sparse_mla_reference_head_block_size() is None + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="1"): + assert sparse_mla_reference_head_block_size() == 1 + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="2"): + assert sparse_mla_reference_head_block_size() == 2 + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="4"): + assert sparse_mla_reference_head_block_size() == 4 + + +def test_sparse_mla_head_block_env_ignores_invalid_values() -> None: + for value in ("0", "3", "invalid"): + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE=value): + assert sparse_mla_reference_head_block_size() is None + + +def test_sparse_mla_head_block_env_is_registered_with_vllm_envs() -> None: + assert "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE" in environment_variables + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="4"): + assert environment_variables["VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE"]() == 4 + + +def test_sparse_mla_chunk_env_defaults_invalid_values() -> None: + with _patched_sparse_mla_env( + VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE="invalid", + VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE="-7", + ): + assert sparse_mla_reference_topk_chunk_size() == 512 + assert sparse_mla_reference_query_chunk_size() == 1 + + +def test_sparse_mla_dump_path_uses_new_name() -> None: + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_DUMP_PATH="/tmp/new.jsonl"): + assert sparse_mla_attention_dump_path() == "/tmp/new.jsonl" diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 046f04e0c79a..cfd03c5f687e 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1855,10 +1855,11 @@ def test_generate_scheduler_kv_cache_config(): def new_mla_spec(cache_dtype_str=None): + # head_size = kv_lora_rank(512) + qk_rope_head_dim(64) = 576 return MLAAttentionSpec( block_size=16, - num_kv_heads=16, - head_size=64, + num_kv_heads=1, + head_size=576, dtype=torch.float32, cache_dtype_str=cache_dtype_str, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 22220599f158..f271a214d3a4 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -557,19 +557,19 @@ def test_prefill_hybrid_model_eagle(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == num_full_blocks assert computed_blocks.get_block_ids() == ( - [1, 2, 3, 4], - [0, 9, 10, 11], - [0, 16, 17, 18], + [1, 2, 3, 4, 5], + [0, 0, 10, 11, 12], + [0, 0, 17, 18, 19], ) - assert num_computed_tokens == 4 * block_size + assert num_computed_tokens == 5 * block_size num_new_tokens = len(all_token_ids) - num_computed_tokens blocks = manager.allocate_slots( req1, num_new_tokens, num_computed_tokens, computed_blocks ) assert blocks is not None and blocks.get_block_ids() == ( - [22, 23, 24], - [25, 26, 27], - [28, 29, 30], + [22, 23], + [24, 25], + [26, 27], ) for block_per_group in computed_blocks.blocks: for block in block_per_group: @@ -591,7 +591,7 @@ def test_prefill_hybrid_model_eagle(): make_block_hash_with_group_id(block_hashes[0], 1), make_block_hash_with_group_id(block_hashes[0], 2), ], - 4, + 5, ) # Evict the first block of full attention, makes total cache miss. @@ -605,7 +605,7 @@ def test_prefill_hybrid_model_eagle(): 0, ) - # Evict the last block of all layers, reduces the hit length to 3. + # Evict the last block of all layers, reduces the hit length to 4. _test_partial_request_hit( manager, block_size, @@ -617,10 +617,10 @@ def test_prefill_hybrid_model_eagle(): make_block_hash_with_group_id(block_hashes[-1], 1), make_block_hash_with_group_id(block_hashes[-1], 2), ], - 3, + 4, ) - # Evict the last block of full attention, reduces the hit length to 3. + # Evict the last block of full attention, reduces the hit length to 4. _test_partial_request_hit( manager, block_size, @@ -628,7 +628,7 @@ def test_prefill_hybrid_model_eagle(): "5", all_token_ids, [make_block_hash_with_group_id(block_hashes[-1], 0)], - 3, + 4, ) # Since the last block of full attention is dropped for eagle, evict @@ -655,12 +655,11 @@ def test_prefill_hybrid_model_eagle(): 3, ) - # Evict different set of blocks for full attention and sliding window makes - # total cache miss. - # The cache hit length of full attention is 4 * block_size. - # The cache hit length of sliding window is 3 * block_size. - # Then it is cache miss as the two type of layers - # have different hit length. + # Evict different set of blocks for full attention and sliding window. + # Full loses its last block so it drops to 4 full blocks after the eagle + # pop; SWA lost block 0 (outside the sliding window of the final hit), + # which is not required for the K+1 anchor at position 4. Coordinated + # single-drop aligns both groups at hit=4. _test_partial_request_hit( manager, block_size, @@ -672,7 +671,7 @@ def test_prefill_hybrid_model_eagle(): make_block_hash_with_group_id(block_hashes[0], 1), make_block_hash_with_group_id(block_hashes[0], 2), ], - 0, + 4, ) @@ -893,7 +892,7 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]): # - 2 groups: 1 full + 1 other _EAGLE_HYBRID_MODEL_TEST_CASES = [ # 2 groups: 1 full + 1 other - pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"), + pytest.param(["full", "sliding_window"], 3, id="2g-full+sw"), ] diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f825220800f3..42f4825e2b3b 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1892,6 +1892,7 @@ def create_scheduler_with_priority( log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), block_size=block_size, + hash_block_size=block_size, ) @@ -4008,6 +4009,7 @@ def _create_encoder_decoder_scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, block_size=block_size, + hash_block_size=block_size, structured_output_manager=StructuredOutputManager(vllm_config), ) diff --git a/tests/v1/kv_connector/unit/test_mooncake_connector.py b/tests/v1/kv_connector/unit/test_mooncake_connector.py index 83202e023c9e..c3ce836423fa 100644 --- a/tests/v1/kv_connector/unit/test_mooncake_connector.py +++ b/tests/v1/kv_connector/unit/test_mooncake_connector.py @@ -91,8 +91,10 @@ def test_basic_interface(): assert request_id in kv_connector_metadata.reqs_to_recv["my-engine-id"] req_meta = kv_connector_metadata.reqs_to_recv["my-engine-id"][request_id] + # local_block_ids is list[list[int]] (per-group); flatten for comparison. + all_block_ids = [bid for group in req_meta.local_block_ids for bid in group] for block_id, block in zip( - req_meta.local_block_ids, + all_block_ids, scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ request_id ], @@ -228,15 +230,15 @@ def test_scheduler_request_finished(): # Case: Capped length (Successful prefill, need to send to decoder) request.status = RequestStatus.FINISHED_LENGTH_CAPPED - delay_free, _ = scheduler_connector.request_finished(request, block_ids=[10, 11]) + delay_free, _ = scheduler_connector.request_finished(request, block_ids=([10, 11],)) assert delay_free is True assert "id-1" in scheduler_connector._reqs_need_send - assert scheduler_connector._reqs_need_send["id-1"][1] == [10, 11] + assert scheduler_connector._reqs_need_send["id-1"][1] == [[10, 11]] # Case: Aborted (No need to transfer, free blocks immediately) scheduler_connector._reqs_need_send.clear() request.status = RequestStatus.FINISHED_ABORTED - delay_free, _ = scheduler_connector.request_finished(request, block_ids=[12]) + delay_free, _ = scheduler_connector.request_finished(request, block_ids=([12],)) assert delay_free is False assert len(scheduler_connector._reqs_need_send) == 0 assert "id-1" in scheduler_connector._reqs_not_processed @@ -334,7 +336,7 @@ async def test_kv_producer(monkeypatch): send_meta = SendBlockMeta( p_req_id="p-req-1", transfer_id=transfer_id, - local_block_ids=[10, 11], + local_block_ids=[[10, 11]], ready=asyncio.Event(), ) prefill_worker.reqs_need_send[transfer_id] = send_meta @@ -346,7 +348,7 @@ async def test_kv_producer(monkeypatch): remote_port=54321, remote_tp_size=1, remote_tp_rank=0, - req_blocks={"d-req-1": (transfer_id, [20, 21])}, + req_blocks={"d-req-1": (transfer_id, [[20, 21]])}, kv_caches_base_addr=[0x2000], block_lens=[block_len], ) @@ -389,7 +391,7 @@ async def test_kv_producer(monkeypatch): prefill_worker.reqs_need_send[transfer_id] = send_meta send_meta.sent = 0 send_meta.ready.set() - xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20]) + xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20]]) # Worker processes the consumer's request await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) # Verify transfer parameters are correct: 11 to 20 @@ -407,7 +409,7 @@ async def test_kv_producer(monkeypatch): prefill_worker.reqs_need_send[transfer_id] = send_meta send_meta.sent = 0 send_meta.ready.set() - xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21, 22]) + xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21, 22]]) # Worker processes the consumer's request await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) # This should not be called because error. @@ -424,7 +426,7 @@ async def test_kv_producer(monkeypatch): prefill_worker.reqs_need_send[transfer_id] = send_meta send_meta.sent = 0 send_meta.ready.clear() - xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21]) + xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21]]) # Worker processes the consumer's request await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) # This should not be called because timeout. @@ -443,7 +445,7 @@ async def test_kv_producer(monkeypatch): prefill_worker.reqs_need_send[transfer_id] = send_meta send_meta.sent = 0 send_meta.ready.set() - xfer_meta.req_blocks["d-req-1"] = (transfer_id, [20, 21]) + xfer_meta.req_blocks["d-req-1"] = (transfer_id, [[20, 21]]) # Worker processes the consumer's request await prefill_worker.send_kv_to_decode(identity, mock_socket, xfer_meta) mock_send_blocks.assert_called_once() @@ -481,7 +483,7 @@ async def test_kv_consumuer(monkeypatch): "d-req-1": PullReqMeta( d_req_id="d-req-1", transfer_id="xfer-req-1", - local_block_ids=[100, 101], + local_block_ids=[[100, 101]], remote_engine_id="p-engine", remote_bootstrap_addr="http://bootstrap:33333", pull_tasks_count=1, @@ -514,7 +516,7 @@ async def test_kv_consumuer(monkeypatch): assert sent_meta.remote_hostname == "127.0.0.1" assert sent_meta.remote_port == 54321 - assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [100, 101]) + assert sent_meta.req_blocks["d-req-1"] == ("xfer-req-1", [[100, 101]]) # Verify internal state is updated correctly. assert "d-req-1" in decode_worker.finished_recving_reqs @@ -538,7 +540,7 @@ async def test_worker_get_finished_timeout(monkeypatch): prefill_worker.reqs_need_send["tx-expired"] = SendBlockMeta( p_req_id="p-req-expired", transfer_id="tx-expired", - local_block_ids=[1, 2], + local_block_ids=[[1, 2]], ready=MagicMock(), expire_time=time.perf_counter() - 100, ) @@ -547,7 +549,7 @@ async def test_worker_get_finished_timeout(monkeypatch): prefill_worker.reqs_need_send["tx-active"] = SendBlockMeta( p_req_id="p-req-active", transfer_id="tx-active", - local_block_ids=[3, 4], + local_block_ids=[[3, 4]], ready=MagicMock(), expire_time=time.perf_counter() + 100, ) @@ -703,7 +705,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): prefill_worker.sender_loop = asyncio.get_event_loop() transfer_id = "xfer-hetero-1" - local_block_ids = [10, 11] + local_block_ids = [[10, 11]] send_meta = SendBlockMeta( p_req_id="p-req-h1", transfer_id=transfer_id, @@ -720,9 +722,9 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): mock_socket.send_multipart = AsyncMock() identity = b"consumer-hetero" - # Assign different remote block IDs per D rank + # Assign different remote block IDs per D rank (nested per-group) d_rank_remote_blocks = { - rank: [20 + i * 10, 21 + i * 10] for i, rank in enumerate(target_d_ranks) + rank: [[20 + i * 10, 21 + i * 10]] for i, rank in enumerate(target_d_ranks) } with patch.object( @@ -757,11 +759,15 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): dst_ptrs = call_args[2] lengths = call_args[3] + # Flatten nested per-group block IDs for assertions + flat_local = [b for g in local_block_ids for b in g] + flat_remote = [b for g in remote_block_ids for b in g] + # Heterogeneous TP: blocks cannot be coalesced because # local and remote block_lens differ - assert len(src_ptrs) == len(local_block_ids) - assert len(dst_ptrs) == len(local_block_ids) - assert len(lengths) == len(local_block_ids) + assert len(src_ptrs) == len(flat_local) + assert len(dst_ptrs) == len(flat_local) + assert len(lengths) == len(flat_local) # Compute expected offsets based on TP ratio if d_tp_size <= P_TP_SIZE: @@ -775,9 +781,7 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size): expected_dst_off = 0 expected_xfer_len = remote_block_len - for idx, (lblk, rblk) in enumerate( - zip(local_block_ids, remote_block_ids) - ): + for idx, (lblk, rblk) in enumerate(zip(flat_local, flat_remote)): assert src_ptrs[idx] == ( 0x1000 + lblk * local_block_len + expected_src_off ) diff --git a/tests/v1/kv_connector/unit/test_mooncake_connector_hma.py b/tests/v1/kv_connector/unit/test_mooncake_connector_hma.py new file mode 100644 index 000000000000..974a722d8c25 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_mooncake_connector_hma.py @@ -0,0 +1,400 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for MooncakeConnector HMA (Hybrid Memory Architecture) support. + +Covers sliding-window clipping, multi-group metadata shape, multi-group +send trimming, and group-count invariant checking in _build_transfer_params. +""" + +import asyncio +from unittest.mock import patch + +import pytest + +from vllm.config import set_current_vllm_config +from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector import ( + KVConnectorRole, + MooncakeConnector, + MooncakeConnectorMetadata, + MooncakeConnectorScheduler, + MooncakeXferMetadata, + SendBlockMeta, + TransferRegion, +) + +from .test_mooncake_connector import FakeMooncakeWrapper, patch_worker_dependencies +from .utils import create_request, create_vllm_config, make_kv_cache_config + + +# --------------------------------------------------------------------------- +# test_sw_sizes: blocks_per_sw computed from KVCacheConfig +# --------------------------------------------------------------------------- +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "swa_enabled,expected_blocks_per_sw", + [ + # SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128+1) + (True, [0, 128 + 1]), + # SWA disabled: only FullAttentionSpec (0) + (False, [0]), + ], +) +def test_sw_sizes(swa_enabled, expected_blocks_per_sw): + """blocks_per_sw is correctly computed based on SWA enabled/disabled.""" + block_size = 16 + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", + kv_role="kv_both", + block_size=block_size, + ) + # Override so HMA detection works + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + kv_cache_config = make_kv_cache_config( + block_size=block_size, swa_enabled=swa_enabled, sw_size=2048 + ) + + scheduler = MooncakeConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + assert scheduler.blocks_per_sw == expected_blocks_per_sw + + +# --------------------------------------------------------------------------- +# test_is_hma_required: derived from kv_cache_config groups +# --------------------------------------------------------------------------- +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "swa_enabled,disable_hma,expected_is_hma", + [ + (True, False, True), # SWA group present, HMA enabled + (True, True, False), # SWA group present, but HMA disabled + (False, False, False), # FA only, HMA not needed + ], +) +def test_is_hma_required(swa_enabled, disable_hma, expected_is_hma): + """_is_hma_required is correctly derived from kv_cache_config.""" + block_size = 16 + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", + kv_role="kv_both", + block_size=block_size, + ) + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = disable_hma + kv_cache_config = make_kv_cache_config( + block_size=block_size, swa_enabled=swa_enabled + ) + + scheduler = MooncakeConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + assert scheduler._is_hma_required is expected_is_hma + + +# --------------------------------------------------------------------------- +# test_get_sw_clipped_blocks: sliding-window clipping logic +# --------------------------------------------------------------------------- +@pytest.mark.cpu_test +def test_get_sw_clipped_blocks(): + """get_sw_clipped_blocks clips SWA group but keeps FA group intact.""" + block_size = 16 + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", + kv_role="kv_both", + block_size=block_size, + ) + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + # SW=128 tokens → 128/16 = 8 blocks + 1 = 9 blocks_per_sw + kv_cache_config = make_kv_cache_config( + block_size=block_size, swa_enabled=True, sw_size=128 + ) + + scheduler = MooncakeConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + assert scheduler.blocks_per_sw == [0, 9] + + # FA group: 20 blocks, SW group: 20 blocks (exceeds window) + fa_blocks = list(range(20)) + sw_blocks = list(range(100, 120)) + block_ids = (fa_blocks, sw_blocks) + + clipped = scheduler.get_sw_clipped_blocks(block_ids) + + # FA: untouched (blocks_per_sw[0] = 0) + assert clipped[0] == fa_blocks + # SW: clipped to last 9 blocks + assert clipped[1] == sw_blocks[-9:] + assert len(clipped[1]) == 9 + + +@pytest.mark.cpu_test +def test_get_sw_clipped_blocks_noop_no_hma(): + """get_sw_clipped_blocks is a no-op when HMA is not required.""" + block_size = 16 + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", + kv_role="kv_both", + block_size=block_size, + ) + # FA only → _is_hma_required = False + kv_cache_config = make_kv_cache_config(block_size=block_size, swa_enabled=False) + + scheduler = MooncakeConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + assert scheduler._is_hma_required is False + + block_ids = ([1, 2, 3],) + clipped = scheduler.get_sw_clipped_blocks(block_ids) + assert clipped == [[1, 2, 3]] + + +# --------------------------------------------------------------------------- +# test_metadata_hma_block_ids: MooncakeConnectorMetadata stores per-group IDs +# --------------------------------------------------------------------------- +@pytest.mark.cpu_test +def test_metadata_hma_block_ids(): + """MooncakeConnectorMetadata.add_new_req stores per-group block IDs.""" + metadata = MooncakeConnectorMetadata() + + # FA group: 6 blocks, SW group: 3 blocks (clipped) + fa_blocks = [0, 1, 2, 3, 4, 5] + sw_blocks = [10, 11, 12] + + # Test recv path + metadata.add_new_req( + request_id="recv-req", + local_block_ids=[fa_blocks, sw_blocks], + kv_transfer_params={ + "transfer_id": "recv-req", + "remote_engine_id": "remote-engine", + "remote_bootstrap_addr": "http://bootstrap:33333", + }, + load_remote_cache=True, + ) + + assert "recv-req" in metadata.reqs_to_recv["remote-engine"] + req_meta = metadata.reqs_to_recv["remote-engine"]["recv-req"] + assert len(req_meta.local_block_ids) == 2 + assert req_meta.local_block_ids[0] == fa_blocks + assert req_meta.local_block_ids[1] == sw_blocks + + # Test send path + metadata.add_new_req( + request_id="send-req", + local_block_ids=[fa_blocks, sw_blocks], + kv_transfer_params={ + "transfer_id": "send-req", + }, + load_remote_cache=False, + ) + + assert "send-req" in metadata.reqs_to_send + transfer_id, stored_blocks = metadata.reqs_to_send["send-req"] + assert transfer_id == "send-req" + assert len(stored_blocks) == 2 + assert stored_blocks[0] == fa_blocks + assert stored_blocks[1] == sw_blocks + + +# --------------------------------------------------------------------------- +# test_build_transfer_params_multi_group_trimming +# --------------------------------------------------------------------------- +@pytest.mark.cpu_test +@pytest.mark.asyncio +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake" + ".mooncake_connector.TransferEngine", + FakeMooncakeWrapper, +) +async def test_build_transfer_params_multi_group_trimming(monkeypatch): + """_build_transfer_params trims per-group blocks when local > remote.""" + + monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5") + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_producer" + ) + + with set_current_vllm_config(vllm_config), patch_worker_dependencies(): + connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + + block_len = 4096 + # Call _build_transfer_params directly (avoids send_kv_to_decode + # async event loop complexity). + transfer_id = "xfer-hma-trim" + send_meta = SendBlockMeta( + p_req_id="p-trim", + transfer_id=transfer_id, + # FA: 4 blocks, SW: 3 blocks (producer has more) + local_block_ids=[[10, 11, 12, 13], [20, 21, 22]], + ready=asyncio.Event(), + ) + + xfer_meta = MooncakeXferMetadata( + remote_hostname="consumer-host", + remote_port=54321, + remote_tp_size=1, + remote_tp_rank=0, + req_blocks={ + "d-trim": ( + transfer_id, + # FA: 2 blocks, SW: 2 blocks (consumer needs fewer) + [[30, 31], [40, 41]], + ) + }, + kv_caches_base_addr=[0x2000], + block_lens=[block_len], + ) + + local_regions = [ + TransferRegion( + base_addr=0x1000, block_len=block_len, kv_block_len=block_len + ), + ] + remote_regions = [ + TransferRegion( + base_addr=0x2000, block_len=block_len, kv_block_len=block_len + ), + ] + + ready_reqs = [("d-trim", send_meta)] + ( + src_ptrs, + dst_ptrs, + lengths, + err_reqs, + err_msg, + ) = await worker._build_transfer_params( + ready_reqs, xfer_meta, local_regions, remote_regions + ) + + # No errors + assert err_reqs == [] + assert err_msg is None + # After trimming: FA [10..13] → last 2 → [12,13]; SW [20..22] → last 2 → [21,22] + # Flattened: [12,13,21,22] = 4 blocks → coalesced into some transfers + assert len(src_ptrs) > 0 + assert len(dst_ptrs) == len(src_ptrs) + assert len(lengths) == len(src_ptrs) + + worker.shutdown() + + +# --------------------------------------------------------------------------- +# test_build_transfer_params_group_count_mismatch +# --------------------------------------------------------------------------- +@pytest.mark.cpu_test +@pytest.mark.asyncio +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake" + ".mooncake_connector.TransferEngine", + FakeMooncakeWrapper, +) +async def test_build_transfer_params_group_count_mismatch(monkeypatch): + """_build_transfer_params asserts when group counts differ.""" + + monkeypatch.setenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "5") + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", kv_role="kv_producer" + ) + + with set_current_vllm_config(vllm_config), patch_worker_dependencies(): + connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + + block_len = 4096 + transfer_id = "xfer-mismatch" + send_meta = SendBlockMeta( + p_req_id="p-mismatch", + transfer_id=transfer_id, + # Producer has 2 groups + local_block_ids=[[10, 11], [20, 21]], + ready=asyncio.Event(), + ) + + # Consumer has only 1 group — group count mismatch + xfer_meta = MooncakeXferMetadata( + remote_hostname="consumer-host", + remote_port=54321, + remote_tp_size=1, + remote_tp_rank=0, + req_blocks={ + "d-mismatch": (transfer_id, [[30, 31]]), + }, + kv_caches_base_addr=[0x2000], + block_lens=[block_len], + ) + + local_regions = [ + TransferRegion( + base_addr=0x1000, block_len=block_len, kv_block_len=block_len + ), + ] + remote_regions = [ + TransferRegion( + base_addr=0x2000, block_len=block_len, kv_block_len=block_len + ), + ] + + ready_reqs = [("d-mismatch", send_meta)] + with pytest.raises(AssertionError, match="KV group count mismatch"): + await worker._build_transfer_params( + ready_reqs, xfer_meta, local_regions, remote_regions + ) + + worker.shutdown() + + +# --------------------------------------------------------------------------- +# test_request_finished_with_hma_groups +# --------------------------------------------------------------------------- +@pytest.mark.cpu_test +def test_request_finished_with_hma_groups(): + """request_finished correctly handles per-group block_ids.""" + block_size = 16 + vllm_config = create_vllm_config( + kv_connector="MooncakeConnector", + kv_role="kv_producer", + block_size=block_size, + ) + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + kv_cache_config = make_kv_cache_config( + block_size=block_size, swa_enabled=True, sw_size=128 + ) + + scheduler = MooncakeConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + + request = create_request(request_id=1, do_remote_decode=True) + request.kv_transfer_params["transfer_id"] = request.request_id + + from vllm.v1.request import RequestStatus + + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + + # 2 groups: FA with 10 blocks, SW with 20 blocks (will be clipped) + fa_blocks = list(range(10)) + sw_blocks = list(range(100, 120)) + block_ids = (fa_blocks, sw_blocks) + + delay_free, _ = scheduler.request_finished(request, block_ids) + assert delay_free is True + assert request.request_id in scheduler._reqs_need_send + + _, stored_blocks = scheduler._reqs_need_send[request.request_id] + # FA: untouched + assert stored_blocks[0] == fa_blocks + # SW: clipped to last 9 blocks (sw_size=128, block_size=16 → 8+1=9) + assert stored_blocks[1] == sw_blocks[-9:] diff --git a/tests/v1/streaming_input/test_scheduler_streaming.py b/tests/v1/streaming_input/test_scheduler_streaming.py index fd9f6b17f9a9..7d680895b836 100644 --- a/tests/v1/streaming_input/test_scheduler_streaming.py +++ b/tests/v1/streaming_input/test_scheduler_streaming.py @@ -76,6 +76,7 @@ def create_scheduler() -> Scheduler: log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), block_size=16, + hash_block_size=16, ) diff --git a/tests/v1/worker/test_kv_cache_view_utils.py b/tests/v1/worker/test_kv_cache_view_utils.py new file mode 100644 index 000000000000..7888046f4954 --- /dev/null +++ b/tests/v1/worker/test_kv_cache_view_utils.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.v1.worker.kv_cache_view_utils import view_kv_cache_with_layout + + +def test_padded_kv_cache_view_uses_block_axis_for_standard_layout(): + num_blocks = 3 + semantic_shape = (2, num_blocks, 2, 1, 2) + stride_order = (0, 1, 2, 3, 4) + block_axis = 1 + page_elements = 12 + raw = torch.arange(num_blocks * page_elements, dtype=torch.int16) + + kv_cache = view_kv_cache_with_layout( + raw_tensor=raw, + kv_cache_shape=semantic_shape, + kv_cache_stride_order=stride_order, + block_axis=block_axis, + dtype=torch.int16, + page_size_bytes=page_elements * raw.element_size(), + page_size_padded=page_elements * raw.element_size(), + ) + + for kv in range(semantic_shape[0]): + for block in range(num_blocks): + for token in range(semantic_shape[2]): + for dim in range(semantic_shape[4]): + expected_offset = block * page_elements + kv * 4 + token * 2 + dim + assert kv_cache[kv, block, token, 0, dim] == raw[expected_offset] + + +def test_padded_kv_cache_view_handles_block_first_layout(): + num_blocks = 3 + semantic_shape = (num_blocks, 2, 4) + stride_order = (0, 1, 2) + block_axis = 0 + page_elements = 12 + raw = torch.arange(num_blocks * page_elements, dtype=torch.int16) + + kv_cache = view_kv_cache_with_layout( + raw_tensor=raw, + kv_cache_shape=semantic_shape, + kv_cache_stride_order=stride_order, + block_axis=block_axis, + dtype=torch.int16, + page_size_bytes=page_elements * raw.element_size(), + page_size_padded=page_elements * raw.element_size(), + ) + + for block in range(num_blocks): + for token in range(semantic_shape[1]): + for dim in range(semantic_shape[2]): + expected_offset = block * page_elements + token * 4 + dim + assert kv_cache[block, token, dim] == raw[expected_offset] diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh index 9d1edee04720..b44ce5075f77 100755 --- a/tools/install_deepgemm.sh +++ b/tools/install_deepgemm.sh @@ -6,8 +6,8 @@ set -e # Default values # Keep DEEPGEMM_GIT_REF in sync with cmake/external_projects/deepgemm.cmake -DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" -DEEPGEMM_GIT_REF="477618cd51baffca09c4b0b87e97c03fe827ef03" +DEEPGEMM_GIT_REPO="https://github.com/jasl/DeepGEMM.git" +DEEPGEMM_GIT_REF="7a7a41a1bac7dacabe74057e7600e59f98f85bce" WHEEL_DIR="" # Parse command line arguments diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d9be6a4c4332..e193fc94f34e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -404,9 +404,18 @@ def rotary_embedding( head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, + rope_dim_offset: int = 0, + inverse: bool = False, ) -> None: torch.ops._C.rotary_embedding( - positions, query, key, head_size, cos_sin_cache, is_neox + positions, + query, + key, + head_size, + cos_sin_cache, + is_neox, + rope_dim_offset, + inverse, ) @@ -2503,6 +2512,30 @@ def topk_sigmoid( ) +def topk_hash_softplus_sqrt( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool = False, + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + input_tokens: torch.Tensor | None = None, + hash_indices_table: torch.Tensor | None = None, +) -> None: + torch.ops._moe_C.topk_softplus_sqrt( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + input_tokens, + hash_indices_table, + ) + + def grouped_topk( scores: torch.Tensor, num_expert_group: int, diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index 15eb23e6f949..dd000eb97b36 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -153,6 +153,58 @@ def __call__(self, graph: torch.fx.Graph) -> None: "input_global_scale", ), ) + elif ( + hasattr(torch.ops._C, "per_token_group_fp8_quant") + and at_target == torch.ops._C.per_token_group_fp8_quant.default + ): + mutated_args = {1: "output_q", 2: "output_s"} + self.defunctionalize( + graph, + node, + mutated_args, + args=( + "input", + "output_q", + "output_s", + "group_size", + "eps", + "fp8_min", + "fp8_max", + "scale_ue8m0", + "dummy_is_scale_transposed", + "dummy_is_tma_aligned", + ), + ) + elif ( + hasattr(torch.ops._C, "cutlass_scaled_mm") + and at_target == torch.ops._C.cutlass_scaled_mm.default + ): + mutated_args = {1: "out"} + self.defunctionalize( + graph, + node, + mutated_args, + args=("out", "a", "b", "a_scales", "b_scales", "bias"), + ) + elif ( + hasattr(torch.ops.vllm, "deepseek_v4_fp8_einsum") + and at_target == torch.ops.vllm.deepseek_v4_fp8_einsum.default + ): + mutated_args = {1: "out"} + self.defunctionalize( + graph, + node, + mutated_args, + args=( + "a", + "a_scale", + "b", + "b_scale", + "out", + "equation", + "recipe", + ), + ) # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper. elif at_target == torch.ops._C.fused_qk_norm_rope.default: mutated_args = {1: "qkv"} diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 561367173d5f..826cef3c6d3f 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -51,6 +51,9 @@ class AttentionConfig: use_prefill_query_quantization: bool = False """If set, quantize query for attention in prefill.""" + use_fp4_indexer_cache: bool = False + """If set, use fp4 indexer cache for dsv32 family model (not support yet)""" + def compute_hash(self) -> str: """ Provide a hash that uniquely identifies all the configs diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 48ff1a32ec05..ae5023f1e348 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -51,6 +51,18 @@ class CacheConfig: """Whether block_size was explicitly provided. Derived automatically.""" user_specified_mamba_block_size: bool = field(default=False, init=False) """Whether mamba_block_size was explicitly provided. Derived automatically.""" + hash_block_size: SkipValidation[int] | None = None # type: ignore + """Block size (in tokens) used for computing Request's block_hashes. + + This can be set to a finer granularity than the physical KV cache block + sizes (e.g. 8) as long as every KV cache group's `block_size` is divisible + by it. This enables prefix-caching keys to be computed at the finest common + granularity and then merged for larger physical block sizes. + + This config is not static default. If left unspecified, vLLM will choose a + default based on the resolved KV cache groups (typically the smallest KV + cache block size when there are multiple groups). + """ gpu_memory_utilization: float = Field(default=0.92, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory @@ -182,6 +194,8 @@ def compute_hash(self) -> str: "num_gpu_blocks_override", "enable_prefix_caching", "prefix_caching_hash_algo", + # Prefix-caching implementation detail (doesn't affect compiled graph). + "hash_block_size", "mamba_page_size_padded", "user_specified_block_size", "user_specified_mamba_block_size", diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 5b726899c2f5..1f0ed0aafdd4 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -749,6 +749,7 @@ class CompilationConfig: "vllm::kda_attention", "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", + "vllm::deepseek_v4_attention", ] def compute_hash(self) -> str: diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index 8d8e37a0549a..e450cb26e6e7 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -108,6 +108,7 @@ def with_default( MoEBackend = Literal[ "auto", "triton", + "triton_unfused", "deep_gemm", "cutlass", "flashinfer_trtllm", @@ -136,7 +137,8 @@ class KernelConfig: """Backend for MoE expert computation kernels. Available options: - "auto": Automatically select the best backend based on model and hardware - - "triton": Use Triton-based fused MoE kernels + - "triton": Use Triton-based fused MoE kernels (SWIGLUOAI activation only) + - "triton_unfused": Use Triton-based unfused MoE kernels (supports SILU/GELU) - "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only) - "cutlass": Use vLLM CUTLASS kernels - "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels diff --git a/vllm/config/model.py b/vllm/config/model.py index 20c291ea281d..76cd734c75a5 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -83,7 +83,7 @@ RunnerOption = Literal["auto", RunnerType] ConvertType = Literal["none", "embed", "classify"] ConvertOption = Literal["auto", ConvertType] -TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"] +TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32", "deepseek_v4"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] LogprobsMode = Literal[ "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" @@ -134,6 +134,7 @@ class ModelConfig: - "slow" will always use the slow tokenizer. - "mistral" will always use the tokenizer from `mistral_common`. - "deepseek_v32" will always use the tokenizer from `deepseek_v32`. + - "deepseek_v4" will always use the tokenizer from `deepseek_v4`. - "qwen_vl" will always use the tokenizer from `qwen_vl`. - Other custom values can be supported via plugins.""" trust_remote_code: bool = False @@ -565,6 +566,8 @@ def __post_init__( self.tokenizer_mode = "qwen_vl" elif arch == "DeepseekV32ForCausalLM": self.tokenizer_mode = "deepseek_v32" + elif arch == "DeepseekV4ForCausalLM": + self.tokenizer_mode = "deepseek_v4" if self.tokenizer_mode != "auto": logger.info( @@ -952,6 +955,7 @@ def _verify_quantization(self) -> None: # imports during override detection (e.g., MXFP4 imports Triton) "mxfp4", "gpt_oss_mxfp4", + "deepseek_v4_fp8", "cpu_awq", "humming", "gguf", diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index a0c5cd04a16f..612cc3a1f281 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -287,13 +287,23 @@ def compute_hash(self) -> str: @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: initial_architecture = hf_config.architectures[0] - if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"): + if hf_config.model_type in ( + "deepseek_v3", + "deepseek_v32", + "glm_moe_dsa", + ): hf_config.model_type = "deepseek_mtp" if hf_config.model_type == "deepseek_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) hf_config.update( {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]} ) + if hf_config.model_type == "deepseek_v4": + hf_config.model_type = "deepseek_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + {"n_predict": n_predict, "architectures": ["DeepSeekV4MTPModel"]} + ) if hf_config.model_type in ("pangu_ultra_moe"): hf_config.model_type = "pangu_ultra_moe_mtp" if hf_config.model_type == "pangu_ultra_moe_mtp": diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index 2057c79fa58c..715fcbde16c9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -29,6 +29,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + SupportsHMA, ) from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import ( MooncakeBootstrapServer, @@ -43,10 +44,12 @@ from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import FullAttentionSpec, SlidingWindowSpec from vllm.v1.request import RequestStatus from vllm.v1.worker.utils import select_common_block_size @@ -252,7 +255,7 @@ class MooncakeXferMetadata( remote_port: int remote_tp_size: int remote_tp_rank: int - req_blocks: dict[ReqId, tuple[TransferId, list[int]]] + req_blocks: dict[ReqId, tuple[TransferId, list[list[int]]]] kv_caches_base_addr: list[int] block_lens: list[int] @@ -280,7 +283,7 @@ class MooncakeXferResponse( class PullReqMeta: d_req_id: ReqId transfer_id: TransferId - local_block_ids: list[int] + local_block_ids: list[list[int]] remote_engine_id: EngineId remote_bootstrap_addr: str # Set expire time to avoid infinitely sending requests. @@ -293,7 +296,7 @@ class PullReqMeta: class SendBlockMeta: p_req_id: ReqId transfer_id: TransferId - local_block_ids: list[int] + local_block_ids: list[list[int]] ready: asyncio.Event expire_time: float = float("inf") need_send: int = 0 @@ -306,13 +309,13 @@ def __init__(self): # Use (engine_id, dp_rank) to group reqs with same dp. # See comments in MooncakeBootstrapServer. self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict) - self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {} + self.reqs_to_send: dict[ReqId, tuple[TransferId, list[list[int]]]] = {} self.reqs_not_processed: set[TransferId] = set() def add_new_req( self, request_id: ReqId, - local_block_ids: list[int], + local_block_ids: list[list[int]], kv_transfer_params: dict[str, Any], load_remote_cache: bool = True, ): @@ -330,7 +333,7 @@ def add_new_req( self.reqs_to_send[request_id] = (transfer_id, local_block_ids) -class MooncakeConnector(KVConnectorBase_V1): +class MooncakeConnector(KVConnectorBase_V1, SupportsHMA): def __init__( self, vllm_config: VllmConfig, @@ -344,13 +347,18 @@ def __init__( self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: + assert kv_cache_config is not None, ( + "kv_cache_config is required for SCHEDULER role" + ) self.connector_scheduler: MooncakeConnectorScheduler | None = ( - MooncakeConnectorScheduler(vllm_config, self.engine_id) + MooncakeConnectorScheduler(vllm_config, self.engine_id, kv_cache_config) ) self.connector_worker: MooncakeConnectorWorker | None = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id) + self.connector_worker = MooncakeConnectorWorker( + vllm_config, self.engine_id, kv_cache_config + ) @classmethod def get_required_kvcache_layout(cls, vllm_config: VllmConfig): @@ -401,6 +409,14 @@ def request_finished( self, request: "Request", block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, (block_ids,)) + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -445,8 +461,14 @@ def wait_for_save(self): class MooncakeConnectorScheduler: """Implementation of Scheduler side methods""" - def __init__(self, vllm_config: VllmConfig, engine_id: str): + def __init__( + self, + vllm_config: VllmConfig, + engine_id: str, + kv_cache_config: "KVCacheConfig", + ): self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size assert vllm_config.kv_transfer_config self.is_kv_producer: bool = ( @@ -457,15 +479,49 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): ) logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id) + self._is_hma_required = ( + not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager + and any( + not isinstance(g.kv_cache_spec, FullAttentionSpec) + for g in kv_cache_config.kv_cache_groups + ) + ) + # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. - self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} - self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_recv: dict[ReqId, tuple[Request, list[list[int]]]] = {} + self._reqs_need_send: dict[ReqId, tuple[Request, list[list[int]]]] = {} # Reqs to remove from processed set because they're not to send after # remote prefill or aborted. self._reqs_not_processed: set[TransferId] = set() + # Compute sliding window block counts per KV cache group. + sw_sizes_tokens: list[tuple[int, int]] = [ + (g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size) + if isinstance(g.kv_cache_spec, SlidingWindowSpec) + else (0, self.block_size) + for g in kv_cache_config.kv_cache_groups + ] + # cdiv(n_tokens, block_size) gives blocks/window; add 1 to + # conservatively account for boundary overlap. + self.blocks_per_sw = [ + cdiv(n_tokens, block_size) + 1 if n_tokens else 0 + for n_tokens, block_size in sw_sizes_tokens + ] + + def get_sw_clipped_blocks( + self, + block_ids: tuple[list[int], ...] | list[list[int]], + ) -> list[list[int]]: + """Clip per-group block IDs to sliding window size.""" + if len(block_ids) == 0 or not self._is_hma_required: + return list(block_ids) + return [ + blocks[-self.blocks_per_sw[i] :] if self.blocks_per_sw[i] > 0 else blocks + for i, blocks in enumerate(block_ids) + ] + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -530,9 +586,12 @@ def update_state_after_alloc( # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. - local_block_ids = ( - blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] + unhashed_block_ids = ( + blocks.get_unhashed_block_ids_all_groups() + if num_external_tokens > 0 + else () ) + local_block_ids = self.get_sw_clipped_blocks(unhashed_block_ids) # Get unhashed blocks to pull from remote. self._reqs_need_recv[request.request_id] = (request, local_block_ids) else: @@ -587,7 +646,7 @@ def build_connector_meta( def request_finished( self, request: "Request", - block_ids: list[int], + block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: """ Once a request is finished, determine whether request blocks @@ -630,10 +689,13 @@ def request_finished( # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below - delay_free_blocks = len(block_ids) > 0 + delay_free_blocks = any(len(group) > 0 for group in block_ids) if delay_free_blocks: - self._reqs_need_send[request.request_id] = (request, block_ids) + self._reqs_need_send[request.request_id] = ( + request, + self.get_sw_clipped_blocks(block_ids), + ) return delay_free_blocks, None @@ -641,7 +703,12 @@ def request_finished( class MooncakeConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, vllm_config: VllmConfig, engine_id: str): + def __init__( + self, + vllm_config: VllmConfig, + engine_id: str, + kv_cache_config: "KVCacheConfig | None" = None, + ): if TransferEngine is None: logger.error("Mooncake is not available") raise RuntimeError("Mooncake is not available") @@ -752,6 +819,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.block_size = vllm_config.cache_config.block_size self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + self.kv_cache_config = kv_cache_config self.use_mla = self.model_config.use_mla self._sync_block_size_with_kernel() @@ -1103,27 +1171,61 @@ async def _build_transfer_params( remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" for d_req_id, send_meta in ready_reqs: - _, remote_block_ids = agent_meta.req_blocks[d_req_id] - num_remote_blocks = len(remote_block_ids) - if num_remote_blocks == 0: + _, remote_block_ids_per_group = agent_meta.req_blocks[d_req_id] + + if not remote_block_ids_per_group or all( + len(g) == 0 for g in remote_block_ids_per_group + ): continue - local_block_ids = send_meta.local_block_ids - # Partial prefix cache hit: just read uncomputed blocks. - num_local_blocks = len(local_block_ids) - if num_local_blocks < num_remote_blocks: + # Per-group partial hit trimming, then flatten. + # With HMA, groups share the same KV tensor but use different + # block ranges. We trim and concatenate so the coalescer and + # address math see one flat block list — same as non-HMA, but + # now including blocks from every group. + local_block_ids: list[int] = [] + remote_block_ids: list[int] = [] + has_block_error = False + if len(send_meta.local_block_ids) != len(remote_block_ids_per_group): logger.error( - "req %s: local blocks(%d) less than remote blocks(%d)!", + "req %s: KV group count mismatch: local=%d, remote=%d", d_req_id, - num_local_blocks, - num_remote_blocks, + len(send_meta.local_block_ids), + len(remote_block_ids_per_group), ) + err_reqs.append(d_req_id) + if err_msg is None: + err_msg = "KV group count mismatch" + continue + for local_group, remote_group in zip( + send_meta.local_block_ids, remote_block_ids_per_group + ): + n_local = len(local_group) + n_remote = len(remote_group) + if n_local < n_remote: + logger.error( + "req %s: local blocks(%d) < remote blocks(%d) " + "in a KV cache group", + d_req_id, + n_local, + n_remote, + ) + has_block_error = True + break + if n_local > n_remote: + # Partial prefix cache hit: just read uncomputed blocks. + local_group = local_group[-n_remote:] + local_block_ids.extend(local_group) + remote_block_ids.extend(remote_group) + + if has_block_error: err_reqs.append(d_req_id) if err_msg is None: err_msg = "P num blocks less than D" continue - if num_local_blocks > num_remote_blocks: - local_block_ids = local_block_ids[-num_remote_blocks:] + + if not local_block_ids: + continue # Group by indices group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous( @@ -1215,7 +1317,7 @@ async def _build_transfer_params( logger.debug( "Sending kv_caches for request %s (%d blocks) to %s", d_req_id, - num_remote_blocks, + len(local_block_ids), remote_session, ) @@ -1273,23 +1375,24 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): continue seen_base_addresses.append(base_addr) - curr_tensor_size_bytes = cache.nbytes if tensor_size_bytes is None: - tensor_size_bytes = curr_tensor_size_bytes + tensor_size_bytes = cache.nbytes self.num_blocks = cache.shape[0] assert cache.shape[0] == self.num_blocks, ( "All kv cache tensors must have the same number of blocks" ) - assert curr_tensor_size_bytes % self.num_blocks == 0, ( - "Mooncake expects each kv cache tensor size to be " - "divisible by the number of blocks." - ) - self.block_len_per_layer.append( - curr_tensor_size_bytes // self.num_blocks - ) + + # Use stride-based block length so RDMA reaches the last + # block's padding (e.g. DeepseekV4 MLA alignment). stride(0) + # reflects the actual byte distance between consecutive + # blocks in GPU memory, which matches or exceeds the + # shape-based size. + block_len = cache.stride(0) * cache.element_size() + + self.block_len_per_layer.append(block_len) kv_data_ptrs.append(base_addr) - kv_data_lens.append(curr_tensor_size_bytes) + kv_data_lens.append(self.num_blocks * block_len) self.kv_caches_base_addr = seen_base_addresses self.seen_base_addresses = seen_base_addresses diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c7c8c0421693..bd4f29a00410 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -299,6 +299,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): tools: list[ChatCompletionFunctionToolParam] | None """The tools for developer role.""" + task: str | None + """Model-specific task marker. Currently passed through for DeepSeek V4.""" + ChatCompletionMessageParam: TypeAlias = ( OpenAIChatCompletionMessageParam @@ -333,6 +336,9 @@ class ConversationMessage(TypedDict, total=False): tools: list[ChatCompletionFunctionToolParam] | None """The tools for developer role.""" + task: str | None + """Model-specific task marker. Currently passed through for DeepSeek V4.""" + # Passed in by user ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] @@ -1566,6 +1572,9 @@ def _parse_chat_message_content( if "name" in message and isinstance(message["name"], str): result_msg["name"] = message["name"] + if "task" in message and isinstance(message["task"], str): + result_msg["task"] = message["task"] + if role == "developer": result_msg["tools"] = message.get("tools", None) return result diff --git a/vllm/envs.py b/vllm/envs.py index 806aed2a0414..3139015545af 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -166,6 +166,15 @@ VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True + VLLM_DEEP_GEMM_SM120_PAGED_MQA_TILED: bool = True + VLLM_TRITON_MLA_SPARSE: bool | None = None + VLLM_TRITON_MLA_SPARSE_DUMP: bool = False + VLLM_TRITON_MLA_SPARSE_DUMP_PATH: str = "" + VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 + VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 + VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH: bool = True + VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE: int | None = None + VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE: bool | None = None VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", @@ -247,6 +256,7 @@ VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_DEEPSEEK_V4_USE_MEGA_MOE: bool = False VLLM_LOG_MODEL_INSPECTION: bool = False VLLM_DEBUG_MFU_METRICS: bool = False VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: bool = False @@ -1265,6 +1275,45 @@ def _get_or_set_default() -> str: "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) ), + # Enable DeepGEMM's SM120 paged MQA tiled kernel when available. + "VLLM_DEEP_GEMM_SM120_PAGED_MQA_TILED": lambda: bool( + int(os.getenv("VLLM_DEEP_GEMM_SM120_PAGED_MQA_TILED", "1")) + ), + # Experimental sparse MLA reference fallback controls. + # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse + # is unavailable; set 0/1 to force-disable/force-enable the fallback. + "VLLM_TRITON_MLA_SPARSE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE") is None + else os.getenv("VLLM_TRITON_MLA_SPARSE", "").lower() + in ("1", "true", "yes", "on") + ), + "VLLM_TRITON_MLA_SPARSE_DUMP": lambda: ( + os.getenv("VLLM_TRITON_MLA_SPARSE_DUMP", "").lower() + in ("1", "true", "yes", "on") + ), + "VLLM_TRITON_MLA_SPARSE_DUMP_PATH": lambda: os.getenv( + "VLLM_TRITON_MLA_SPARSE_DUMP_PATH", "" + ), + "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", "512") + ), + "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", "256") + ), + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH": lambda: ( + os.getenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", "1").lower() + in ("1", "true", "yes", "on") + ), + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE") + ), + "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE") is None + else os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE", "").lower() + in ("1", "true", "yes", "on") + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -1675,6 +1724,11 @@ def _get_or_set_default() -> str: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + # Use the DeepGEMM MegaMoE fused expert kernel for DeepSeek V4 routed + # experts. Set to 0 to fall back to the standard SharedFusedMoE path. + "VLLM_DEEPSEEK_V4_USE_MEGA_MOE": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "0")) + ), # Log model inspection after loading. # If enabled, logs a transformers-style hierarchical view of the model # with quantization methods and attention backends. diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 618084029159..6baedd3bbcbc 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -7,6 +7,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) @@ -26,6 +29,20 @@ ) +def _is_sm12x_compute_capability(compute_capability) -> bool: + if compute_capability is None: + return current_platform.is_device_capability_family(120) + + if isinstance(compute_capability, tuple): + return compute_capability[0] == 12 + + to_int = getattr(compute_capability, "to_int", None) + if callable(to_int): + return to_int() // 10 == 12 + + return int(compute_capability) // 10 == 12 + + class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def is_supported( @@ -196,6 +213,9 @@ def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: @classmethod def is_supported(cls, compute_capability=None): + if _is_sm12x_compute_capability(compute_capability): + return False, "CUTLASS block-scaled FP8 GEMM is not supported on SM12x." + if not CUTLASS_BLOCK_FP8_SUPPORTED: return ( False, @@ -219,6 +239,31 @@ def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): ) return True, None + def process_weights_after_loading(self, layer: torch.nn.Module): + super().process_weights_after_loading(layer) + params = self._get_layer_params(layer) + weight_scale = ( + params.weight_scale + if params.weight_scale_inv is None + else params.weight_scale_inv + ) + scale_attr_name = ( + params.WEIGHT_SCALE + if params.weight_scale_inv is None + else params.WEIGHT_SCALE_INV + ) + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and weight_scale is not None + and weight_scale.dtype == e8m0_dtype + ): + replace_parameter( + layer, + scale_attr_name, + _upcast_e8m0_to_fp32(weight_scale), + ) + def apply_block_scaled_mm( self, A: torch.Tensor, diff --git a/vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py b/vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py index a369623a3b17..70122f7b4ac6 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/deep_gemm.py @@ -100,6 +100,8 @@ def process_weights_after_loading(self, layer): else params.weight_scale, quant_block_shape=tuple(layer.weight_block_size), use_e8m0=self.use_deep_gemm_e8m0, + is_bmm=getattr(layer, "is_bmm", False), + bmm_batch_size=getattr(layer, "bmm_batch_size", 0), ) replace_parameter(layer, params.WEIGHT, dg_weight) replace_parameter(layer, scale_attr, dg_weight_scale) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index e649d790e82a..4afe2319570e 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1422,6 +1422,20 @@ class MLADims: def get_mla_dims(model_config: ModelConfig) -> MLADims: hf_text_config = model_config.hf_text_config + # Check if this is a DeepseekV4 config (uses unified head_dim + rope_head_dim) + if hasattr(hf_text_config, "compress_ratios"): + # DeepseekV4 style config: unified head_dim with rope_head_dim + head_dim = hf_text_config.head_dim + rope_head_dim = hf_text_config.qk_rope_head_dim + return MLADims( + q_lora_rank=hf_text_config.q_lora_rank, + kv_lora_rank=head_dim, + qk_nope_head_dim=head_dim - rope_head_dim, + qk_rope_head_dim=rope_head_dim, + v_head_dim=head_dim, + ) + + # DeepseekV2/V3 style config return MLADims( q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), kv_lora_rank=hf_text_config.kv_lora_rank, @@ -2191,6 +2205,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + # DSV3.2 MLA Specific Arguments indexer: object | None = None, q_pad_num_heads: int | None = None, ) -> None: @@ -2213,6 +2228,7 @@ def __init__( self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads self.supports_quant_query_input = True + self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() # Use flashinfer's optimized concat_mla_k kernel when available. # The kernel is optimized for DeepSeek V3 dimensions: diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py new file mode 100644 index 000000000000..af2783f604da --- /dev/null +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -0,0 +1,438 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any, ClassVar, cast + +import torch +from torch import nn + +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, +) +from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + MultipleOf, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fused_compress_quant_cache import ( + _fused_kv_compress_norm_rope_insert_indexer_attn, + _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, + _fused_kv_compress_norm_rope_insert_sparse_attn, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( + MXFP4_BLOCK_SIZE, +) +from vllm.v1.kv_cache_interface import ( + KVCacheSpec, + MLAAttentionSpec, + SlidingWindowMLASpec, +) + + +class CompressorBackend(AttentionBackend): + def __init__(self): + super().__init__() + + @staticmethod + def get_name() -> str: + return "CompressorBackend" + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(1)] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [512, 1024] + + @staticmethod + def get_builder_cls() -> type["CompressorMetadataBuilder"]: + return CompressorMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + assert num_kv_heads == 1 + return (num_blocks, block_size, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + if include_num_layers_dimension: + return (0, 1, 2, 3) + return (0, 1, 2) + + +@dataclass +class CompressorMetadata: + block_table: torch.Tensor + slot_mapping: torch.Tensor + block_size: int + + token_to_req_indices: torch.Tensor | None = None # [num_tokens] + + +class CompressorMetadataBuilder(AttentionMetadataBuilder): + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) + mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec) + self.block_size = mla_spec.block_size + + self.token_to_req_indices = torch.zeros( + self.vllm_config.scheduler_config.max_num_batched_tokens, + dtype=torch.int32, + device=self.device, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> CompressorMetadata: + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_reqs = common_attn_metadata.num_reqs + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory() + token_to_req_indices = self.token_to_req_indices[: x.shape[0]] + token_to_req_indices.copy_(x, non_blocking=True) + return CompressorMetadata( + block_table=common_attn_metadata.block_table_tensor.clamp_(min=0), + slot_mapping=common_attn_metadata.slot_mapping, + block_size=self.block_size, + token_to_req_indices=token_to_req_indices, + ) + + +class CompressorStateCache(torch.nn.Module, AttentionLayerBase): + def __init__( + self, + state_dim: int, + dtype: torch.dtype, + compress_ratio: int, + prefix: str, + ): + super().__init__() + self.state_dim = state_dim + self.dtype = dtype + self.prefix = prefix + self.kv_cache = torch.tensor([]) + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + assert self.dtype == torch.float32 + assert compress_ratio in [4, 128] + coff = 1 + (compress_ratio == 4) + self.sliding_window = coff * compress_ratio + # Block size is constrained by tensor sharing between compressor states + # and KV blocks. Since compressor states share the same physical tensor + # as KV blocks, they must use the same page size. + # The KV block shape [256//4, head_dim] = [64, 584] determines: + # - C4 compressor block shape [4, 2*512*2*4] -> block_size = 4 + # - C128 compressor block shape [8, 512*2*4] -> block_size = 8 + # TODO(yifan): make block size automatically determined and configurable. + if compress_ratio == 4: + self.block_size = 4 + elif compress_ratio == 128: + self.block_size = 8 + else: + raise ValueError(f"Invalid compress ratio: {compress_ratio}") + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + return SlidingWindowMLASpec( # only has one vector instead of K + V + block_size=self.block_size, + num_kv_heads=1, + head_size=self.state_dim, + dtype=self.dtype, + sliding_window=self.sliding_window, + alignment=576, # NOTE: FlashMLA requires 576B alignment + ) + + def forward(self): ... + + def get_attn_backend(self) -> type[AttentionBackend]: + return CompressorBackend + + +class DeepseekCompressor(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + compress_ratio: int, + hidden_size: int, + head_dim: int, + rotate: bool = False, + prefix: str = "", + k_cache_prefix="", + use_fp4_cache: bool = False, + ): + super().__init__() + self.compress_ratio = compress_ratio + self.hidden_size = hidden_size + self.head_dim = head_dim + self.rotate = rotate + self.prefix = prefix + self.k_cache_prefix = k_cache_prefix + self.use_fp4_cache = use_fp4_cache + + config = vllm_config.model_config.hf_config + self.rope_head_dim = config.qk_rope_head_dim + self.nope_head_dim = self.head_dim - self.rope_head_dim + self.rms_norm_eps = config.rms_norm_eps + self.device = current_platform.device_type + self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs + self.max_model_len = vllm_config.model_config.max_model_len + + self.overlap = compress_ratio == 4 + self.coff = 1 + self.overlap + + state_dtype = torch.float32 + self.ape = nn.Parameter( + torch.empty( + (compress_ratio, self.coff * self.head_dim), + dtype=state_dtype, + device=self.device, + ), + requires_grad=False, + ) + + self.fused_wkv_wgate = MergedColumnParallelLinear( + self.hidden_size, + [self.coff * self.head_dim, self.coff * self.head_dim], + bias=False, + return_bias=False, + quant_config=None, + disable_tp=True, + prefix=f"{prefix}.fused_wkv_wgate", + ) + self.norm = RMSNorm(self.head_dim, self.rms_norm_eps) + + self.state_cache = CompressorStateCache( + state_dim=2 * self.coff * self.head_dim, # kv_state + score_state + dtype=state_dtype, + compress_ratio=compress_ratio, + prefix=f"{prefix}.state_cache", + ) + + # Save reference to static_forward_context for forward-time KV cache lookup. + # get_current_vllm_config() is only available during __init__, not forward. + self._static_forward_context = ( + vllm_config.compilation_config.static_forward_context + ) + + if self.head_dim == 512: + assert not use_fp4_cache, ( + "MXFP4 cache is only supported for indexer (head=128)" + ) + self._fused_kernel = _fused_kv_compress_norm_rope_insert_sparse_attn + self._quant_block = 64 + self._token_stride = self.nope_head_dim + self.rope_head_dim * 2 + self._scale_dim = self.nope_head_dim // 64 + 1 # 7 real + 1 pad + self._num_warps = 4 + elif self.head_dim == 128: + if use_fp4_cache: + self._fused_kernel = ( + _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn + ) + self._quant_block = MXFP4_BLOCK_SIZE + self._token_stride = self.head_dim // 2 + self._scale_dim = self.head_dim // MXFP4_BLOCK_SIZE + else: + self._fused_kernel = _fused_kv_compress_norm_rope_insert_indexer_attn + self._quant_block = 128 + self._token_stride = self.head_dim + self._scale_dim = 4 # single float32 scale + self._num_warps = 1 + else: + raise ValueError( + f"Unsupported head_dim for fused quant+cache: {self.head_dim}" + ) + + def forward( + self, + # [num_tokens, hidden_size] + x: torch.Tensor, + # [num_tokens] + positions: torch.Tensor, + rotary_emb, + ) -> None: + num_tokens, _ = x.shape + # bf16 weights/activations but fp32 output for numerical stability of + # the downstream compressor math. + kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight) + # Each of shape [num_tokens, coff * self.head_dim] + # input bf16, output are fp32 + kv, score = kv_score.split( + [self.coff * self.head_dim, self.coff * self.head_dim], dim=-1 + ) + + # Get the metadata and handle dummy profiling run. + attn_metadata = get_forward_context().attn_metadata + if not isinstance(attn_metadata, dict): + return + + state_metadata = cast( + CompressorMetadata, attn_metadata[self.state_cache.prefix] + ) + token_to_req_indices = state_metadata.token_to_req_indices + slot_mapping = state_metadata.slot_mapping + num_actual = slot_mapping.shape[0] + block_table = state_metadata.block_table + block_size = state_metadata.block_size + + # [num_blocks, block_size, kv_dim+score_dim], where kv_dim == score_dim + state_cache = self.state_cache.kv_cache + # kv_state stored in first half, score_state stored in second half + state_width = state_cache.shape[-1] // 2 + + # Store the KV and score (with fused APE addition) in the state. + # NOTE: PDL is disabled — both this kernel and _fused_kernel below + # depend on preceding kernel outputs (kv/score from the cublas GEMM; + # state_cache from this kernel) but neither emits/waits on PDL grid + # dependency primitives, so launch_pdl=True caused a read-after-write + # race and non-deterministic output. + _save_partial_states_kernel[(num_actual,)]( + kv, + kv.stride(0), + score, + score.stride(0), + self.ape, + self.ape.stride(0), + positions, + state_cache, + state_cache.stride(0), + state_cache.stride(1), + slot_mapping, + block_size, + HEAD_SIZE=kv.shape[-1], + TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), + STATE_WIDTH=state_width, + COMPRESS_RATIO=self.compress_ratio, + launch_pdl=False, + ) + + # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. + # RoPE requirements (kernel applies forward GPT-J style rotation): + # - is_neox_style=False (interleaved pairs, NOT split-half) + # - cos_sin_cache layout: [max_pos, rope_head_dim] with first half cos, + # second half sin (per-pair, length rope_head_dim // 2 each) + # - applied to LAST rope_head_dim elements of head_dim + # - position used: (positions // compress_ratio) * compress_ratio + cos_sin_cache = rotary_emb.cos_sin_cache + k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) + kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache + + self._fused_kernel[(num_actual,)]( + # state cache + state_cache, + state_cache.stride(0), + state_cache.stride(1), + # metadata + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_table.stride(0), + block_size, + # RMSNorm + self.norm.weight, + self.rms_norm_eps, + # RoPE + cos_sin_cache, + cos_sin_cache.stride(0), + # KV cache + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], # paged KV cache block size (tokens per block) + # constexprs + HEAD_SIZE=self.head_dim, + TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim), + STATE_WIDTH=state_width, + COMPRESS_RATIO=self.compress_ratio, + OVERLAP=self.overlap, + ROPE_HEAD_DIM=self.rope_head_dim, + FP8_MAX=448.0, + QUANT_BLOCK=self._quant_block, + TOKEN_STRIDE=self._token_stride, + SCALE_DIM=self._scale_dim, + KV_BLOCK_STRIDE=kv_cache.stride(0), + num_warps=self._num_warps, + launch_pdl=False, + ) + + +@triton.jit +def _save_partial_states_kernel( + kv_ptr, + kv_stride, + score_ptr, + score_stride, + ape_ptr, + ape_stride, + positions_ptr, + state_cache_ptr, + state_cache_stride0, + state_cache_stride1, + slot_mapping_ptr, + block_size, + HEAD_SIZE: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, + # state_cache last dim packs [kv_state, score_state], each STATE_WIDTH wide. + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, +): + token_idx = tl.program_id(0) + slot_id = tl.load(slot_mapping_ptr + token_idx) + + # Skip padded / invalid tokens (slot_id == -1 is the PAD sentinel used + # by vLLM). During CUDA graph replay the batch may contain padding + # tokens whose slot_mapping is -1; writing to kv_state[-1] would be an + # illegal memory access. + if slot_id < 0: + return + + block_idx = slot_id // block_size + pos_in_block = slot_id % block_size + base_ptr = ( + state_cache_ptr + + block_idx * state_cache_stride0 + + pos_in_block * state_cache_stride1 + ) + + block = tl.arange(0, TRITON_BLOCK_SIZE) + mask = block < HEAD_SIZE + + kv = tl.load(kv_ptr + token_idx * kv_stride + block, mask=mask) + tl.store(base_ptr + block, kv, mask=mask) + + # Fused: score += ape[position % compress_ratio] + position = tl.load(positions_ptr + token_idx) + ape_row = position % COMPRESS_RATIO + ape = tl.load(ape_ptr + ape_row * ape_stride + block, mask=mask) + score = tl.load(score_ptr + token_idx * score_stride + block, mask=mask) + tl.store( + base_ptr + STATE_WIDTH + block, + score + ape, + mask=mask, + ) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py new file mode 100644 index 000000000000..4738a55165ba --- /dev/null +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -0,0 +1,1722 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +DeepseekV4 MLA Attention Layer +""" + +import json +from dataclasses import dataclass +from typing import TYPE_CHECKING, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import DeepseekV2Config, DeepseekV3Config + +from vllm.model_executor.layers.linear import ( + ReplicatedLinear, +) +from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import fp8_einsum +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.ops.deepseek_v4_ops import ( + combine_topk_swa_indices, + compute_global_topk_indices_and_lens, + dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + fused_indexer_q_rope_quant, + fused_inv_rope_fp8_quant, + fused_q_kv_rmsnorm, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( + deepseek_v4_sm12_fp8_einsum, +) + +if TYPE_CHECKING: + from vllm.v1.attention.backends.mla.sparse_swa import ( + DeepseekSparseSWAMetadata, + ) + +from vllm.config import ( + CacheConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import PluggableLayer +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor +from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.input_quant_fp8 import ( + QuantFP8, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, +) +from vllm.utils.multi_stream_utils import maybe_execute_in_parallel +from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + DeepseekV4FlashMLASparseBackend, + FlashMLASparseBackend, + FlashMLASparseMetadata, +) +from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV4IndexerBackend, + get_max_prefill_buffer_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + disable_sparse_mla_reference_cudagraphs_if_enabled, + is_sparse_mla_attention_dump_enabled, + is_sparse_mla_reference_attention_enabled, + sparse_mla_attention_dump_path, + sparse_mla_matmul_decode_enabled, + sparse_mla_reference_query_chunk_size, + sparse_mla_reference_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead, + accumulate_indexed_sparse_mla_attention_chunk, + build_combined_sparse_mla_decode_valid_mask, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead, + fp8ds_paged_sparse_mla_attention_with_sink_multihead, + matmul_sparse_mla_attention_with_sink, + sparse_mla_decode_head_block_size, +) +from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache +from vllm.v1.attention.ops.flashmla import ( + flash_mla_sparse_fwd, + flash_mla_with_kvcache, +) +from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec +from vllm.v1.worker.workspace import current_workspace_manager + +logger = init_logger(__name__) + + +def _tensor_summary(tensor: torch.Tensor | None) -> dict[str, object] | None: + if tensor is None: + return None + return { + "shape": [int(dim) for dim in tensor.shape], + "dtype": str(tensor.dtype), + "stride": [int(stride) for stride in tensor.stride()], + "device": str(tensor.device), + "is_contiguous": tensor.is_contiguous(), + } + + +def _optional_int(value: object) -> int | None: + return int(value) if value is not None else None + + +def _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu: torch.Tensor, + gather_lens_cpu: torch.Tensor, + compress_ratio: int, + swa_only: bool, +) -> tuple[int, int]: + if seq_lens_cpu.numel() == 0: + return 0, 0 + + max_gather_len = int(gather_lens_cpu.max().item()) + if swa_only: + return 0, max_gather_len + + compressed_region_size = int((seq_lens_cpu // compress_ratio).max().item()) + return compressed_region_size, compressed_region_size + max_gather_len + + +def _deepseek_v4_fp8_einsum_config( + capability_major: int, +) -> tuple[tuple[int, int, int], bool]: + if capability_major == 10: + return (1, 1, 128), True + return (1, 128, 128), False + + +def _use_deepseek_v4_sm12_triton_fp8_einsum( + equation: str, + recipe: list[int], + b_scale: torch.Tensor, +) -> bool: + capability = current_platform.get_device_capability() + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + return ( + capability is not None + and capability.major == 12 + and equation == "bhr,hdr->bhd" + and tuple(recipe) == (1, 128, 128) + and b_scale.dtype in (torch.float32, e8m0_dtype) + ) + + +def _dump_sparse_mla_attention_state( + phase: str, + prefix: str, + compress_ratio: int, + q: torch.Tensor, + output: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + fields: dict[str, object], +) -> None: + dump_path = sparse_mla_attention_dump_path() + payload = { + "phase": phase, + "prefix": prefix, + "compress_ratio": compress_ratio, + "q": _tensor_summary(q), + "output": _tensor_summary(output), + "attn_metadata_present": attn_metadata is not None, + "swa_metadata": { + "block_table": _tensor_summary(swa_metadata.block_table), + "slot_mapping": _tensor_summary(swa_metadata.slot_mapping), + "seq_lens": _tensor_summary(swa_metadata.seq_lens), + "query_start_loc": _tensor_summary(swa_metadata.query_start_loc), + "is_valid_token": _tensor_summary(swa_metadata.is_valid_token), + "token_to_req_indices": _tensor_summary( + swa_metadata.token_to_req_indices + ), + "decode_swa_indices": _tensor_summary( + swa_metadata.decode_swa_indices + ), + "decode_swa_lens": _tensor_summary(swa_metadata.decode_swa_lens), + "prefill_seq_lens": _tensor_summary(swa_metadata.prefill_seq_lens), + "prefill_gather_lens": _tensor_summary( + swa_metadata.prefill_gather_lens + ), + "block_size": int(swa_metadata.block_size), + "num_decodes": int(swa_metadata.num_decodes), + "num_prefills": int(swa_metadata.num_prefills), + "num_decode_tokens": int(swa_metadata.num_decode_tokens), + "num_prefill_tokens": int(swa_metadata.num_prefill_tokens), + }, + "flashmla_metadata": { + "block_table": _tensor_summary( + attn_metadata.block_table if attn_metadata is not None else None + ), + "slot_mapping": _tensor_summary( + attn_metadata.slot_mapping if attn_metadata is not None else None + ), + "c128a_global_decode_topk_indices": _tensor_summary( + attn_metadata.c128a_global_decode_topk_indices + if attn_metadata is not None + else None + ), + "c128a_decode_topk_lens": _tensor_summary( + attn_metadata.c128a_decode_topk_lens + if attn_metadata is not None + else None + ), + "c128a_prefill_topk_indices": _tensor_summary( + attn_metadata.c128a_prefill_topk_indices + if attn_metadata is not None + else None + ), + "block_size": _optional_int( + attn_metadata.block_size if attn_metadata is not None else None + ), + "topk_tokens": _optional_int( + attn_metadata.topk_tokens if attn_metadata is not None else None + ), + }, + "fields": fields, + } + with open(dump_path, "a", encoding="utf-8") as dump_file: + dump_file.write(json.dumps(payload, sort_keys=True) + "\n") + raise RuntimeError( + f"DeepseekV4 sparse MLA diagnostic dump written to {dump_path}" + ) + + +def _write_sparse_mla_attention_state_if_enabled( + phase: str, + prefix: str, + compress_ratio: int, + q: torch.Tensor, + output: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + fields: dict[str, object], +) -> None: + if not is_sparse_mla_attention_dump_enabled(): + return + _dump_sparse_mla_attention_state( + phase=phase, + prefix=prefix, + compress_ratio=compress_ratio, + q=q, + output=output, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + fields=fields, + ) + +# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather +# workspace allocated at _forward_prefill (and the matching profile-time +# reservation in attention_impl's dummy-run branch). +PREFILL_CHUNK_SIZE = 4 + + +@dataclass +class DeepseekV4MLAModules: + """Modules used in DeepseekV4 MLA.""" + + vllm_config: VllmConfig + fused_wqa_wkv: torch.nn.Module + q_norm: torch.nn.Module + wq_b: torch.nn.Module + kv_norm: torch.nn.Module + wo_a: torch.nn.Module + wo_b: torch.nn.Module + attn_sink: torch.nn.Module + rotary_emb: torch.nn.Module + indexer: torch.nn.Module | None + indexer_rotary_emb: torch.nn.Module + topk_indices_buffer: torch.Tensor | None + aux_stream: torch.cuda.Stream | None = None + + +# --8<-- [start:multi_head_latent_attention] +@PluggableLayer.register("deepseek_v4_multi_head_latent_attention") +class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): + """Pluggable MLA layer which allows OOT backends to add + custom implementations of the outer MLA layer (including rope & o_proj). + Note that currently oot platforms can still use CustomOp.register_oot to + replace MLA layer entirely, although we use PluggableLayer to register + this layer now. + + This class takes positions and hidden_states as input. + The input tensors can either contain prefill tokens or decode tokens. + The class does the following: + + 1. MLA Preprocess. + 2. Perform multi-head attention to prefill tokens and + multi-query attention to decode tokens separately. + 3. Return the output tensor. + """ + + # --8<-- [end:multi_head_latent_attention] + + def __init__( + self, + hidden_size: int, + num_heads: int, + head_dim: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + o_lora_rank: int | None, + mla_modules: DeepseekV4MLAModules, + window_size: int, + compress_ratio: int | None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.n_local_heads = num_heads + self.head_dim = head_dim + self.scale = scale + + # FlashMLA sparse kernel only supports 64 or 128 heads; pad up to the + # next supported size. Must match DeepseekV4MLAAttention.padded_heads. + if num_heads <= 64: + self.padded_heads = 64 + elif num_heads <= 128: + self.padded_heads = 128 + else: + raise ValueError( + f"DeepseekV4 attention does not support {num_heads} heads " + "(must be <= 128)." + ) + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.window_size = window_size + self.compress_ratio = compress_ratio if compress_ratio is not None else 1 + self.prefix = prefix + + disable_sparse_mla_reference_cudagraphs_if_enabled(mla_modules.vllm_config) + + # Extract config from vllm_config + config = mla_modules.vllm_config.model_config.hf_config + tp_size = get_tensor_model_parallel_world_size() + + # DeepseekV4-specific attributes (num_heads is already TP-adjusted) + self.eps = config.rms_norm_eps + self.rope_head_dim = config.qk_rope_head_dim + self.nope_head_dim = head_dim - self.rope_head_dim + self.n_local_groups = config.o_groups // tp_size + self.o_lora_rank = config.o_lora_rank + + # Store projection modules + self.fused_wqa_wkv = mla_modules.fused_wqa_wkv + self.q_norm = mla_modules.q_norm + self.wq_b = mla_modules.wq_b + + self.kv_norm = mla_modules.kv_norm + self.wo_a = mla_modules.wo_a + + self._wo_a_act_quant = QuantFP8( + static=False, + group_shape=GroupShape(1, 128), + use_ue8m0=True, + ) + # Bypass packed-for-deepgemm path — we need FP32 scales (not packed + # INT32) so fp8_einsum can handle layout transform internally. + self._wo_a_act_quant.use_deep_gemm_supported = False + self.wo_b = mla_modules.wo_b + + # Pick fp8_einsum recipe based on GPU arch: + # SM90/SM120: FP32 block scales stay [g, r/128, d/128]. + # SM100: INT32 packed scales become [g, r, ...]. + from vllm.platforms import current_platform + + cap = current_platform.get_device_capability() + assert cap is not None, "DeepseekV4 attention requires a CUDA device" + self._einsum_recipe, self._tma_aligned_scales = ( + _deepseek_v4_fp8_einsum_config(cap.major) + ) + + self.rotary_emb = mla_modules.rotary_emb + self.indexer_rotary_emb = mla_modules.indexer_rotary_emb + self.topk_indices_buffer = mla_modules.topk_indices_buffer + + self.indexer = mla_modules.indexer + + # Per-head RMS normalization for Q (no learnable weights) + self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) + + # TODO(yifan): currently hardcoded for FP8 sparse, make it more generic + head_bytes = ( + self.nope_head_dim # 448 fp8 NoPE + + self.rope_head_dim * 2 # 64 bf16 RoPE + + self.nope_head_dim // 64 # 7B scale factors + + 1 # 1B pad + ) + + self.aux_stream = mla_modules.aux_stream + self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + + assert cache_config is not None, "DeepseekV4 attention requires cache_config" + self.swa_cache_layer = DeepseekV4SWACache( + head_dim=self.head_dim, + window_size=self.window_size, + dtype=torch.uint8, + prefix=f"{prefix}.swa_cache", + cache_config=cache_config, + ) + + self.mla_attn = DeepseekV4MLAAttention( + num_heads=self.n_local_heads, + head_dim=self.head_dim, + scale=self.scale, + qk_nope_head_dim=self.nope_head_dim, + qk_rope_head_dim=self.rope_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + compress_ratio=self.compress_ratio, + window_size=self.window_size, + head_bytes=head_bytes, + swa_cache_layer=self.swa_cache_layer, + attn_sink=mla_modules.attn_sink, # already padded with -inf + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + indexer=self.indexer, + topk_indices_buffer=self.topk_indices_buffer, + ) + # Register this layer in the compilation config's static forward context + # This allows the custom op to retrieve the layer during execution + compilation_config = mla_modules.vllm_config.compilation_config + # HACK + self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention" + if self.layer_name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {self.layer_name}") + compilation_config.static_forward_context[self.layer_name] = self + + # Create the compressor for layers with compress_ratio > 1; after + # creating the DeepseekV4MLAAttention layer to get its cache. + self.compressor = None + if self.compress_ratio > 1: + self.compressor = DeepseekCompressor( + vllm_config=mla_modules.vllm_config, + compress_ratio=self.compress_ratio, + hidden_size=self.hidden_size, + head_dim=self.head_dim, + rotate=True, + prefix=f"{prefix}.compressor", + k_cache_prefix=self.mla_attn.prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None = None, + ) -> torch.Tensor: + qr_kv, _ = self.fused_wqa_wkv(hidden_states) + qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) + + # Pre-allocate attention output with FlashMLA-padded head count. + # The op writes into `o_padded`; we slice to n_local_heads after. + num_tokens = hidden_states.shape[0] + o_padded = torch.empty( + (num_tokens, self.padded_heads, self.head_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # Attention (inside custom op for torch.compile boundary) + torch.ops.vllm.deepseek_v4_attention( + hidden_states, + qr, + kv, + positions, + o_padded, + self.layer_name, + ) + o = o_padded[:, : self.n_local_heads, :] + + # O projection: inverse RoPE + FP8 quant + einsum + wo_b + o_fp8, o_scale = fused_inv_rope_fp8_quant( + o, + positions, + self.rotary_emb.cos_sin_cache, + n_groups=self.n_local_groups, + heads_per_group=self.n_local_heads // self.n_local_groups, + nope_dim=self.nope_head_dim, + rope_dim=self.rope_head_dim, + tma_aligned_scales=self._tma_aligned_scales, + ) + + wo_a_fp8 = self.wo_a.weight + wo_a_scale = self.wo_a.weight_scale_inv + + (z,) = current_workspace_manager().get_simultaneous( + ((num_tokens, self.n_local_groups, self.o_lora_rank), torch.bfloat16), + ) + torch.ops.vllm.deepseek_v4_fp8_einsum( + o_fp8, + o_scale, + wo_a_fp8, + wo_a_scale, + z, + "bhr,hdr->bhd", + list(self._einsum_recipe), + ) + + return self.wo_b(z.flatten(1)) + + def attention_impl( + self, + hidden_states: torch.Tensor, + qr: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place + ) -> None: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + + qr, kv = fused_q_kv_rmsnorm( + qr, + kv, + self.q_norm.weight.data, + self.kv_norm.weight.data, + self.eps, + ) + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + + # Overlap kv_insert with whichever of indexer/compressor is present. + # Indexer implies compressor; when both exist, compressor rides on the + # aux stream alongside kv_insert so the heavy indexer owns default. + if self.indexer is not None: + indexer = self.indexer + # Local ref so the closure keeps a non-None type for mypy. + assert self.compressor is not None + compressor = self.compressor + + def kv_insert_and_compress() -> None: + self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + compressor(hidden_states, positions, self.rotary_emb) + + maybe_execute_in_parallel( + lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb), + kv_insert_and_compress, + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + elif self.compressor is not None: + # Compressor on default, kv_insert on aux. + compressor = self.compressor + maybe_execute_in_parallel( + lambda: compressor(hidden_states, positions, self.rotary_emb), + lambda: self._fused_qnorm_rope_kv_insert( + q, kv, positions, attn_metadata + ), + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + else: + # SWA-only layer: no compressor, no overlap. + self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + + # Handle dummy run (no metadata). + if not isinstance(attn_metadata, dict): + out.zero_() + return + + # Pad q to FlashMLA-required head count (64 or 128) + if self.n_local_heads < self.padded_heads: + pad_size = self.padded_heads - self.n_local_heads + q = F.pad(q, (0, 0, 0, pad_size), value=0.0) + + # MLA attention writes into the pre-allocated `out` buffer + # ([num_tokens, padded_heads, head_dim]). + self.mla_attn(q, kv, positions, output=out) + + def _fused_qnorm_rope_kv_insert( + self, + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + ) -> None: + if not isinstance(attn_metadata, dict): + return + + swa_metadata = cast( + "DeepseekSparseSWAMetadata | None", + attn_metadata.get(self.swa_cache_layer.prefix), + ) + assert swa_metadata is not None + + swa_kv_cache = self.swa_cache_layer.kv_cache + swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) + + # Horizontally fused: + # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE + # KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert + # kv is unchanged; mla_attn reads kv solely via swa_kv_cache. + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + q, + kv, + swa_kv_cache_2d, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + self.eps, + swa_metadata.block_size, + ) + + +def deepseek_v4_attention( + hidden_states: torch.Tensor, + qr: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + out: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.attention_impl(hidden_states, qr, kv, positions, out) + + +def deepseek_v4_attention_fake( + hidden_states: torch.Tensor, + qr: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + out: torch.Tensor, + layer_name: str, +) -> None: + return None + + +direct_register_custom_op( + op_name="deepseek_v4_attention", + op_func=deepseek_v4_attention, + mutates_args=["out"], + fake_impl=deepseek_v4_attention_fake, +) + + +def deepseek_v4_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + if equation == "bhr,hdr->bhd" and b.dim() == 2: + num_groups = out.shape[1] + out_rank = out.shape[2] + hidden_size = a.shape[2] + b = b.view(num_groups, out_rank, hidden_size) + + if b_scale.dim() == 2: + scale_mn = recipe[1] + scale_k_pack = 4 if b_scale.dtype == torch.int32 else 1 + scale_k = recipe[2] * scale_k_pack + b_scale = b_scale.view( + num_groups, + (out_rank + scale_mn - 1) // scale_mn, + (hidden_size + scale_k - 1) // scale_k, + ) + + if _use_deepseek_v4_sm12_triton_fp8_einsum(equation, recipe, b_scale): + deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, out) + return + + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + + +def deepseek_v4_fp8_einsum_fake( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + return None + + +direct_register_custom_op( + op_name="deepseek_v4_fp8_einsum", + op_func=deepseek_v4_fp8_einsum, + mutates_args=["out"], + fake_impl=deepseek_v4_fp8_einsum_fake, +) + + +class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): + # FlashMLA FP8 sparse only supports 64 or 128 heads + SUPPORTED_HEAD_COUNTS = (64, 128) + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + compress_ratio: int, + window_size: int, + head_bytes: int, + swa_cache_layer: DeepseekV4SWACache, + attn_sink: torch.Tensor, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + # Sparse MLA Args + indexer: object | None = None, + topk_indices_buffer: torch.Tensor | None = None, + aux_stream: torch.cuda.Stream | None = None, + **extra_impl_args, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = 1 + self.head_dim = head_dim + self.scale = scale + self.window_size = window_size + self.head_bytes = head_bytes + self.compress_ratio = compress_ratio + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.nope_head_dim = qk_nope_head_dim + self.rope_head_dim = qk_rope_head_dim + self.indexer = indexer + self.topk_indices_buffer = topk_indices_buffer + + self.prefix = prefix # Alias for compatibility with compressor + + self.aux_stream = aux_stream + self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + + # Determine padded head count for FlashMLA + if num_heads not in self.SUPPORTED_HEAD_COUNTS: + if num_heads < 64: + self.padded_heads = 64 + elif num_heads < 128: + self.padded_heads = 128 + else: + raise ValueError( + f"DeepseekV4MLAAttention does not support {num_heads} heads. " + f"Supported: <= 128 (will be padded to 64 or 128)" + ) + else: + self.padded_heads = num_heads + + # Store attention sink + assert attn_sink is not None + self.attn_sink: torch.Tensor = attn_sink + # Store SWA cache + assert swa_cache_layer is not None + self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer + + # Get vllm config for cache setup + vllm_config = get_current_vllm_config() + self.max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + ) + self.max_model_len = vllm_config.model_config.max_model_len + # DeepseekV4 only supports fp8 kv-cache format for now + kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" + + assert kv_cache_dtype.startswith("fp8"), ( + f"DeepseekV4 only supports fp8 kv-cache format for now, " + f"got {kv_cache_dtype}" + ) + assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), ( + "Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now" + ) + # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format + # Automatically convert fp8 kv-cache format to "fp8_ds_mla" + if ( + issubclass(self.get_attn_backend(), FlashMLASparseBackend) + and kv_cache_dtype.startswith("fp8") + and kv_cache_dtype != "fp8_ds_mla" + ): + assert cache_config is not None + cache_config.cache_dtype = "fp8_ds_mla" + kv_cache_dtype = "fp8_ds_mla" + logger.info_once( + "Using DeepSeek's fp8_ds_mla KV cache format. To use standard " + "fp8 kv-cache format, please set `--attention-backend " + "FLASHINFER_MLA_SPARSE`" + ) + + self.kv_cache_dtype = kv_cache_dtype + + # Register with compilation context for metadata lookup + compilation_config = vllm_config.compilation_config + if prefix and prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + if prefix: + compilation_config.static_forward_context[prefix] = self + + self.kv_cache = torch.tensor([]) + + def get_attn_backend(self) -> type[AttentionBackend]: + return DeepseekV4FlashMLASparseBackend + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: + if ( + self.compress_ratio <= 1 + ): # SWA part. Allocated separately as DeepseekV4SWACache. + return None + return MLAAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=torch.uint8, + compress_ratio=self.compress_ratio, + cache_dtype_str=self.kv_cache_dtype, + alignment=576, # NOTE: FlashMLA requires 576B alignment + model_version="deepseek_v4", + ) + + def _forward_sparse_mla_swa_decode_reference( + self, + q: torch.Tensor, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if not mtp_decode: + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=self.num_heads, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + return + + ( + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ) + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_sparse_mla_attention_with_sink( + swa_max_score, + swa_denom, + swa_acc, + self.attn_sink, + output=output, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + + def _forward_sparse_mla_compressed_decode_reference( + self, + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + swa_k_cache: torch.Tensor, + topk_indices: torch.Tensor, + topk_lens: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata, + output: torch.Tensor, + ) -> None: + if self.compress_ratio not in (4, 128): + raise NotImplementedError( + "Sparse MLA reference compressed decode currently supports " + f"compress_ratio=4 or 128, got {self.compress_ratio}" + ) + + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + compressed_block_size = attn_metadata.block_size // self.compress_ratio + compressed_topk = topk_indices.shape[-1] + topk_chunk_size = min( + compressed_topk, + sparse_mla_reference_topk_chunk_size(), + ) + compressed_slot_ids = topk_indices[:, 0, :] + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if ( + not mtp_decode + and compressed_topk <= topk_chunk_size + and sparse_mla_matmul_decode_enabled() + ): + total_candidates = compressed_topk + max_swa_len + ( + combined_kv, + valid_tokens, + ) = current_workspace_manager().get_simultaneous( + ( + (num_decode_tokens, total_candidates, q.shape[-1]), + torch.bfloat16, + ), + ((num_decode_tokens, total_candidates), torch.bool), + ) + dequantize_combined_sparse_mla_decode_kv( + combined_kv, + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + swa_k_cache, + swa_metadata.seq_lens[:num_decodes], + swa_lens, + swa_metadata.block_table[:num_decodes], + swa_metadata.block_size, + ) + + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + matmul_sparse_mla_attention_with_sink( + q=q, + kv=combined_kv, + valid_tokens=valid_tokens, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + num_heads=self.num_heads, + ) + return + + if not mtp_decode and compressed_topk <= topk_chunk_size: + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + swa_block_size=swa_metadata.block_size, + num_compressed_candidates=compressed_topk, + num_swa_candidates=max_swa_len, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=self.num_heads, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + return + + ( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ) + comp_max_score.fill_(float("-inf")) + comp_denom.zero_() + comp_acc.zero_() + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + + for chunk_start in range(0, compressed_topk, topk_chunk_size): + chunk_end = min(chunk_start + topk_chunk_size, compressed_topk) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids[:, chunk_start:chunk_end], + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=chunk_start, + scale=self.scale, + max_score=comp_max_score, + denom=comp_denom, + acc=comp_acc, + head_block_size=head_block_size, + ) + if mtp_decode: + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + else: + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_two_sparse_mla_attention_states_with_sink( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + self.attn_sink, + output=output, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + + def _forward_sparse_mla_prefill_reference( + self, + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + output: torch.Tensor, + state_buffers: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = min( + combined_indices.shape[-1], + sparse_mla_reference_topk_chunk_size(), + ) + query_chunk_size = min( + q.shape[0], + sparse_mla_reference_query_chunk_size(), + ) + if state_buffers is None: + ( + max_score_buffer, + denom_buffer, + output_buffer, + ) = current_workspace_manager().get_simultaneous( + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32), + ) + else: + max_score_buffer, denom_buffer, output_buffer = state_buffers + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + num_tokens = token_end - token_start + max_score = max_score_buffer[:num_tokens] + denom = denom_buffer[:num_tokens] + subset_acc = output_buffer[:num_tokens] + max_score.fill_(float("-inf")) + denom.zero_() + subset_acc.zero_() + + for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=self.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + ) + + finish_sparse_mla_attention_with_sink( + max_score, + denom, + subset_acc, + self.attn_sink, + output=output[token_start:token_end], + ) + if output.shape[1] > self.num_heads: + output[token_start:token_end, self.num_heads :].zero_() + + def forward( + self, + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + output: torch.Tensor, + ) -> None: + assert output.shape == q.shape, ( + f"output buffer shape {output.shape} must match q shape {q.shape}" + ) + assert output.dtype == q.dtype, ( + f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" + ) + + # Get SWA and indexer metadata from forward context + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + assert isinstance(attn_metadata, dict) + flashmla_metadata = cast( + FlashMLASparseMetadata | None, attn_metadata.get(self.prefix) + ) + swa_metadata = cast( + "DeepseekSparseSWAMetadata | None", + attn_metadata.get(self.swa_cache_layer.prefix), + ) + assert swa_metadata is not None + + swa_only = self.compress_ratio <= 1 + # SWA-only layers (compress_ratio <= 1) don't have their own KV cache + # allocation, so self.kv_cache may be empty after profiling cleanup. + self_kv_cache = self.kv_cache if not swa_only else None + swa_kv_cache = self.swa_cache_layer.kv_cache + + # Split prefill and decode + num_decodes = swa_metadata.num_decodes + num_prefills = swa_metadata.num_prefills + num_decode_tokens = swa_metadata.num_decode_tokens + + if num_prefills > 0: + self._forward_prefill( + q=q[num_decode_tokens:], + positions=positions[num_decode_tokens:], + compressed_k_cache=self_kv_cache, + swa_k_cache=swa_kv_cache, + output=output[num_decode_tokens:], + attn_metadata=flashmla_metadata, + swa_metadata=swa_metadata, + ) + if num_decodes > 0: + self._forward_decode( + q=q[:num_decode_tokens], + kv_cache=self_kv_cache, + swa_metadata=swa_metadata, + attn_metadata=flashmla_metadata, + swa_only=swa_only, + output=output[:num_decode_tokens], + ) + + def _forward_decode( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1 + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + swa_only: bool, + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + + topk_indices = None + topk_lens = None + if not swa_only: + assert attn_metadata is not None + assert swa_metadata.is_valid_token is not None + block_size = attn_metadata.block_size // self.compress_ratio + is_valid = swa_metadata.is_valid_token[:num_decode_tokens] + if self.compress_ratio == 4: + # C4A: local indices differ per layer (filled by Indexer). + assert self.topk_indices_buffer is not None + global_indices, topk_lens = compute_global_topk_indices_and_lens( + self.topk_indices_buffer[:num_decode_tokens], + swa_metadata.token_to_req_indices, + attn_metadata.block_table[:num_decodes], + block_size, + is_valid, + ) + topk_indices = global_indices.view(num_decode_tokens, 1, -1) + else: + # C128A: pre-computed during metadata build. + topk_indices = attn_metadata.c128a_global_decode_topk_indices + topk_lens = attn_metadata.c128a_decode_topk_lens + + swa_indices = swa_metadata.decode_swa_indices + swa_lens = swa_metadata.decode_swa_lens + + # We treat queries in the same seq as different queries + # and later we only attend by generated indices. + # q arrives pre-padded to self.padded_heads by the outer wrapper. + q = q.unsqueeze(1) + + # Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes) + # Use unsqueeze to preserve strides (handles padded blocks correctly) + swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) + # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) + compressed_k_cache = kv_cache + if kv_cache is not None: + kv_cache = kv_cache.unsqueeze(-2) + + decode_fields = { + "kv_cache": _tensor_summary(kv_cache), + "swa_cache": _tensor_summary(swa_cache), + "topk_indices": _tensor_summary(topk_indices), + "topk_lens": _tensor_summary(topk_lens), + "swa_indices": _tensor_summary(swa_indices), + "swa_lens": _tensor_summary(swa_lens), + "attn_sink": _tensor_summary(self.attn_sink), + "scale": float(self.scale), + "swa_only": swa_only, + "padded_heads": int(self.padded_heads), + "num_decodes": int(num_decodes), + "num_decode_tokens": int(num_decode_tokens), + } + _write_sparse_mla_attention_state_if_enabled( + phase="decode", + prefix=self.prefix, + compress_ratio=self.compress_ratio, + q=q, + output=output, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + fields=decode_fields, + ) + + if is_sparse_mla_reference_attention_enabled(q.device): + if swa_only: + self._forward_sparse_mla_swa_decode_reference( + q=q, + swa_k_cache=self.swa_cache_layer.kv_cache, + swa_metadata=swa_metadata, + output=output, + ) + return + if self.compress_ratio in (4, 128): + assert compressed_k_cache is not None + assert attn_metadata is not None + assert topk_indices is not None + assert topk_lens is not None + self._forward_sparse_mla_compressed_decode_reference( + q=q, + compressed_k_cache=compressed_k_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + output=output, + ) + return + _dump_sparse_mla_attention_state( + phase="decode_unsupported_compressed", + prefix=self.prefix, + compress_ratio=self.compress_ratio, + q=q, + output=output, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + fields=decode_fields, + ) + + # One FlashMLASchedMeta per layer type, shared across all same-type + # layers within this decode step. The first forward call per type + # triggers the in-kernel planner (allocating tile_scheduler_metadata + # and num_splits via PyTorch's graph-aware allocator so CUDA graph + # capture reuses the same addresses on replay); subsequent same-type + # layers see have_initialized=True and skip the planner. + if self.compress_ratio <= 1: + tile_metadata = swa_metadata.tile_sched_swaonly + elif self.compress_ratio == 4: + tile_metadata = swa_metadata.tile_sched_c4a + elif self.compress_ratio == 128: + tile_metadata = swa_metadata.tile_sched_c128a + else: + raise ValueError( + f"Unsupported compress_ratio={self.compress_ratio}; " + "expected 1, 4, or 128." + ) + assert tile_metadata is not None, ( + "swa_metadata missing tile_sched entry for " + f"compress_ratio={self.compress_ratio}; " + "DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not " + "allocate one for this layer type." + ) + + out, _ = flash_mla_with_kvcache( + q=q, + k_cache=swa_cache, + block_table=None, + head_dim_v=512, + tile_scheduler_metadata=tile_metadata, + cache_seqlens=None, + is_fp8_kvcache=True, + indices=swa_indices, + topk_length=swa_lens, + softmax_scale=self.scale, + attn_sink=self.attn_sink, + extra_k_cache=kv_cache if not swa_only else None, + extra_indices_in_kvcache=topk_indices, + extra_topk_length=topk_lens, + out=output.unsqueeze(1), + ) + + def _forward_prefill( + self, + q: torch.Tensor, + positions: torch.Tensor, + compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1 + swa_k_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashMLASparseMetadata | None, + swa_metadata: "DeepseekSparseSWAMetadata", + ) -> None: + swa_only = attn_metadata is None + + num_prefills = swa_metadata.num_prefills + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + + # Use pre-computed prefill metadata. + seq_lens = swa_metadata.prefill_seq_lens + gather_lens = swa_metadata.prefill_gather_lens + seq_lens_cpu = swa_metadata.prefill_seq_lens_cpu + gather_lens_cpu = swa_metadata.prefill_gather_lens_cpu + assert seq_lens is not None + assert gather_lens is not None + assert seq_lens_cpu is not None + assert gather_lens_cpu is not None + + # Derive prefill-local token offsets from the full query_start_loc_cpu. + query_start_loc_cpu = swa_metadata.query_start_loc_cpu + query_start_loc = swa_metadata.query_start_loc + assert query_start_loc_cpu is not None + assert query_start_loc is not None + prefill_token_base = query_start_loc_cpu[num_decodes] + + if not swa_only: + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[num_decode_tokens:] + topk_indices = topk_indices[:num_prefill_tokens] + else: + # C128A: pre-computed during metadata build. + assert attn_metadata is not None + topk_indices = attn_metadata.c128a_prefill_topk_indices + top_k = topk_indices.shape[-1] + else: + # NOTE(woosuk): topk_indices will not be used for SWA-only layers. + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[num_decode_tokens:] + top_k = 0 + + N, M = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=self.compress_ratio, + swa_only=swa_only, + ) + num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE + + workspace_manager = current_workspace_manager() + reference_attention_enabled = is_sparse_mla_reference_attention_enabled( + q.device + ) + if reference_attention_enabled: + query_chunk_size = min( + q.shape[0], sparse_mla_reference_query_chunk_size() + ) + ( + kv, + max_score_buffer, + denom_buffer, + output_buffer, + ) = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32), + ) + prefill_state_buffers = ( + max_score_buffer, + denom_buffer, + output_buffer, + ) + else: + kv = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + )[0] + prefill_state_buffers = None + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * PREFILL_CHUNK_SIZE + chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) + chunk_size = chunk_end - chunk_start + if not swa_only: + # Gather compressed KV + assert attn_metadata is not None + block_table = attn_metadata.block_table[num_decodes:] + dequantize_and_gather_k_cache( + kv[:chunk_size], + compressed_k_cache, + seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio, + gather_lens=None, + block_table=block_table[chunk_start:chunk_end], + block_size=attn_metadata.block_size // self.compress_ratio, + offset=0, + ) + + # Gather SWA KV + swa_block_table = swa_metadata.block_table[num_decodes:] + dequantize_and_gather_k_cache( + kv[:chunk_size], + swa_k_cache, + seq_lens=seq_lens[chunk_start:chunk_end], + gather_lens=gather_lens[chunk_start:chunk_end], + block_table=swa_block_table[chunk_start:chunk_end], + block_size=swa_metadata.block_size, + offset=N, + ) + + # Combine the topk indices and SWA indices for gathered KV cache + query_start = ( + query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base + ) + query_end = ( + query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base + ) + + combined_indices, combined_lens = combine_topk_swa_indices( + topk_indices[query_start:query_end], + query_start_loc[ + num_decodes + chunk_start : num_decodes + chunk_end + 1 + ], + seq_lens[chunk_start:chunk_end], + gather_lens[chunk_start:chunk_end], + self.window_size, + self.compress_ratio, + top_k, + M, + N, + ) + + if is_sparse_mla_attention_dump_enabled(): + _dump_sparse_mla_attention_state( + phase="prefill", + prefix=self.prefix, + compress_ratio=self.compress_ratio, + q=q[query_start:query_end], + output=output[query_start:query_end], + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + fields={ + "compressed_k_cache": _tensor_summary(compressed_k_cache), + "swa_k_cache": _tensor_summary(swa_k_cache), + "gathered_kv": _tensor_summary(kv[:chunk_size]), + "topk_indices": _tensor_summary(topk_indices), + "combined_indices": _tensor_summary(combined_indices), + "combined_lens": _tensor_summary(combined_lens), + "attn_sink": _tensor_summary(self.attn_sink), + "scale": float(self.scale), + "swa_only": swa_only, + "chunk_start": int(chunk_start), + "chunk_end": int(chunk_end), + "query_start": int(query_start), + "query_end": int(query_end), + }, + ) + + if reference_attention_enabled: + self._forward_sparse_mla_prefill_reference( + q=q[query_start:query_end], + kv=kv[:chunk_size], + combined_indices=combined_indices, + combined_lens=combined_lens, + output=output[query_start:query_end], + state_buffers=prefill_state_buffers, + ) + continue + + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) + + +class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): + def __init__( + self, + head_dim: int, + dtype: torch.dtype, + prefix: str, + cache_config: CacheConfig, + compress_ratio: int = 1, + ): + super().__init__() + self.kv_cache = torch.tensor([]) + self.head_dim = head_dim + self.prefix = prefix + self.cache_config = cache_config + self.dtype = dtype + self.compress_ratio = compress_ratio + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # head_dim already carries the fp8 scale padding + # compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout. + return MLAAttentionSpec( + block_size=self.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + compress_ratio=self.compress_ratio, + # DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with + # the indexer's compressor state cache. V3.2 keeps the legacy layout. + alignment=576, + ) + + def forward(self): ... + + def get_attn_backend(self) -> type[AttentionBackend]: + return DeepseekV4IndexerBackend + + +class DeepseekV4Indexer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + q_lora_rank: int, + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, + topk_indices_buffer: torch.Tensor | None, + compress_ratio: int = 1, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.config = config + self.quant_config = quant_config + # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] + self.topk_tokens = config.index_topk + self.n_head = config.index_n_heads # 64 + self.head_dim = config.index_head_dim # 128 + self.rope_dim = config.qk_rope_head_dim # 64 + self.q_lora_rank = q_lora_rank # 1536 + self.compress_ratio = compress_ratio + self.use_fp4_kv = self.vllm_config.attention_config.use_fp4_indexer_cache + logger.info_once( + "Using %s indexer cache for Lighening Indexer.", + "MXFP4" if self.use_fp4_kv else "FP8", + ) + + # no tensor parallel, just replicated + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + ) + self.weights_proj = ReplicatedLinear( + hidden_size, + self.n_head, + bias=False, + quant_config=None, + prefix=f"{prefix}.weights_proj", + ) + self.k_norm = LayerNorm(self.head_dim, eps=1e-6) + self.softmax_scale = self.head_dim**-0.5 + + self.scale_fmt = "ue8m0" + self.quant_block_size = 128 # TODO: get from config + self.topk_indices_buffer = topk_indices_buffer + + self.max_model_len = ( + vllm_config.model_config.max_model_len // self.compress_ratio + ) + self.prefix = prefix + + self.max_total_seq_len = ( + get_max_prefill_buffer_size(vllm_config) // self.compress_ratio + ) + + assert cache_config is not None, "Deepseek V4 indexer requires cache_config" + # NOTE(yifan): FP8 indxer cache use the same layout as V3.2: + # head_dim bytes = 128 fp8 + 4 fp32 scale = 132. + # For FP4 indexer cache, we still allocate the same amount of memory as FP8, + # but only use the first half of the memory. + k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4 + self.k_cache = DeepseekV4IndexerCache( + head_dim=k_cache_head_dim, + dtype=torch.uint8, + prefix=f"{prefix}.k_cache", + cache_config=cache_config, + compress_ratio=self.compress_ratio, + ) + self.compressor = DeepseekCompressor( + vllm_config=vllm_config, + compress_ratio=self.compress_ratio, + hidden_size=hidden_size, + head_dim=self.head_dim, + rotate=True, + prefix=f"{prefix}.compressor", + k_cache_prefix=self.k_cache.prefix, + use_fp4_cache=self.use_fp4_kv, + ) + + self.indexer_op = SparseAttnIndexer( + self.k_cache, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + skip_k_cache_insert=True, + use_fp4_cache=self.use_fp4_kv, + ) + + def forward( + self, + hidden_states: torch.Tensor, + qr: torch.Tensor, + positions: torch.Tensor, + rotary_emb: nn.Module, + ) -> torch.Tensor: + q, _ = self.wq_b(qr) + q = q.view(-1, self.n_head, self.head_dim) + k = self.compressor(hidden_states, positions, rotary_emb) + weights, _ = self.weights_proj(hidden_states) + q_quant, weights = fused_indexer_q_rope_quant( + positions, + q, + rotary_emb.cos_sin_cache, + weights, + self.softmax_scale, + self.n_head**-0.5, + use_fp4=self.use_fp4_kv, + ) + return self.indexer_op(hidden_states, q_quant, k, weights) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 997900971539..f958a6322e38 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -117,8 +117,10 @@ class RoutingMethodType(IntEnum): Custom = (6,) # Simulated Simulated = (7,) + # Deepseek V4 -> sqrtsoftplus + Bias + Normalize + DeepseekV4 = (8,) # Unspecified - Unspecified = 8.0 + Unspecified = 9.0 def get_routing_method_type( @@ -128,6 +130,14 @@ def get_routing_method_type( num_expert_group: int | None, has_e_score_bias: bool, ) -> RoutingMethodType: + if scoring_func == "sqrtsoftplus": + # DeepSeek V4 uses sqrtsoftplus routing with optional routing bias + # and top-k renormalization. + if renormalize: + return RoutingMethodType.DeepseekV4 + else: + return RoutingMethodType.Unspecified + if has_e_score_bias: if (num_expert_group or 0) > 0 and scoring_func == "sigmoid": return RoutingMethodType.DeepSeekV3 @@ -230,6 +240,13 @@ class FusedMoEQuantConfig: _w2: FusedMoEQuantDesc is_nvfp4_scale_swizzled: bool = True + # MXFP4-specific TRTLLM parameters for SwiGLU activation clamping. + # These correspond to gemm1_alpha, gemm1_beta, gemm1_clamp_limit + # in TrtLlmMxfp4ExpertsBase. + gemm1_alpha: float | None = None + gemm1_beta: float | None = None + gemm1_clamp_limit: float | None = None + def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( "illegal quantization" @@ -477,6 +494,9 @@ def make( w2_zp: torch.Tensor | None = None, weight_dtype: torch.dtype | str | None = None, is_nvfp4_scale_swizzled: bool = True, + gemm1_alpha: float | None = None, + gemm1_beta: float | None = None, + gemm1_clamp_limit: float | None = None, ) -> "FusedMoEQuantConfig": """ General builder function for a FusedMoEQuantConfig. @@ -507,6 +527,9 @@ def make( - w1_zp: Optional w1 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization. - is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling. + - gemm1_alpha: Optional MXFP4 TRTLLM SwiGLU alpha parameter. + - gemm1_beta: Optional MXFP4 TRTLLM SwiGLU beta parameter. + - gemm1_clamp_limit: Optional MXFP4 TRTLLM SwiGLU clamp limit. """ assert not isinstance(quant_dtype, str) or quant_dtype in { "nvfp4", @@ -540,6 +563,9 @@ def make( weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias ), is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, ) assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_out_ch_quant == per_out_ch_quant @@ -650,6 +676,9 @@ def mxfp4_w4a16_moe_quant_config( w2_scale: Union[torch.Tensor, "PrecisionConfig"], w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + gemm1_alpha: float | None = None, + gemm1_beta: float | None = None, + gemm1_clamp_limit: float | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for unquantized activations and mxfp4 weights. @@ -659,6 +688,9 @@ def mxfp4_w4a16_moe_quant_config( _a2=FusedMoEQuantDesc(), _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, ) @@ -670,6 +702,9 @@ def mxfp4_mxfp8_moe_quant_config( w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, block_shape: list[int] | None = None, + gemm1_alpha: float | None = None, + gemm1_beta: float | None = None, + gemm1_clamp_limit: float | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and mxfp4 weights. @@ -679,6 +714,9 @@ def mxfp4_mxfp8_moe_quant_config( _a2=FusedMoEQuantDesc("mxfp8"), _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, ) @@ -712,6 +750,9 @@ def ocp_mx_moe_quant_config( w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, block_shape: list[int] | None = None, + gemm1_alpha: float | None = None, + gemm1_beta: float | None = None, + gemm1_clamp_limit: float | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and mxfp4 weights. @@ -729,6 +770,9 @@ def ocp_mx_moe_quant_config( per_act_token_quant=False, per_out_ch_quant=False, block_shape=block_shape, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, ) diff --git a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py index 03341378a13c..b4394b5fd382 100644 --- a/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/deep_gemm_moe.py @@ -25,15 +25,20 @@ per_token_group_quant_fp8_packed_for_deepgemm, silu_mul_per_token_group_quant_fp8_colmajor, ) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + silu_mul_quant_fp8_packed_triton as fused_silu_mul_fp8_quant_packed, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8Dynamic128Sym, kFp8Static128BlockSym, + kMxfp4Static, ) from vllm.utils.deep_gemm import ( DeepGemmQuantScaleFMT, get_mk_alignment_for_contiguous_layout, is_deep_gemm_supported, + m_grouped_fp8_fp4_gemm_nt_contiguous, m_grouped_fp8_gemm_nt_contiguous, ) from vllm.utils.import_utils import has_deep_gemm @@ -197,8 +202,14 @@ def _act_mul_quant( M_sum, N = input.size() activation_out_dim = self.adjust_N_for_activation(N, activation) - # 1. DeepGemm UE8M0: use packed per-token-group quant + # 1. DeepGemm UE8M0: fused SiLU+mul+clamp+quant+pack if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: + if activation == MoEActivation.SILU: + return fused_silu_mul_fp8_quant_packed( + input=input, + output_q=output, + group_size=block_k, + ) act_out = torch.empty( (M_sum, activation_out_dim), dtype=input.dtype, device=input.device ) @@ -312,3 +323,225 @@ def apply( expert_map=expert_map, output=output, ) + + +class DeepGemmFP4Experts(mk.FusedMoEExpertsModular): + """DeepGemm-based fused MoE expert implementation for FP4 weights. + + Uses m_grouped_fp8_fp4_gemm_nt_contiguous with FP8 activations and + MXFP4 (FP4 E2M1 packed as uint8) weights. Requires SM100+ (Blackwell). + """ + + # FP8 activation block size (hardcoded since mxfp4_w4a8 quant config + # does not set a block_shape on the activation descriptor). + _ACT_BLOCK_K = 128 + # FP4 weight block size + _WEIGHT_BLOCK_K = 32 + + def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): + super().__init__(moe_config=moe_config, quant_config=quant_config) + assert quant_config.weight_quant_dtype == "mxfp4" + assert not quant_config.per_act_token_quant + assert not quant_config.per_out_ch_quant + + self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + from vllm.platforms import current_platform + + return ( + is_deep_gemm_supported() + and current_platform.is_device_capability_family(100) + ) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kMxfp4Static, kFp8Dynamic128Sym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP] + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + return not ( + moe_parallel_config.use_fi_nvl_two_sided_kernels + or moe_parallel_config.use_fi_nvl_one_sided_kernels + ) + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: MoEActivation, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + block_m = get_mk_alignment_for_contiguous_layout()[0] + M_sum = compute_aligned_M( + M, topk, local_num_experts, block_m, expert_tokens_meta + ) + assert M_sum % block_m == 0 + + activation_out_dim = self.adjust_N_for_activation(N, activation) + workspace1 = (M_sum, max(activation_out_dim, K)) + workspace2 = (M_sum, max(N, K)) + output = (M, K) + return (workspace1, workspace2, output) + + def _act_mul_quant( + self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation + ) -> tuple[torch.Tensor, torch.Tensor]: + block_k = self._ACT_BLOCK_K + scale_fmt = DeepGemmQuantScaleFMT.from_oracle() + + M_sum, N = input.size() + activation_out_dim = self.adjust_N_for_activation(N, activation) + + if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: + assert activation == MoEActivation.SILU + return fused_silu_mul_fp8_quant_packed( + input=input, + output_q=output, + group_size=block_k, + clamp_limit=self.gemm1_clamp_limit, + ) + + if activation == MoEActivation.SILU: + use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 + return silu_mul_per_token_group_quant_fp8_colmajor( + input=input, + output=output, + use_ue8m0=use_ue8m0, + ) + + act_out = torch.empty( + (M_sum, activation_out_dim), dtype=input.dtype, device=input.device + ) + self.activation(activation, act_out, input) + return per_token_group_quant_fp8( + act_out, block_k, column_major_scales=True, out_q=output + ) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert a1q_scale is not None + assert a2_scale is None + assert self.w1_scale is not None + assert self.w2_scale is not None + + a1q = hidden_states + _, N, _ = w1.size() + # K comes from activations (full hidden dim), not from w1 which is + # packed FP4 (E, N, K//2). + K = a1q.size(1) + + local_num_experts = w1.size(0) + if global_num_experts == -1: + global_num_experts = local_num_experts + + M_sum = compute_aligned_M( + M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=get_mk_alignment_for_contiguous_layout()[0], + expert_tokens_meta=expert_tokens_meta, + ) + + a1q_perm = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K) + ) + a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( + aq=a1q, + aq_scale=a1q_scale, + topk_ids=topk_ids, + local_num_experts=local_num_experts, + expert_map=expert_map, + expert_tokens_meta=expert_tokens_meta, + aq_out=a1q_perm, + ) + assert a1q.size(0) == M_sum + + # FC1: FP8 activations x FP4 weights + # DeepGEMM 2.4.2 requires FP4-packed weights as int8 (kPackedFP4). + mm1_out = _resize_cache(workspace2, (M_sum, N)) + m_grouped_fp8_fp4_gemm_nt_contiguous( + (a1q, a1q_scale), + (w1.view(torch.int8), self.w1_scale), + mm1_out, + expert_ids, + recipe_a=(1, self._ACT_BLOCK_K), + recipe_b=(1, self._WEIGHT_BLOCK_K), + ) + + # SwiGLU activation + FP8 requant + activation_out_dim = self.adjust_N_for_activation(N, activation) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim) + ) + a2q, a2q_scale = self._act_mul_quant( + input=mm1_out.view(-1, N), output=quant_out, activation=activation + ) + + # FC2: FP8 activations x FP4 weights + mm2_out = _resize_cache(workspace2, (M_sum, K)) + m_grouped_fp8_fp4_gemm_nt_contiguous( + (a2q, a2q_scale), + (w2.view(torch.int8), self.w2_scale), + mm2_out, + expert_ids, + recipe_a=(1, self._ACT_BLOCK_K), + recipe_b=(1, self._WEIGHT_BLOCK_K), + ) + + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) + + deepgemm_unpermute_and_reduce( + a=mm2_out, + topk_ids=topk_ids, + topk_weights=topk_weights, + inv_perm=inv_perm, + expert_map=expert_map, + output=output, + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index 41898060dc3c..ac317ac7762c 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -28,8 +28,162 @@ from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels +from ..utils import swiglu_limit_func + logger = init_logger(__name__) + +def _patch_make_bitmatrix_metadata() -> None: + """Monkey-patch make_bitmatrix_metadata to support non-power-of-2 top_k. + + triton's tl.arange requires a power-of-2 range. The original kernel + computes BLOCK_SIZE = BLOCK_PER_TOK * TOKS_PER_ROW (= 32 * top_k). For + DeepSeek-V4 with top_k=6 this gives 192, which is not a power of 2 and + causes a compile error at the first forward pass. + + Fix: define a drop-in replacement kernel that accepts an extra constexpr + BLOCK_SIZE_PADDED (next power of 2 >= BLOCK_SIZE) and uses it for the + tl.arange call while keeping the actual BLOCK_SIZE as the stride between + thread-blocks so that all flat indices into NonzeroIndx stay correct. + Elements beyond BLOCK_SIZE are masked out (col_indx = 0xffff) and ignored. + + This function is called once at module load time and patches the function + inside the triton_kernels tensor module so that SparseMatrix.__post_init__ + picks up the fixed version transparently. + """ + import torch + import triton + import triton.language as tl + + try: + from vllm.third_party.triton_kernels.tensor_details import ( + bitmatrix as _bm, + ) + from vllm.third_party.triton_kernels.tensor_details.bitmatrix import ( + BitmatrixMetadata, + _keyed_add, + cdiv, + ) + from vllm.third_party.triton_kernels.tensor_details.bitmatrix_details.sum_bitmatrix_rows import ( # noqa: E501 + sum_bitmatrix_rows, + ) + except ImportError: + return + + @triton.jit + def _stage2_pow2( + ColSortedIndx, + RowSortedIndx, + NonzeroIndx, + n_tokens, + ColPartialSum, + stride_pm, + stride_pn, + ColOffs, + TOKS_PER_ROW: tl.constexpr, + BLOCK_PER_TOK: tl.constexpr, + BLOCK_SIZE_PADDED: tl.constexpr, + ): + # Actual number of elements per block (may not be a power of 2). + BLOCK_SIZE: tl.constexpr = BLOCK_PER_TOK * TOKS_PER_ROW + tl.static_assert(BLOCK_SIZE_PADDED <= 32768) + if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr(): + n_tokens = tl.load(n_tokens) + nonzero_indx_size = n_tokens * TOKS_PER_ROW + pid_m = tl.program_id(0) + # Use BLOCK_SIZE_PADDED (a power of 2) for tl.arange, but stride by + # the actual BLOCK_SIZE so flat positions in NonzeroIndx are correct. + # Elements with offs_local >= BLOCK_SIZE have offs_global beyond the + # valid range, get col_indx = 0xffff, and are filtered by the mask + # below without producing any output. + offs_local = tl.arange(0, BLOCK_SIZE_PADDED) + offs_global = pid_m * BLOCK_SIZE + offs_local + mask = offs_global < nonzero_indx_size + col_indx = tl.load(NonzeroIndx + offs_global, mask=mask, other=-1).to(tl.uint32) + kv_pairs = ((col_indx << 16) | offs_local).to(tl.uint32) + kv_pairs = tl.sort(kv_pairs, 0) + col_indx = kv_pairs >> 16 + offs_global = pid_m * BLOCK_SIZE + (kv_pairs & 0xFFFF) + mask = col_indx != 0xFFFF + x = kv_pairs & 0xFFFF0000 | 0x00000001 + cols_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add) + exclusive_run_lengths = (cols_and_inclusive_run_lengths - 1) & 0xFFFF + row_sorted_indx = tl.load( + ColPartialSum + pid_m * stride_pm + col_indx * stride_pn, mask=mask + ) + row_sorted_indx += tl.load(ColOffs + col_indx, mask=mask) + row_sorted_indx += exclusive_run_lengths + tl.store(RowSortedIndx + offs_global, row_sorted_indx, mask=mask) + tl.store(ColSortedIndx + row_sorted_indx, offs_global, mask=mask) + + def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix): + assert nonzero_indx.ndim == 2 + PARTIAL_BLOCK_M = 32 + col_sum, col_partial_sum = sum_bitmatrix_rows( + bitmatrix, partials_block_size=PARTIAL_BLOCK_M + ) + device = bitmatrix.device + n_indx = nonzero_indx.numel() + n_cols = bitmatrix.shape[1] + col_offs = torch.empty(n_cols, dtype=torch.int32, device=device) + combined_indx = torch.empty(n_indx * 2, dtype=torch.int32, device=device) + col_sorted_indx = combined_indx[:n_indx] + row_sorted_indx = combined_indx[n_indx:] + MEMSET_BLOCK = 1024 + memset_grid = (cdiv(n_indx * 2, MEMSET_BLOCK) + n_cols + 1,) + _bm._bitmatrix_metadata_compute_stage1[memset_grid]( + combined_indx, + n_indx * 2, + -1, + MEMSET_BLOCK, + col_sum, + col_offs, + col_sum.shape[0], + col_partial_sum, + col_partial_sum.shape[0], + col_partial_sum.stride(0), + col_partial_sum.stride(1), + BLOCK_M=512, + BLOCK_N=512, + ) + toks_per_row = nonzero_indx.shape[-1] + block_size = PARTIAL_BLOCK_M * toks_per_row + # Next power of 2 >= block_size (required by tl.arange). + block_size_padded = 1 << (max(block_size, 1) - 1).bit_length() + compute_grid = (cdiv(bitmatrix.shape_max[0], PARTIAL_BLOCK_M),) + _stage2_pow2[compute_grid]( + col_sorted_indx, + row_sorted_indx, + nonzero_indx, + bitmatrix.shape[0], + col_partial_sum, + col_partial_sum.stride(0), + col_partial_sum.stride(1), + col_offs, + TOKS_PER_ROW=toks_per_row, + BLOCK_PER_TOK=PARTIAL_BLOCK_M, + BLOCK_SIZE_PADDED=block_size_padded, + ) + return BitmatrixMetadata( + col_sum=col_sum, + col_sorted_indx=col_sorted_indx, + row_sorted_indx=row_sorted_indx, + ) + + # The most reliable patch point: SparseMatrix.__post_init__ looks up + # make_bitmatrix_metadata via its own __globals__ dict (the tensor.py + # module dict). Patching through __globals__ works regardless of how + # sys.modules maps "triton_kernels.tensor" vs + # "vllm.third_party.triton_kernels.tensor". + from triton_kernels.tensor import SparseMatrix as _SparseMatrix + + _SparseMatrix.__post_init__.__globals__["make_bitmatrix_metadata"] = ( + _make_bitmatrix_metadata_pow2_safe + ) + # Also patch the bitmatrix module itself in case it is imported directly. + _bm.make_bitmatrix_metadata = _make_bitmatrix_metadata_pow2_safe + + use_legacy_triton_kernels = False if has_triton_kernels(): @@ -59,6 +213,8 @@ use_legacy_triton_kernels = True else: raise + if not use_legacy_triton_kernels: + _patch_make_bitmatrix_metadata() except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -497,6 +653,8 @@ def _supports_current_device() -> bool: return False # (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell) # and ROCm gfx942/gfx950 (which map to 9.4/9.5). + if not has_triton_kernels(): + return False return (9, 0) <= (cap.major, cap.minor) < (11, 0) @staticmethod @@ -698,6 +856,37 @@ def workspace_shapes( def moe_sum(self, input: torch.Tensor, output: torch.Tensor): ops.moe_sum(input, output) + def activation( + self, + activation: MoEActivation, + output: torch.Tensor, + input: torch.Tensor, + ) -> None: + quant_config = self.quant_config or FUSED_MOE_UNQUANTIZED_CONFIG + if activation == MoEActivation.SWIGLUOAI: + alpha = ( + quant_config.gemm1_alpha + if quant_config.gemm1_alpha is not None + else 1.702 + ) + limit = ( + quant_config.gemm1_clamp_limit + if quant_config.gemm1_clamp_limit is not None + else 7.0 + ) + torch.ops._C.swigluoai_and_mul(output, input, alpha, limit) + elif ( + activation == MoEActivation.SILU + and quant_config.gemm1_clamp_limit is not None + ): + swiglu_limit_func( + output, + input, + quant_config.gemm1_clamp_limit, + ) + else: + super().activation(activation, output, input) + def apply( self, output: torch.Tensor, @@ -812,9 +1001,9 @@ def apply( act_input, ) - # matmul_ogs grouped reduction fuse sum across multiple experts: + # matmul_ogs grouped reduction fuses sum across multiple experts: # y[dst_indx // n_expts_act, :] += x - # Need to set n_expts_act to 1 to unfuse moe_sum + # Set n_expts_act to 1 to unfuse the sum so we can do it manually via moe_sum. routing_data.n_expts_act = 1 matmul_ogs( @@ -878,6 +1067,8 @@ def _supports_current_device() -> bool: return False # (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell) # and ROCm gfx942/gfx950 (which map to 9.4/9.5). + if not has_triton_kernels(): + return False return (9, 0) <= (cap.major, cap.minor) < (11, 0) @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py index d084283360c4..f7af9aea70ad 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kMxfp4Static, @@ -32,10 +33,8 @@ def __init__( self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, + **kwargs, ): - # NOTE: FusedMoEExperts.__init__ is called by the concrete subclass - # (Monolithic/Modular) via MRO, not here, to avoid mypy issues with - # multiple inheritance. This matches the NvFP4 expert pattern. self.moe_config = moe_config self.quant_config = quant_config @@ -48,23 +47,34 @@ def __init__( self.local_num_experts = moe_config.num_local_experts self.ep_rank = moe_config.moe_parallel_config.ep_rank - # MXFP4-specific TRTLLM parameters + # MXFP4-specific TRTLLM parameters from quant_config device = torch.accelerator.current_device_index() - self.gemm1_alpha = torch.tensor( - [1.702] * self.local_num_experts, - dtype=torch.float32, - device=device, - ) - self.gemm1_beta = torch.tensor( - [1.0] * self.local_num_experts, - dtype=torch.float32, - device=device, - ) - self.gemm1_clamp_limit = torch.tensor( - [7.0] * self.local_num_experts, - dtype=torch.float32, - device=device, - ) + if quant_config.gemm1_alpha is not None: + self.gemm1_alpha = torch.tensor( + [quant_config.gemm1_alpha] * self.local_num_experts, + dtype=torch.float32, + device=device, + ) + else: + self.gemm1_alpha = None + + if quant_config.gemm1_beta is not None: + self.gemm1_beta = torch.tensor( + [quant_config.gemm1_beta] * self.local_num_experts, + dtype=torch.float32, + device=device, + ) + else: + self.gemm1_beta = None + + if quant_config.gemm1_clamp_limit is not None: + self.gemm1_clamp_limit = torch.tensor( + [quant_config.gemm1_clamp_limit] * self.local_num_experts, + dtype=torch.float32, + device=device, + ) + else: + self.gemm1_clamp_limit = None from vllm.config import get_current_vllm_config @@ -97,7 +107,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: MoEActivation) -> bool: - return activation == MoEActivation.SWIGLUOAI + return activation in (MoEActivation.SWIGLUOAI, MoEActivation.SILU) @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -190,36 +200,41 @@ def apply( output = torch.empty_like(hidden_states) - return trtllm_fp4_block_scale_moe( - routing_logits=router_logits.to(torch.bfloat16), - routing_bias=None, - hidden_states=x_quant, - hidden_states_scale=x_scale, - gemm1_weights=w1, - gemm1_weights_scale=self.w1_scale, - gemm1_bias=self.w1_bias, - gemm1_alpha=self.gemm1_alpha, - gemm1_beta=self.gemm1_beta, - gemm1_clamp_limit=self.gemm1_clamp_limit, - gemm2_weights=w2, - gemm2_weights_scale=self.w2_scale, - gemm2_bias=self.w2_bias, - output1_scale_scalar=None, - output1_scale_gate_scalar=None, - output2_scale_scalar=None, - num_experts=global_num_experts, - top_k=self.topk, - n_group=None, - topk_group=None, - intermediate_size=self.intermediate_size_per_partition, - local_expert_offset=self.ep_rank * self.local_num_experts, - local_num_experts=self.local_num_experts, - routed_scaling_factor=None, - routing_method_type=self.routing_method_type, - do_finalize=True, - tune_max_num_tokens=max(self.max_capture_size, 1), - output=output, - )[0] + from vllm.utils.flashinfer import _is_fi_autotuning, autotune + + with autotune(_is_fi_autotuning): + trtllm_fp4_block_scale_moe( + routing_logits=router_logits.to(torch.bfloat16), + routing_bias=None, + hidden_states=x_quant, + hidden_states_scale=x_scale, + gemm1_weights=w1, + gemm1_weights_scale=self.w1_scale, + gemm1_bias=self.w1_bias, + gemm1_alpha=self.gemm1_alpha, + gemm1_beta=self.gemm1_beta, + gemm1_clamp_limit=self.gemm1_clamp_limit, + gemm2_weights=w2, + gemm2_weights_scale=self.w2_scale, + gemm2_bias=self.w2_bias, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=global_num_experts, + top_k=self.topk, + n_group=None, + topk_group=None, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=None, + routing_method_type=self.routing_method_type, + do_finalize=True, + tune_max_num_tokens=max(self.max_capture_size, 1), + output=output, + ) + + return output class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModular): @@ -239,6 +254,16 @@ def _supports_parallel_config( ) -> bool: return True + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # Modular kernel handles only the expert computation; + # routing is done externally, so accept any routing method. + return True + def supports_expert_map(self) -> bool: return True @@ -282,7 +307,7 @@ def apply( ): topk = topk_ids.size(-1) local_num_experts = w1.size(0) - intermediate_size = w2.size(1) + intermediate_size = self.intermediate_size_per_partition local_expert_offset = self.moe_config.ep_rank * local_num_experts # Handle input quantization @@ -302,9 +327,8 @@ def apply( x_quant = hidden_states x_scale = None - packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( - torch.bfloat16 - ).view(torch.int16) + # Pack topk ids and weights into format expected by the kernel. + packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) assert self.w1_scale is not None assert self.w2_scale is not None @@ -333,7 +357,10 @@ def apply( "local_expert_offset": local_expert_offset, "local_num_experts": local_num_experts, "routed_scaling_factor": None, - "routing_method_type": self.routing_method_type, + # Modular kernel receives pre-routed tokens, so routing + # is already done. Use Renormalize as a safe default that + # the TRTLLM C++ kernel supports. + "routing_method_type": RoutingMethodType.Renormalize, "do_finalize": True, "output": output, "tune_max_num_tokens": max(self.max_capture_size, 1), @@ -341,12 +368,9 @@ def apply( from flashinfer import trtllm_fp4_block_scale_routed_moe - from vllm.utils.flashinfer import autotune + from vllm.utils.flashinfer import _is_fi_autotuning, autotune - with autotune(False): - # Enable autotune when, - # https://github.com/flashinfer-ai/flashinfer/issues/2023 is - # resolved. + with autotune(_is_fi_autotuning): trtllm_fp4_block_scale_routed_moe(**kwargs) return output diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 5a085847fcd1..ebd330197099 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -50,6 +50,8 @@ from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +from .utils import swiglu_limit_func + def _fused_marlin_moe( hidden_states: torch.Tensor, @@ -88,6 +90,7 @@ def _fused_marlin_moe( output: torch.Tensor | None = None, input_dtype: torch.dtype | None = None, is_k_full: bool = True, + clamp_limit: float | None = None, ) -> torch.Tensor: assert hidden_states.ndim == 2 M, K = hidden_states.size() @@ -155,11 +158,18 @@ def _fused_marlin_moe( use_fp32_reduce=True, is_zp_float=False, ) - activation_func( - activation, - intermediate_cache2, - intermediate_cache1.view(-1, w13_num_shards * N), - ) + if clamp_limit is not None and activation == MoEActivation.SILU: + swiglu_limit_func( + intermediate_cache2, + intermediate_cache1.view(-1, w13_num_shards * N), + clamp_limit, + ) + else: + activation_func( + activation, + intermediate_cache2, + intermediate_cache1.view(-1, w13_num_shards * N), + ) if output is None: output = intermediate_cache3 @@ -247,6 +257,7 @@ def fused_marlin_moe( output: torch.Tensor | None = None, input_dtype: torch.dtype | None = None, inplace: bool = False, + clamp_limit: float | None = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -363,6 +374,7 @@ def fused_marlin_moe( output=None, input_dtype=input_dtype, is_k_full=is_k_full, + clamp_limit=clamp_limit, ).view(-1, topk, K) if output is None: @@ -557,6 +569,7 @@ def __init__( self.w2_g_idx_sort_indices = w2_g_idx_sort_indices self.is_k_full = is_k_full self.input_dtype = get_marlin_input_dtype() + self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit super().__init__( moe_config=moe_config, @@ -850,6 +863,7 @@ def moe_sum_with_lora(moe_out: torch.Tensor, out: torch.Tensor) -> None: sort_indices2=self.w2_g_idx_sort_indices, is_k_full=self.is_k_full, input_dtype=self.input_dtype, + clamp_limit=self.gemm1_clamp_limit, ) def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index a239dfea92e4..7d54e2b717d6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -169,5 +169,6 @@ def apply_monolithic( layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 142e180786c6..f2461ed59985 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -84,6 +84,14 @@ def get_fused_moe_quant_config( ) -> FusedMoEQuantConfig | None: return self.moe_quant_config + def reserve_workspace(self, layer: "FusedMoE") -> None: # type: ignore[name-defined] # noqa: F821 + self.moe_kernel.reserve_workspace( + layer.w13_weight, + layer.w2_weight, + layer.global_num_experts, + layer.activation, + ) + def apply( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b28fc5dda257..7471793daba0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -268,6 +268,7 @@ def __init__( custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, + swiglu_limit: float | None = None, e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -285,6 +286,7 @@ def __init__( routed_output_transform: torch.nn.Module | None = None, apply_routed_scale_to_output: bool = False, zero_expert_type: str | None = None, + hash_indices_table: torch.Tensor | None = None, ): super().__init__() @@ -294,6 +296,7 @@ def __init__( vllm_config = get_current_vllm_config() self.vllm_config = vllm_config + self.swiglu_limit = swiglu_limit # FIXME (varun): We should have a better way of inferring the activation # datatype. This works for now as the tensor datatype entering the MoE @@ -455,6 +458,7 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias # TODO(bnell): end attributes + self.hash_indices_table = hash_indices_table self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = MoEActivation.from_str(activation) @@ -479,6 +483,7 @@ def __init__( indices_type_getter=lambda: self.quant_method.topk_indices_dtype, zero_expert_type=zero_expert_type, num_logical_experts=self.logical_num_experts, + hash_indices_table=self.hash_indices_table, ) self.routing_method_type: RoutingMethodType = self.router.routing_method_type @@ -634,15 +639,15 @@ def maybe_init_modular_kernel(self) -> None: logger.debug( "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) ) - self._replace_quant_method( - FusedMoEModularMethod.make( - self, - self.base_quant_method, - prepare_finalize, - self.shared_experts, - inplace=not self.moe_config.disable_inplace, - ) + modular_method = FusedMoEModularMethod.make( + self, + self.base_quant_method, + prepare_finalize, + self.shared_experts, + inplace=not self.moe_config.disable_inplace, ) + modular_method.reserve_workspace(self) + self._replace_quant_method(modular_method) @property def shared_experts(self) -> SharedExperts | None: @@ -779,6 +784,19 @@ def update_expert_map(self): dp_size=get_dp_group().world_size, ) + @staticmethod + def _normalize_loaded_weight_for_copy( + expert_data: torch.Tensor, loaded_weight: torch.Tensor + ) -> torch.Tensor: + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and expert_data.dtype == torch.uint8 + and loaded_weight.dtype == e8m0_dtype + ): + return loaded_weight.view(torch.uint8) + return loaded_weight + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -792,10 +810,12 @@ def _load_per_tensor_weight_scale( # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight + target = param_data[expert_id][idx] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) # If we are in the row parallel case (down_proj) elif shard_id == "w2": - param_data[expert_id] = loaded_weight + target = param_data[expert_id] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) def _load_combined_w13_weight_scale( self, @@ -812,7 +832,7 @@ def _load_combined_w13_weight_scale( loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) - param.copy_(loaded_weight) + param.copy_(self._normalize_loaded_weight_for_copy(param, loaded_weight)) def _load_model_weight_or_group_weight_scale( self, @@ -979,7 +999,9 @@ def _load_w13( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_w2( self, @@ -1015,7 +1037,9 @@ def _load_w2( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int @@ -1541,10 +1565,12 @@ def forward( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: return self.runner.forward( hidden_states, router_logits, + input_ids, ) @property diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b0f967085ae4..c7fc0830be2f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1101,6 +1101,46 @@ def _allocate_buffers( return workspace13, workspace2, fused_out + def reserve_workspace( + self, + w1: torch.Tensor, + w2: torch.Tensor, + global_num_experts: int, + activation: MoEActivation, + max_num_tokens: int | None = None, + ) -> None: + moe_config = self.fused_experts.moe_config + max_num_tokens = max_num_tokens or moe_config.max_num_tokens + dummy_hidden_states = torch.empty( + (max_num_tokens, moe_config.hidden_dim), + device="meta", + ) + dummy_topk_ids = torch.empty( + (max_num_tokens, moe_config.experts_per_token), + device="meta", + dtype=torch.int32, + ) + _, max_m, n, k, top_k = self.fused_experts.moe_problem_size( + dummy_hidden_states, + w1, + w2, + dummy_topk_ids, + ) + local_num_experts = w1.shape[0] + self._allocate_buffers( + moe_config.in_dtype, + w1.device, + max_m, + max_m, + n, + k, + top_k, + global_num_experts, + local_num_experts, + None, + activation, + ) + def _maybe_apply_shared_experts( self, shared_experts_input: torch.Tensor | None, @@ -1565,6 +1605,23 @@ def output_is_reduced(self) -> bool: """ return self.prepare_finalize.output_is_reduced() + def reserve_workspace( + self, + w1: torch.Tensor, + w2: torch.Tensor, + global_num_experts: int, + activation: MoEActivation, + max_num_tokens: int | None = None, + ) -> None: + assert isinstance(self.impl, FusedMoEKernelModularImpl) + self.impl.reserve_workspace( + w1, + w2, + global_num_experts, + activation, + max_num_tokens=max_num_tokens, + ) + def apply_monolithic( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 9519487f82f1..f476d980d555 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -17,6 +17,7 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + FusedMoEQuantDesc, mxfp4_mxfp8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, @@ -24,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4 from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, + kFp8Dynamic128Sym, kMxfp4Static, kMxfp8Dynamic, ) @@ -46,6 +48,8 @@ class Mxfp4MoeBackend(Enum): NONE = "None" + # DeepGEMM FP8xFP4 backend (SM100+) + DEEPGEMM_MXFP4 = "DEEPGEMM_MXFP4" # FlashInfer TRTLLM backends FLASHINFER_TRTLLM_MXFP4_MXFP8 = "FLASHINFER_TRTLLM_MXFP4_MXFP8" FLASHINFER_TRTLLM_MXFP4_BF16 = "FLASHINFER_TRTLLM_MXFP4_BF16" @@ -81,7 +85,14 @@ class Mxfp4MoeBackend(Enum): def backend_to_kernel_cls( backend: Mxfp4MoeBackend, ) -> list[type[mk.FusedMoEExperts]]: - if backend in ( + if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: + from vllm.model_executor.layers.fused_moe.experts.deep_gemm_moe import ( + DeepGemmFP4Experts, + ) + + return [DeepGemmFP4Experts] + + elif backend in ( Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, ): @@ -159,11 +170,13 @@ def backend_to_kernel_cls( def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend: """Map user's moe_backend string to Mxfp4MoeBackend.""" mapping: dict[str, Mxfp4MoeBackend] = { + "deep_gemm": Mxfp4MoeBackend.DEEPGEMM_MXFP4, "flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, "flashinfer_trtllm_afp8": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, "flashinfer_cutlass": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, "flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, "triton": Mxfp4MoeBackend.TRITON, + "triton_unfused": Mxfp4MoeBackend.TRITON_UNFUSED, "marlin": Mxfp4MoeBackend.MARLIN, "aiter": Mxfp4MoeBackend.AITER, "xpu": Mxfp4MoeBackend.XPU, @@ -177,7 +190,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend: ) -def _get_priority_backends() -> list[Mxfp4MoeBackend]: +def _get_priority_backends_for_gpt_oss() -> list[Mxfp4MoeBackend]: """ Get available backends in priority order based on platform and config. Only includes BF16 backends. MXFP8 backends are selected via env vars. @@ -187,7 +200,9 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: Mxfp4MoeBackend.AITER, Mxfp4MoeBackend.TRITON, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, - Mxfp4MoeBackend.TRITON_UNFUSED, + # TRITON_UNFUSED has bug with MTP support + # TODO re-enable after kernel is fixed + # TRITON_UNFUSED Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, Mxfp4MoeBackend.XPU, @@ -196,8 +211,28 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: return _AVAILABLE_BACKENDS +def _get_priority_backends() -> list[Mxfp4MoeBackend]: + """ + Get available backends in priority order. SM100+ prefers DeepGEMM FP4 / + TRTLLM MXFP8; SM90 falls through to Triton_unfused or Marlin (the + backend-level ``is_supported_config`` check filters by device capability). + """ + _AVAILABLE_BACKENDS = [ + Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, + Mxfp4MoeBackend.DEEPGEMM_MXFP4, + # TRITON_UNFUSED has bug with MTP support + # TODO re-enable after kernel is fixed + # TRITON_UNFUSED + Mxfp4MoeBackend.MARLIN, + Mxfp4MoeBackend.BATCHED_MARLIN, + ] + return _AVAILABLE_BACKENDS + + def _backend_activation_key(backend: Mxfp4MoeBackend) -> QuantKey | None: - """Map backend to its activation key (MXFP8 or None for BF16).""" + """Map backend to its activation key (FP8, MXFP8, or None for BF16).""" + if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: + return kFp8Dynamic128Sym if backend in ( Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, @@ -290,7 +325,7 @@ def _return_or_raise( ) # Select kernels in order of backend. - AVAILABLE_BACKENDS = _get_priority_backends() + AVAILABLE_BACKENDS = _get_priority_backends_for_gpt_oss() # Handle explicit FlashInfer MXFP4 BF16 configuration. if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): @@ -387,11 +422,95 @@ def _return_or_raise( return Mxfp4MoeBackend.NONE, None +def select_mxfp4_moe_backend( + config: FusedMoEConfig, +) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts] | None]: + """ + Select the MXFP4 MoE backend with MXFP8 activation as top priority. + Falls back through BF16 and other backends. + """ + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if config.moe_parallel_config.use_batched_activation_format + else mk.FusedMoEActivationFormat.Standard + ) + + def _make_log_backend(backend: Mxfp4MoeBackend): + return f"Using '{backend.value}' Mxfp4 MoE backend." + + def _make_log_unsupported(backend: Mxfp4MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"Mxfp4 MoE backend '{backend.value}' does not support the " + f"deployment configuration since {reason}." + ) + return ( + f"Mxfp4 MoE backend '{backend.value}' does not support the " + "deployment configuration." + ) + + def _return_or_raise( + backend: Mxfp4MoeBackend, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[Mxfp4MoeBackend, type[mk.FusedMoEExperts]]: + reason: str | None = None + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) + + # Honor explicit moe_backend (e.g. "marlin", "triton_unfused") before + # falling back to the auto priority list. + runner_backend = config.moe_backend + if runner_backend != "auto": + requested_backend = map_mxfp4_backend(runner_backend) + if ( + activation_format == mk.FusedMoEActivationFormat.BatchedExperts + and requested_backend == Mxfp4MoeBackend.MARLIN + ): + requested_backend = Mxfp4MoeBackend.BATCHED_MARLIN + return _return_or_raise( + requested_backend, + config, + kMxfp4Static, + _backend_activation_key(requested_backend), + activation_format, + ) + + # Iterate priority backends: TRTLLM MXFP8, then Triton. + for backend in _get_priority_backends(): + activation_key = _backend_activation_key(backend) + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, config, kMxfp4Static, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once(_make_log_unsupported(backend, reason), scope="local") + + raise NotImplementedError( + "No MXFP4 MoE backend supports the deployment configuration." + ) + + def mxfp4_round_up_hidden_size_and_intermediate_size( backend: Mxfp4MoeBackend, hidden_size: int, intermediate_size: int ) -> tuple[int, int]: """Round up hidden_size and intermediate_size based on backend requirements.""" - if backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): + if backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: + # DeepGEMM requires M/N/K alignment + intermediate_size = round_up(intermediate_size, 128) + hidden_size = round_up(hidden_size, 128) + elif backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): intermediate_size = round_up(intermediate_size, 128) if current_platform.is_xpu(): hidden_size = round_up(hidden_size, 128) @@ -434,6 +553,20 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format( ]: """Convert loaded weights into backend-specific kernel format.""" + if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + return ( + w13_weight.data, + w2_weight.data, + _upcast_e8m0_to_fp32(w13_weight_scale.data), + _upcast_e8m0_to_fp32(w2_weight_scale.data), + w13_bias, + w2_bias, + ) + num_experts = w13_weight.shape[0] intermediate_size = w13_weight.shape[1] // 2 hidden_size = w13_weight.shape[2] * 2 @@ -738,9 +871,10 @@ def _interleave_mxfp4_cutlass_sm90(w): elif mxfp4_backend in TRITON_BACKENDS: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig - assert w13_bias is not None and w2_bias is not None - w13_bias = w13_bias.to(torch.float32) - w2_bias = w2_bias.to(torch.float32) + if w13_bias is not None: + w13_bias = w13_bias.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.to(torch.float32) w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( w13_weight, @@ -797,15 +931,271 @@ def _interleave_mxfp4_cutlass_sm90(w): ) +def convert_weight_to_mxfp4_moe_kernel_format( + mxfp4_backend: Mxfp4MoeBackend, + layer: torch.nn.Module, + w13_weight: torch.Tensor, + w2_weight: torch.Tensor, + w13_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w13_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + _cache_permute_indices: dict[torch.Size, torch.Tensor] | None = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + Union[torch.Tensor, "PrecisionConfig"], + Union[torch.Tensor, "PrecisionConfig"], + torch.Tensor | None, + torch.Tensor | None, +]: + """Convert loaded weights into backend-specific kernel format. + + Supports DeepGEMM, TRTLLM MXFP8, Triton and Marlin backends. + """ + + if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + # Weights stay as uint8 packed FP4 — no layout change needed. + # Convert E8M0 uint8 scales to float32. + return ( + w13_weight.data, + w2_weight.data, + _upcast_e8m0_to_fp32(w13_weight_scale.data), + _upcast_e8m0_to_fp32(w2_weight_scale.data), + w13_bias, + w2_bias, + ) + + if mxfp4_backend in (Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN): + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_mxfp4_layer_for_marlin, + ) + + return prepare_moe_mxfp4_layer_for_marlin( + layer, + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + + num_experts = w13_weight.shape[0] + intermediate_size = w13_weight.shape[1] // 2 + hidden_size = w13_weight.shape[2] * 2 + + sf_block_size = 32 # mxfp4 block size + + if mxfp4_backend in TRTLLM_BACKENDS: + assert _cache_permute_indices is not None + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache + + w13_weight = w13_weight.data + w2_weight = w2_weight.data + w13_weight_scale = w13_weight_scale.data + w2_weight_scale = w2_weight_scale.data + if w13_bias is not None: + w13_bias = w13_bias.data.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.data.to(torch.float32) + + # Swap w1/w3 and interleave to match TRTLLM SwiGLU convention. + # Standard loading gives contiguous [w1/gate, w3/up]. + # TRTLLM kernel expects interleaved [w3_0, w1_0, w3_1, w1_1, ...]. + w1_weight = w13_weight[:, :intermediate_size, :] + w3_weight = w13_weight[:, intermediate_size:, :] + w13_weight = torch.stack([w3_weight, w1_weight], dim=2).reshape( + w13_weight.shape + ) + + w1_scale = w13_weight_scale[:, :intermediate_size, :] + w3_scale = w13_weight_scale[:, intermediate_size:, :] + w13_weight_scale = torch.stack([w3_scale, w1_scale], dim=2).reshape( + w13_weight_scale.shape + ) + + if w13_bias is not None: + b1 = w13_bias[:, :intermediate_size] + b3 = w13_bias[:, intermediate_size:] + w13_bias = torch.stack([b3, b1], dim=2).reshape(w13_bias.shape) + + # Shuffle weights and scaling factors for transposed mma output. + # Permute indices depend only on shape (cached by torch.Size), + # so compute once and apply to all experts via batched indexing. + epilogue_tile_m = 128 + + # w13 weight permute + w13_perm = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w13_weight[0].view(torch.uint8), + epilogue_tile_m, + ).to(w13_weight.device) + w13_weight = w13_weight.view(torch.uint8)[:, w13_perm].contiguous() + + # w13 scale permute + interleave + w13_sf_perm = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w13_weight_scale[0].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ).to(w13_weight_scale.device) + w13_s = w13_weight_scale.view(torch.uint8)[:, w13_sf_perm].contiguous() + E, N_s, K_s = w13_s.shape + w13_weight_scale = ( + nvfp4_block_scale_interleave(w13_s.reshape(E * N_s, K_s)) + .reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size) + .view(torch.float8_e4m3fn) + ) + + # w2 weight permute + w2_perm = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w2_weight[0].view(torch.uint8), + epilogue_tile_m, + ).to(w2_weight.device) + w2_weight = w2_weight.view(torch.uint8)[:, w2_perm].contiguous() + + # w2 scale permute + interleave + w2_sf_perm = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w2_weight_scale[0].view(torch.uint8), + epilogue_tile_m, + num_elts_per_sf=16, + ).to(w2_weight_scale.device) + w2_s = w2_weight_scale.view(torch.uint8)[:, w2_sf_perm].contiguous() + E2, N2_s, K2_s = w2_s.shape + w2_weight_scale = ( + nvfp4_block_scale_interleave(w2_s.reshape(E2 * N2_s, K2_s)) + .reshape(num_experts, hidden_size, intermediate_size // sf_block_size) + .view(torch.float8_e4m3fn) + ) + + # w13 bias permute + if w13_bias is not None: + w13_b_perm = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w13_bias[0].reshape(-1, 1), + epilogue_tile_m, + ).to(w13_bias.device) + w13_bias = w13_bias.reshape(num_experts, -1, 1)[:, w13_b_perm].reshape( + num_experts, -1 + ) + + # w2 bias permute + if w2_bias is not None: + w2_b_perm = get_w2_permute_indices_with_cache( + _cache_permute_indices, + w2_bias[0].reshape(-1, 1), + epilogue_tile_m, + ).to(w2_bias.device) + w2_bias = w2_bias.reshape(num_experts, -1, 1)[:, w2_b_perm].reshape( + num_experts, -1 + ) + + return ( + w13_weight, + w2_weight, + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + + elif mxfp4_backend in TRITON_BACKENDS: + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + if mxfp4_backend == Mxfp4MoeBackend.TRITON: + + def shuffle_weight(w: torch.Tensor) -> torch.Tensor: + shape = w.shape + n = shape[-1] + first = w[..., : n // 2] + second = w[..., n // 2 :] + stacked = torch.stack((first, second), dim=-1) + return stacked.reshape(shape) + + w13_weight = shuffle_weight(w13_weight) + w13_weight_scale = shuffle_weight(w13_weight_scale) + + if w13_bias is not None: + w13_bias = shuffle_weight(w13_bias.to(torch.float32)) + else: + if w13_bias is not None: + w13_bias = w13_bias.to(torch.float32) + + if w2_bias is not None: + w2_bias = w2_bias.to(torch.float32) + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + w13_weight, + w13_weight_scale, + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + w2_weight, + w2_weight_scale, + ) + + w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) + ) + w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) + ) + + del layer.w13_weight + del layer.w2_weight + + return ( + w13_weight, + w2_weight, + w13_precision_config, + w2_precision_config, + w13_bias, + w2_bias, + ) + else: + raise ValueError( + f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. " + f"Expected TRTLLM or Triton backend." + ) + + def make_mxfp4_moe_quant_config( mxfp4_backend: Mxfp4MoeBackend, w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], + gemm1_alpha: float | None = None, + gemm1_beta: float | None = None, + swiglu_limit: float | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> FusedMoEQuantConfig | None: """Create a FusedMoEQuantConfig for the given MXFP4 backend.""" - if mxfp4_backend in ( + if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + ) + + # DeepGEMM FP4 uses FP8 per-token-group activation quantization + # with block 128, matching the FP8 DeepGEMM path. + _fp8_dtype = current_platform.fp8_dtype() + _block_shape = GroupShape(128, 128) + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None), + _a2=FusedMoEQuantDesc(_fp8_dtype, _block_shape, None, None, None, None), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, + ) + elif mxfp4_backend in ( Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, ): @@ -814,6 +1204,9 @@ def make_mxfp4_moe_quant_config( w2_bias=w2_bias, w1_scale=w1_scale, w2_scale=w2_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, ) elif mxfp4_backend in ( Mxfp4MoeBackend.MARLIN, @@ -829,6 +1222,9 @@ def make_mxfp4_moe_quant_config( w2_bias=w2_bias, w1_scale=w1_scale, w2_scale=w2_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, ) else: return ocp_mx_moe_quant_config( @@ -837,6 +1233,9 @@ def make_mxfp4_moe_quant_config( w2_bias=w2_bias, w1_scale=w1_scale, w2_scale=w2_scale, + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=swiglu_limit, ) diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 011aa9a03097..0138eb59c91c 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -228,6 +228,8 @@ def _compute_routing( hidden_states: torch.Tensor, router_logits: torch.Tensor, indices_type: torch.dtype | None, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute the actual routing logic. @@ -249,6 +251,8 @@ def select_experts( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Route the input hidden states to the top-k experts based on the @@ -278,7 +282,7 @@ def select_experts( # Step 3: Compute routing (delegated to subclass) topk_weights, topk_ids = self._compute_routing( - hidden_states, router_logits, indices_type + hidden_states, router_logits, indices_type, input_ids=input_ids ) # Capture logical ids before EPLB mapping. diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index 0367189ca1ab..c1bd7a6993ab 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -46,6 +46,8 @@ def _compute_routing( hidden_states: torch.Tensor, router_logits: torch.Tensor, indices_type: torch.dtype | None, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute routing using the custom routing function.""" topk_weights, topk_ids = self.custom_routing_function( diff --git a/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py b/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py index d7aed4fdeb2b..d82085254f9b 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_moe_router.py @@ -31,6 +31,8 @@ def select_experts( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Route the input hidden states to the top-k experts based on the diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index a5361b399e2a..07d8af6cce90 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -4,6 +4,7 @@ from collections.abc import Callable import torch +import torch.nn.functional as F import vllm._custom_ops as ops import vllm.envs as envs @@ -56,6 +57,38 @@ def vllm_topk_sigmoid( return topk_weights, topk_indices +def vllm_topk_softplus_sqrt( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool = False, + e_score_correction_bias: torch.Tensor | None = None, + input_tokens: torch.Tensor | None = None, + hash_indices_table: torch.Tensor | None = None, + routed_scaling_factor: float = 1.0, +) -> tuple[torch.Tensor, ...]: + idx_dtype = topk_indices.dtype + if input_tokens is not None and input_tokens.dtype != idx_dtype: + input_tokens = input_tokens.to(idx_dtype) + if hash_indices_table is not None and hash_indices_table.dtype != idx_dtype: + hash_indices_table = hash_indices_table.to(idx_dtype) + + ops.topk_hash_softplus_sqrt( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + input_tokens, + hash_indices_table, + ) + + return topk_weights, topk_indices + + @functools.lru_cache(maxsize=8) def _aiter_get_num_expert_group(num_experts: int) -> int: _AITER_MAX_EXPERTS_PER_GROUP = 32 @@ -72,11 +105,14 @@ def _aiter_get_num_expert_group(num_experts: int) -> int: def fused_topk_bias( hidden_states: torch.Tensor, gating_output: torch.Tensor, + scoring_func: str, e_score_correction_bias: torch.Tensor, topk: int, renormalize: bool, - scoring_func: str = "softmax", indices_type: torch.dtype | None = None, + input_tokens: torch.Tensor | None = None, + hash_indices_table: torch.Tensor | None = None, + routed_scaling_factor: float = 1.0, ): if not rocm_aiter_ops.is_fused_moe_enabled(): assert hidden_states.size(0) == gating_output.size(0), ( @@ -107,6 +143,8 @@ def fused_topk_bias( renormalize, e_score_correction_bias, ) + if routed_scaling_factor != 1.0: + topk_weights *= routed_scaling_factor return topk_weights, topk_ids elif scoring_func == "sigmoid": topk_weights, topk_ids = vllm_topk_sigmoid( @@ -117,9 +155,24 @@ def fused_topk_bias( renormalize, e_score_correction_bias, ) + if routed_scaling_factor != 1.0: + topk_weights *= routed_scaling_factor return topk_weights, topk_ids + elif scoring_func == "sqrtsoftplus": + return vllm_topk_softplus_sqrt( + topk_weights, + topk_ids, + token_expert_indices, + gating_output, + renormalize, + e_score_correction_bias, + input_tokens, + hash_indices_table, + routed_scaling_factor, + ) else: raise ValueError(f"Unsupported scoring function: {scoring_func}") + elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid": M = hidden_states.size(0) num_experts = gating_output.shape[-1] @@ -143,6 +196,8 @@ def fused_topk_bias( topk_group=num_expert_group, need_renorm=renormalize, ) + if routed_scaling_factor != 1.0: + topk_weights *= routed_scaling_factor return topk_weights, topk_ids n_routed_experts = gating_output.shape[-1] @@ -150,20 +205,31 @@ def fused_topk_bias( scores = gating_output.softmax(dim=-1) elif scoring_func == "sigmoid": scores = gating_output.sigmoid() + elif scoring_func == "sqrtsoftplus": + scores = F.softplus(gating_output).sqrt() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") - - scores_for_choice = scores.view( - -1, n_routed_experts - ) + e_score_correction_bias.unsqueeze(0) - + if e_score_correction_bias is not None: + scores_for_choice = scores.view( + -1, n_routed_experts + ) + e_score_correction_bias.unsqueeze(0) + else: + scores_for_choice = scores.view(-1, n_routed_experts) # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = envs.VLLM_BATCH_INVARIANT - topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] + if hash_indices_table is not None: + topk_indices = hash_indices_table[input_tokens].to(topk_ids.dtype) + else: + use_sorted = envs.VLLM_BATCH_INVARIANT + topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[ + 1 + ] topk_weights = scores.gather(1, topk_indices) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights.to(torch.float32), topk_indices.to( + topk_weights = topk_weights.to(torch.float32) + if routed_scaling_factor != 1.0: + topk_weights *= routed_scaling_factor + return topk_weights, topk_indices.to( torch.int32 if indices_type is None else indices_type ) @@ -176,12 +242,14 @@ def __init__( top_k: int, global_num_experts: int, eplb_state: EplbLayerState, - e_score_correction_bias: torch.Tensor, - scoring_func: str, + e_score_correction_bias: torch.Tensor | None = None, renormalize: bool = True, routed_scaling_factor: float = 1.0, enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, + *, + scoring_func: str = "sigmoid", + hash_indices_table: torch.Tensor | None = None, ): super().__init__( top_k=top_k, @@ -194,6 +262,8 @@ def __init__( self.renormalize = renormalize self.scoring_func = scoring_func self.routed_scaling_factor = routed_scaling_factor + self.scoring_func = scoring_func + self._hash_indices_table = hash_indices_table @property def routing_method_type(self) -> RoutingMethodType: @@ -210,19 +280,23 @@ def _compute_routing( hidden_states: torch.Tensor, router_logits: torch.Tensor, indices_type: torch.dtype | None, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute routing using fused top-k with bias.""" topk_weights, topk_ids = fused_topk_bias( hidden_states=hidden_states, gating_output=router_logits, - e_score_correction_bias=self.e_score_correction_bias.data, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias.data + if self.e_score_correction_bias is not None + else None, topk=self.top_k, renormalize=self.renormalize, - scoring_func=self.scoring_func, indices_type=indices_type, + input_tokens=input_ids, + hash_indices_table=self._hash_indices_table, + routed_scaling_factor=self.routed_scaling_factor, ) - if self.routed_scaling_factor != 1.0: - topk_weights *= self.routed_scaling_factor - return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py index 01376e6b16b5..45311dba08e3 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py @@ -151,6 +151,8 @@ def _compute_routing( hidden_states: torch.Tensor, router_logits: torch.Tensor, indices_type: torch.dtype | None, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute routing using standard fused top-k.""" topk_weights, topk_ids, token_expert_indices = fused_topk( diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py index 1bf141d81e4b..74c3a62a1f11 100644 --- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -292,6 +292,8 @@ def _compute_routing( hidden_states: torch.Tensor, router_logits: torch.Tensor, indices_type: torch.dtype | None, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute routing using grouped top-k.""" @@ -308,6 +310,7 @@ def valid_grouping() -> bool: topk_weights, topk_ids = fused_topk_bias( hidden_states=hidden_states, gating_output=router_logits, + scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias.data, topk=self.top_k, renormalize=self.renormalize, diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 42d418d7e537..da7896de6159 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -55,6 +55,7 @@ def create_fused_moe_router( # zero expert parameters zero_expert_type: str | None = None, num_logical_experts: int | None = None, + hash_indices_table: torch.Tensor | None = None, ) -> FusedMoERouter: """ Factory function to create the appropriate FusedMoERouter subclass based on @@ -99,6 +100,9 @@ def create_fused_moe_router( num_logical_experts: Number of real (non-zero) experts. Required when zero_expert_type is not None. + Hash Indices Table: + Used to map input_ids to experts, need for Deepseek V4 + Returns: An instance of the appropriate FusedMoERouter subclass """ @@ -179,17 +183,20 @@ def create_fused_moe_router( indices_type_getter=indices_type_getter, ) - if e_score_correction_bias is not None: + assert scoring_func in ["sigmoid", "softmax", "sqrtsoftplus"] + + if e_score_correction_bias is not None or hash_indices_table is not None: return FusedTopKBiasRouter( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, e_score_correction_bias=e_score_correction_bias, - scoring_func=scoring_func, renormalize=renormalize, routed_scaling_factor=routed_scaling_factor, enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, + scoring_func=scoring_func, + hash_indices_table=hash_indices_table, ) return FusedTopKRouter( diff --git a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py index f8e46371841a..8fb36b72cb70 100644 --- a/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py +++ b/vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py @@ -334,6 +334,8 @@ def _compute_routing( hidden_states: torch.Tensor, router_logits: torch.Tensor, indices_type: torch.dtype | None, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Use routing simulator to compute routing.""" routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY diff --git a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py index c87070bc5acf..65760727770a 100644 --- a/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py +++ b/vllm/model_executor/layers/fused_moe/router/zero_expert_router.py @@ -72,6 +72,8 @@ def _compute_routing( hidden_states: torch.Tensor, router_logits: torch.Tensor, indices_type: torch.dtype | None, + *, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute routing with full bias, compute zero expert output, mask zero expert IDs.""" diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index 6b13c0c36323..d2d34db6b2b9 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -91,6 +91,7 @@ def _moe_forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, + input_ids: torch.Tensor | None, layer_name: _layer_name_type, ) -> torch.Tensor: layer = get_layer_from_name(_resolve_layer_name(layer_name)) @@ -99,6 +100,7 @@ def _moe_forward( hidden_states, router_logits, shared_experts_input, + input_ids, ) @@ -106,6 +108,7 @@ def _moe_forward_fake( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, + input_ids: torch.Tensor | None, layer_name: _layer_name_type, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -115,6 +118,7 @@ def _moe_forward_shared( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, + input_ids: torch.Tensor | None, layer_name: _layer_name_type, ) -> tuple[torch.Tensor, torch.Tensor]: layer = get_layer_from_name(_resolve_layer_name(layer_name)) @@ -123,6 +127,7 @@ def _moe_forward_shared( hidden_states, router_logits, shared_experts_input, + input_ids, ) @@ -130,6 +135,7 @@ def _moe_forward_shared_fake( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, + input_ids: torch.Tensor | None, layer_name: _layer_name_type, ) -> tuple[torch.Tensor, torch.Tensor]: # Output shapes: @@ -433,6 +439,7 @@ def _apply_quant_method( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, + input_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor]: """Run expert routing and the fused MoE kernel via the quant method. @@ -454,6 +461,7 @@ def _apply_quant_method( topk_weights, topk_ids = self.router.select_experts( hidden_states=hidden_states, router_logits=router_logits, + input_ids=input_ids, ) # Passing shared_experts_input in case SharedExpertsOrder is @@ -523,6 +531,7 @@ def forward( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: """Invoke the fused moe layer. @@ -565,6 +574,7 @@ def forward( hidden_states, router_logits, shared_experts_input, + input_ids, self._encode_layer_name(), ) @@ -672,6 +682,7 @@ def _forward_impl( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Entry point called by the custom op to run the MoE computation. @@ -712,6 +723,7 @@ def _forward_impl( hidden_states=hidden_states, router_logits=router_logits, shared_experts_input=shared_experts_input, + input_ids=input_ids, ) return self._maybe_combine( diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_interface.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_interface.py index 80bd83e3732e..9a6c37aa3983 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner_interface.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner_interface.py @@ -26,6 +26,7 @@ def forward( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index e33111aa0ab2..89697033403d 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -309,6 +309,7 @@ def apply_monolithic( layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic if self.unquantized_backend == UnquantizedMoeBackend.CPU: diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index d8e174051b51..ffab3ca0bfa9 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -4,6 +4,7 @@ from math import prod import torch +import torch.nn.functional as F from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -384,3 +385,20 @@ def trtllm_moe_pack_topk_ids_weights( return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view( torch.int16 ) + + +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def swiglu_limit_func( + output: torch.Tensor, + input: torch.Tensor, # first half is gate, second half is up + swiglu_limit: float = 0.0, +) -> None: + d = input.shape[1] // 2 + gate = input[:, :d] + up = input[:, d:] + + if swiglu_limit > 0: + gate = torch.clamp(gate, max=swiglu_limit) + up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) + + output.copy_(F.silu(gate) * up) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py new file mode 100644 index 000000000000..b6e4374032bc --- /dev/null +++ b/vllm/model_executor/layers/mhc.py @@ -0,0 +1,436 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from functools import cache + +import tilelang +import tilelang.language as T +import torch + +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import direct_register_custom_op + + +@cache +def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: + device_props = torch.cuda.get_device_properties(0) + n_sms = device_props.multi_processor_count + split_k = n_sms // grid_size + if k is not None: + # avoid split_k for small k + num_block_k = cdiv(k, block_k) + split_k = min(split_k, num_block_k // 4) + split_k = max(split_k, 1) + return split_k + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual, + post_mix, + comb_mix, + layer_input, + hidden_size: int, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 16, + hc_mult: int = 4, +): + """Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block.""" + num_tokens = T.dynamic("num_tokens") + hc_mult3 = hc_mult * (2 + hc_mult) + hidden_block = math.gcd(512, hidden_size) + + gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] # type: ignore[no-redef, valid-type] + gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] # type: ignore[no-redef, valid-type] + hc_scale: T.Tensor[[3], T.float32] # type: ignore[no-redef, valid-type] + hc_base: T.Tensor[[hc_mult3], T.float32] # type: ignore[no-redef, valid-type] + residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type] + # outputs + post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] # type: ignore[no-redef, valid-type] + comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] # type: ignore[no-redef, valid-type] + layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type] + + with T.Kernel(num_tokens, threads=96) as i: + T.pdl_sync() + ################################################################## + # _pre_norm_fn_fwd_norm + rms = T.alloc_fragment(1, T.float32) + mixes = T.alloc_fragment(hc_mult3, T.float32) + T.clear(mixes) + rms[0] = 0 + for i_split in T.serial(n_splits): + rms[0] += gemm_out_sqrsum[i_split, i] + rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) + for j in T.Parallel(hc_mult3): + mixes[j] = 0 + for i_split in T.serial(n_splits): + mixes[j] += gemm_out_mul[i_split, i, j] + mixes[j] *= rms[0] + mixes_shared = T.alloc_shared(hc_mult3, T.float32) + T.copy(mixes, mixes_shared) + + if T.get_thread_binding() < 32: + ################################################################## + # _pre_split_mixes_fwd (post & comb) + cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) + for j in T.Parallel(hc_mult): + post_mix[i, j] = ( + T.sigmoid( + mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult] + ) + * hc_post_mult_value + ) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = ( + mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] + + hc_base[j * hc_mult + k + hc_mult * 2] + ) + + ################################################################## + # _sinkhorn_fwd + row_sum = T.alloc_fragment(hc_mult, T.float32) + col_sum = T.alloc_fragment(hc_mult, T.float32) + + # comb = comb.softmax(-1) + eps + row_max = T.alloc_fragment(hc_mult, T.float32) + T.reduce_max(cm, row_max, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = T.exp(cm[j, k] - row_max[j]) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for _ in T.serial(sinkhorn_repeat - 1): + # comb = comb / (comb.sum(-1) + eps) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + # save comb_mix to global memory + for j, k in T.Parallel(hc_mult, hc_mult): + comb_mix[i, j * hc_mult + k] = cm[j, k] + else: + ################################################################## + # _pre_split_mixes_fwd (pre) + pre_mix_shared = T.alloc_shared(hc_mult, T.float32) + for j in T.Parallel(hc_mult): + pre_mix_shared[j] = ( + T.sigmoid( + mixes_shared[j] * hc_scale[0] + hc_base[j], + ) + + hc_pre_eps + ) + ################################################################### + # _pre_apply_mix_fwd + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): + xs = T.alloc_shared((hc_mult, hidden_block), T.float32) + xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) + T.copy(residual[i, 0, i0_h * hidden_block], xs) + T.copy(xs, xl) + + ol = T.alloc_fragment(hidden_block, T.float32) + T.clear(ol) + + for i_hc in T.serial(hc_mult): + pre = pre_mix_shared[i_hc] + for i1_h in T.Parallel(hidden_block): + ol[i1_h] += pre * xl[i_hc, i1_h] + + T.copy(ol, layer_input[i, i0_h * hidden_block]) + T.pdl_trigger() + + +def mhc_pre( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for mHC pre block. + + Args: + residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16 + fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32 + hc_scale: shape (3,), dtype torch.float32 + hc_base: shape (hc_mult3,), dtype torch.float32 + rms_eps: RMS normalization epsilon + hc_pre_eps: pre-mix epsilon + hc_sinkhorn_eps: sinkhorn epsilon + hc_post_mult_value: post-mix multiplier value + sinkhorn_repeat: number of sinkhorn iterations + n_splits: split-k factor; + + Returns: + post_mix: shape (..., hc_mult), dtype torch.float32 + comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32 + layer_input: shape (..., hidden_size), dtype torch.bfloat16 + """ + + # Validate shapes + assert residual.dtype == torch.bfloat16 + assert fn.dtype == torch.float32 + assert hc_scale.dtype == torch.float32 + assert hc_base.dtype == torch.float32 + + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + + hc_hidden_size = hc_mult * hidden_size + assert fn.shape[0] == hc_mult3 + assert fn.shape[1] == hc_hidden_size + assert hc_scale.shape == (3,) + assert hc_base.shape == (hc_mult3,) + + outer_shape = residual.shape[:-2] + + residual_flat = residual.view(-1, hc_mult, hidden_size) + num_tokens = residual_flat.shape[0] + fn_flat = fn + + # these number are from deepgemm kernel impl + block_k = 64 + block_m = 64 + n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m)) + + post_mix = torch.empty( + num_tokens, + hc_mult, + dtype=torch.float32, + device=residual.device, + ) + comb_mix = torch.empty( + num_tokens, + hc_mult2, + dtype=torch.float32, + device=residual.device, + ) + layer_input = torch.empty( + num_tokens, + hidden_size, + dtype=torch.bfloat16, + device=residual.device, + ) + + gemm_out_mul = torch.empty( + n_splits, + num_tokens, + hc_mult3, + dtype=torch.float32, + device=residual.device, + ) + gemm_out_sqrsum = torch.empty( + n_splits, + num_tokens, + dtype=torch.float32, + device=residual.device, + ) + + from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm + + tf32_hc_prenorm_gemm( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn_flat, + gemm_out_mul, + gemm_out_sqrsum, + n_splits, + ) + + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + hc_mult, + ) + + post_mix = post_mix.view(*outer_shape, hc_mult, 1) + comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) + layer_input = layer_input.view(*outer_shape, hidden_size) + + return post_mix, comb_mix, layer_input + + +def _mhc_pre_fake( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + outer_shape = residual.shape[:-2] + + # Create empty tensors with correct shapes for meta device / shape inference + post_mix = torch.empty( + *outer_shape, + hc_mult, + 1, + dtype=torch.float32, + device=residual.device, + ) + comb_mix = torch.empty( + *outer_shape, + hc_mult, + hc_mult, + dtype=torch.float32, + device=residual.device, + ) + layer_input = torch.empty( + *outer_shape, + hidden_size, + dtype=torch.bfloat16, + device=residual.device, + ) + + return post_mix, comb_mix, layer_input + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_post_tilelang( + a, + b, + c, + d, + x, + hc: int, + hidden: int, + n_thr: int = 128, + h_blk: int = 1024, +) -> tilelang.JITKernel: + # rename for shorter code + n = T.dynamic("num_tokens") + h = hidden + + h_blk = math.gcd(hidden, h_blk) + a: T.Tensor((n, hc, hc), T.float32) # type: ignore[no-redef, valid-type] + b: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type] + c: T.Tensor((n, hc), T.float32) # type: ignore[no-redef, valid-type] + d: T.Tensor((n, h), T.bfloat16) # type: ignore[no-redef, valid-type] + x: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type] + with T.Kernel(n, threads=n_thr) as i_n: + x_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + b_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + d_shared = T.alloc_shared(h_blk, T.bfloat16) + + x_local = T.alloc_fragment((hc, h_blk), T.float32) + b_local = T.alloc_fragment((hc, h_blk), T.float32) + d_local = T.alloc_fragment(h_blk, T.float32) + + a_local = T.alloc_fragment((hc, hc), T.float32) + c_local = T.alloc_fragment(hc, T.float32) + T.pdl_sync() + T.copy(a[i_n, 0, 0], a_local) + T.copy(c[i_n, 0], c_local) + + for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2): + T.copy(b[i_n, 0, i0_h * h_blk], b_shared) + T.copy(d[i_n, i0_h * h_blk], d_shared) + + T.copy(b_shared, b_local) + T.copy(d_shared, d_local) + for i_hco, i1_h in T.Parallel(hc, h_blk): + x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h] + for i_hci in T.serial(hc): + x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] + T.copy(x_local, x_shared) + + T.copy(x_shared, x[i_n, 0, i0_h * h_blk]) + T.pdl_trigger() + + +def mhc_post( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + out = torch.empty_like(residual) + mhc_post_tilelang( + comb_res_mix, + residual, + post_layer_mix.squeeze(-1), + x, + out, + residual.shape[-2], + residual.shape[-1], + ) + return out + + +def _mhc_post_fake( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + return torch.empty_like(residual) + + +direct_register_custom_op( + op_name="mhc_pre", + op_func=mhc_pre, + mutates_args=[], + fake_impl=_mhc_pre_fake, +) +direct_register_custom_op( + op_name="mhc_post", + op_func=mhc_post, + mutates_args=[], + fake_impl=_mhc_post_fake, +) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index df052fdc33bb..9f2b4e70206c 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -32,6 +32,7 @@ "inc", "mxfp4", "gpt_oss_mxfp4", + "deepseek_v4_fp8", "cpu_awq", "online", # Below are values of the OnlineQuantScheme enum, specified as strings to @@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: # lazy import to avoid triggering `torch.compile` too early from vllm.config.quantization import OnlineQuantScheme from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig + from vllm.model_executor.models.deepseek_v4 import DeepseekV4FP8Config from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig @@ -163,6 +165,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "inc": INCConfig, "mxfp4": Mxfp4Config, "gpt_oss_mxfp4": GptOssMxfp4Config, + "deepseek_v4_fp8": DeepseekV4FP8Config, "cpu_awq": CPUAWQConfig, "humming": HummingConfig, "online": OnlineQuantizationConfig, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py index 09a216fd2cb1..29c673d0f6e3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py @@ -265,6 +265,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py index 2e8a935cad6f..88cdbadd3f83 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a8_int8.py @@ -305,6 +305,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." assert layer.activation in ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py index ed8ed79c50c6..bba7e0e7abce 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_fp8.py @@ -367,6 +367,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py index 02e946b1b61e..ecd0b54890d1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w8a8_mxfp8.py @@ -168,6 +168,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py index 81b7efaa6d7e..8f86e687b7f6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_wna16_marlin.py @@ -517,6 +517,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.kernel_backend == "Flashinfer" return flashinfer_trtllm_mxint4_moe( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0dc8907248ef..1c9237d3f60a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -269,6 +269,7 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + self.is_scale_e8m0 = getattr(quant_config, "is_scale_e8m0", False) self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.out_dtype = torch.get_default_dtype() self.input_dtype = get_current_vllm_config().model_config.dtype @@ -362,6 +363,7 @@ def create_weights( input_size_per_partition, self.weight_block_size, weight_loader, + scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None), ) # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) @@ -866,6 +868,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 852ed1a10a34..242cc105e470 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -950,6 +950,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None @@ -1442,6 +1443,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None @@ -1920,6 +1922,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from flashinfer.fused_moe.core import ( ActivationType, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index b53c7cc9ac1a..a9228483bcfd 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -20,10 +20,12 @@ TRITON_BACKENDS, Mxfp4MoeBackend, convert_gpt_oss_weight_to_mxfp4_moe_kernel_format, + convert_weight_to_mxfp4_moe_kernel_format, make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, mxfp4_round_up_hidden_size_and_intermediate_size, select_gpt_oss_mxfp4_moe_backend, + select_mxfp4_moe_backend, ) from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods @@ -151,6 +153,16 @@ def __init__(self, moe: FusedMoEConfig): self.w13_precision_config = None self.w2_precision_config = None + + def _workspace_reserve_num_tokens(self) -> int: + scheduler_config = get_current_vllm_config().scheduler_config + candidates = [self.moe.max_num_tokens, self.max_capture_size or 0] + for attr in ("max_num_batched_tokens", "max_num_seqs"): + value = getattr(scheduler_config, attr, None) + if isinstance(value, int): + candidates.append(value) + return max(candidates) + @property def skip_forward_padding(self) -> bool: # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant @@ -217,6 +229,7 @@ def create_weights( ) layer.register_parameter("w13_weight_scale", w13_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) + w13_weight_scale.quant_method = "block" # down_proj (row parallel) w2_weight = torch.nn.Parameter( @@ -242,6 +255,7 @@ def create_weights( ) layer.register_parameter("w2_weight_scale", w2_weight_scale) set_weight_attrs(w2_weight_scale, extra_weight_attrs) + w2_weight_scale.quant_method = "block" if self.moe.has_bias: w13_bias = torch.nn.Parameter( @@ -363,6 +377,358 @@ def _setup_kernel( routing_tables=layer._maybe_init_expert_routing_tables(), shared_experts=layer.shared_experts, ) + if not self.is_monolithic: + self.moe_kernel.reserve_workspace( + layer.w13_weight, + layer.w2_weight, + layer.global_num_experts, + layer.activation, + max_num_tokens=self._workspace_reserve_num_tokens(), + ) + + def process_weights_after_loading(self, layer): + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w13_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + + if self.mxfp4_backend == Mxfp4MoeBackend.NONE: + return + + self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w1_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + + if self.mxfp4_backend in TRITON_BACKENDS: + assert self.w13_precision_config is not None + assert self.w2_precision_config is not None + w1_scale = self.w13_precision_config + w2_scale = self.w2_precision_config + + return make_mxfp4_moe_quant_config( + mxfp4_backend=self.mxfp4_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + gemm1_alpha=1.702, + gemm1_beta=1.0, + swiglu_limit=7.0, + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalize, + layer: torch.nn.Module, + ) -> mk.FusedMoEExpertsModular: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel " + "initialization logic. This function should not be called." + ) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + assert not self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + expert_map=layer.expert_map, + shared_experts_input=shared_experts_input, + ) + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + assert self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + router_logits=router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) + + +class Mxfp4MoEMethod(FusedMoEMethodBase): + """MXFP4 MoE quantization method.""" + + def __init__(self, moe: FusedMoEConfig): + super().__init__(moe) + self.weight_dtype = "mxfp4" + self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) + + self.max_capture_size = ( + get_current_vllm_config().compilation_config.max_cudagraph_capture_size + ) + + self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + self.moe_kernel: mk.FusedMoEKernel | None = None + + # Used for triton kernel precision configs + self.w13_precision_config = None + self.w2_precision_config = None + + + def _workspace_reserve_num_tokens(self) -> int: + scheduler_config = get_current_vllm_config().scheduler_config + candidates = [self.moe.max_num_tokens, self.max_capture_size or 0] + for attr in ("max_num_batched_tokens", "max_num_seqs"): + value = getattr(scheduler_config, attr, None) + if isinstance(value, int): + candidates.append(value) + return max(candidates) + + @property + def skip_forward_padding(self) -> bool: + # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant + # so can skip the padding in the forward before applying the moe method + return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8 + + def maybe_roundup_sizes( + self, + hidden_size: int, + intermediate_size_per_partition: int, + act_dtype: torch.dtype, + moe_parallel_config: FusedMoEParallelConfig, + ) -> tuple[int, int]: + hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes( + hidden_size=hidden_size, + intermediate_size_per_partition=intermediate_size_per_partition, + act_dtype=act_dtype, + moe_parallel_config=moe_parallel_config, + ) + return mxfp4_round_up_hidden_size_and_intermediate_size( + self.mxfp4_backend, hidden_size, intermediate_size_per_partition + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + self.num_experts = num_experts + weight_dtype = torch.uint8 + scale_dtype = torch.uint8 + mxfp4_block = 32 + + layer.params_dtype = params_dtype + layer.num_experts = num_experts + self.intermediate_size = intermediate_size_per_partition + self.hidden_size = hidden_size + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + w13_weight_scale.quant_method = "block" + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=weight_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // mxfp4_block, + dtype=scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + w2_weight_scale.quant_method = "block" + + if self.moe.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + def _setup_kernel( + self, + layer: FusedMoE, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + ) -> None: + num_experts = self.num_experts + intermediate_size = self.intermediate_size + hidden_size = self.hidden_size + sf_block_size = 32 + + # Shape assertions + assert ( + w13.dim() == 3 + and w13.shape[0] == num_experts + and w13.shape[1] == intermediate_size * 2 + and w13.shape[2] == hidden_size // 2 + ) + assert ( + w13_scale.dim() == 3 + and w13_scale.shape[0] == num_experts + and w13_scale.shape[1] == intermediate_size * 2 + and w13_scale.shape[2] == hidden_size // sf_block_size + ) + assert ( + w2.dim() == 3 + and w2.shape[0] == num_experts + and w2.shape[1] == hidden_size + and w2.shape[2] == intermediate_size // 2 + ) + assert ( + w2_scale.dim() == 3 + and w2_scale.shape[1] == hidden_size + and w2_scale.shape[2] == intermediate_size // sf_block_size + ) + if w13_bias is not None: + assert ( + w13_bias.dim() == 2 + and w13_bias.shape[0] == num_experts + and w13_bias.shape[1] == intermediate_size * 2 + ) + if w2_bias is not None: + assert ( + w2_bias.dim() == 2 + and w2_bias.shape[0] == num_experts + and w2_bias.shape[1] == hidden_size + ) + + # Convert weights to kernel format + w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( + convert_weight_to_mxfp4_moe_kernel_format( + mxfp4_backend=self.mxfp4_backend, + layer=layer, + w13_weight=w13, + w2_weight=w2, + w13_weight_scale=w13_scale, + w2_weight_scale=w2_scale, + w13_bias=w13_bias, + w2_bias=w2_bias, + _cache_permute_indices=self._cache_permute_indices, + ) + ) + + # For TRITON backends, weights are wrapped tensors from triton_kernels + # that don't support .detach(). Manually assign parameters. + if self.mxfp4_backend not in TRITON_BACKENDS: + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight_scale", w2_scale) + else: + layer.w13_weight = w13 + layer.w2_weight = w2 + self.w13_precision_config = w13_scale + self.w2_precision_config = w2_scale + + if w13_bias is not None and w2_bias is not None: + replace_parameter(layer, "w13_bias", w13_bias) + replace_parameter(layer, "w2_bias", w2_bias) + + # Build quant config + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + + # Build kernel (modular or monolithic) + if self.moe_quant_config is not None and self.experts_cls is not None: + self.moe_kernel = make_mxfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + mxfp4_backend=self.mxfp4_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, + ) + if not self.is_monolithic: + self.moe_kernel.reserve_workspace( + layer.w13_weight, + layer.w2_weight, + layer.global_num_experts, + layer.activation, + max_num_tokens=self._workspace_reserve_num_tokens(), + ) def process_weights_after_loading(self, layer): w13 = layer.w13_weight @@ -384,6 +750,7 @@ def get_fused_moe_quant_config( w2_scale = layer.w2_weight_scale w1_bias = getattr(layer, "w13_bias", None) w2_bias = getattr(layer, "w2_bias", None) + swiglu_limit = getattr(layer, "swiglu_limit", None) if self.mxfp4_backend in TRITON_BACKENDS: assert self.w13_precision_config is not None @@ -397,6 +764,7 @@ def get_fused_moe_quant_config( w2_scale=w2_scale, w1_bias=w1_bias, w2_bias=w2_bias, + swiglu_limit=swiglu_limit, ) def select_gemm_impl( @@ -437,6 +805,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/online/moe_base.py b/vllm/model_executor/layers/quantization/online/moe_base.py index 25c3359ee8be..417ce1770f9e 100644 --- a/vllm/model_executor/layers/quantization/online/moe_base.py +++ b/vllm/model_executor/layers/quantization/online/moe_base.py @@ -130,6 +130,7 @@ def apply_monolithic( layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 64753a173dfe..c50d4396ee39 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1457,6 +1457,7 @@ def apply_monolithic( layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, ) -> torch.Tensor: assert self.is_monolithic assert self.moe_kernel is not None diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 19fdb1ec884d..c6473c406c92 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -149,6 +149,148 @@ def _per_token_group_quant_fp8( tl.store(y_s_ptr, y_s) +@triton.jit +def _silu_mul_quant_fp8_packed_kernel( + input_ptr, + output_q_ptr, + output_scale_ptr, + M, + input_stride_m, + output_q_stride_m, + output_scale_stride_k, + clamp_limit, + N: tl.constexpr, + NUM_GROUPS: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + HAS_CLAMP: tl.constexpr, +): + N_2: tl.constexpr = N // 2 + + pid_pack = tl.program_id(0) + pid_m = tl.program_id(1) + m_offset = pid_m * BLOCK_M + + if m_offset >= M: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, GROUP_SIZE) + row_mask = (m_offset + offs_m) < M + + base_row_offset = (m_offset + offs_m[:, None]) * input_stride_m + base_out_offset = (m_offset + offs_m[:, None]) * output_q_stride_m + + packed_scale = tl.zeros((BLOCK_M,), dtype=tl.int32) + + for pack_idx in tl.static_range(4): + group_id = pid_pack * 4 + pack_idx + + if group_id < NUM_GROUPS: + n_offset = group_id * GROUP_SIZE + + act_ptrs = input_ptr + base_row_offset + n_offset + offs_n[None, :] + act_in = tl.load(act_ptrs, mask=row_mask[:, None], other=0.0) + + mul_ptrs = act_ptrs + N_2 + mul_in = tl.load(mul_ptrs, mask=row_mask[:, None], other=0.0) + + act_f32 = act_in.to(tl.float32) + mul_f32 = mul_in.to(tl.float32) + + if HAS_CLAMP: + act_f32 = tl.minimum(act_f32, clamp_limit) + mul_f32 = tl.clamp(mul_f32, -clamp_limit, clamp_limit) + + y = (act_f32 / (1.0 + tl.exp(-act_f32))) * mul_f32 + # Round through bf16 to match unfused precision path + y = y.to(tl.bfloat16).to(tl.float32) + + absmax = tl.max(tl.abs(y), axis=1) + + scale_raw = tl.maximum(absmax / fp8_max, 1e-10) + exponent = tl.ceil(tl.log2(scale_raw)) + scale = tl.math.exp2(exponent) + + y_q = tl.clamp(y / scale[:, None], fp8_min, fp8_max) + + out_q_ptrs = output_q_ptr + base_out_offset + n_offset + offs_n[None, :] + tl.store( + out_q_ptrs, + y_q.to(output_q_ptr.dtype.element_ty), + mask=row_mask[:, None], + ) + + exponent_biased = tl.clamp(exponent + 127.0, 0.0, 255.0).to(tl.int32) + packed_scale = packed_scale | (exponent_biased << (pack_idx * 8)) + + scale_ptrs = output_scale_ptr + pid_pack * output_scale_stride_k + m_offset + offs_m + tl.store(scale_ptrs, packed_scale, mask=row_mask) + + +def silu_mul_quant_fp8_packed_triton( + input: torch.Tensor, + group_size: int = 128, + output_q: torch.Tensor | None = None, + clamp_limit: float | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert input.dim() == 2 + assert input.is_contiguous() + + M, N = input.shape + N_2 = N // 2 + + assert N_2 % group_size == 0 + + fp8_dtype = torch.float8_e4m3fn + finfo = torch.finfo(fp8_dtype) + fp8_min, fp8_max = finfo.min, finfo.max + + num_groups_per_row = N_2 // group_size + num_packed_groups = (num_groups_per_row + 3) // 4 + tma_aligned_M = ((M + 3) // 4) * 4 + + if output_q is None: + output_q = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device) + + output_scale_packed = torch.zeros( + (num_packed_groups, tma_aligned_M), + dtype=torch.int32, + device=input.device, + ).T[:M, :] + + BLOCK_M = 8 + grid = (num_packed_groups, (M + BLOCK_M - 1) // BLOCK_M) + + num_warps = max(4, group_size // 32) + num_stages = 2 + + has_clamp = clamp_limit is not None + _silu_mul_quant_fp8_packed_kernel[grid]( + input, + output_q, + output_scale_packed, + M, + input.stride(0), + output_q.stride(0), + output_scale_packed.stride(1), + clamp_limit if has_clamp else 0.0, + N=N, + NUM_GROUPS=num_groups_per_row, + fp8_min=fp8_min, + fp8_max=fp8_max, + GROUP_SIZE=group_size, + BLOCK_M=BLOCK_M, + HAS_CLAMP=has_clamp, + num_warps=num_warps, + num_stages=num_stages, + ) + + return output_q, output_scale_packed + + @triton.jit def _silu_mul_per_token_group_quant_fp8_colmajor( y_ptr, # [M, N] @@ -675,6 +817,35 @@ def get_w8a8_block_fp8_configs( return None +def _get_default_w8a8_block_fp8_config( + M: int, + block_n: int, + block_k: int, +) -> dict[str, Any]: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_n and + # BLOCK_SIZE_K must be divisible by block_k. + # M-aware tuning for low-M decode: BLOCK_SIZE_M=64 wastes most of the + # M-dim for single-request decode and short MTP-style draft batches. SM12x + # keeps benefiting from the low-M tile through M=32 on DeepSeek V4 shapes. + capability = current_platform.get_device_capability() + capability_major = getattr(capability, "major", None) + if capability_major is None and capability is not None: + capability_major = capability[0] + low_m_limit = 32 if capability_major == 12 else 8 + if low_m_limit >= M: + block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3) + else: + block_m, num_stages = 64, 2 + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": num_stages, + } + + def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -710,6 +881,12 @@ def w8a8_triton_block_scaled_mm( N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is not None: + if As.dtype == e8m0_dtype: + As = _upcast_e8m0_to_fp32(As) + if Bs.dtype == e8m0_dtype: + Bs = _upcast_e8m0_to_fp32(Bs) C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) @@ -719,17 +896,7 @@ def w8a8_triton_block_scaled_mm( # Get the optimal config if there is one config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: - # Default config - # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] - # BLOCK_SIZE_K must be divisible by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2, - } + config = _get_default_w8a8_block_fp8_config(M, block_size[0], block_size[1]) def grid(META): return ( @@ -823,19 +990,65 @@ def requant_weight_ue8m0_inplace( s_old.copy_(s_requant) +def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + """Upcast E8M0 (exponent-only) scale to float32. + + E8M0 stores only the 8-bit biased exponent (bias=127). To convert + to float32 we place those 8 bits into the exponent field of an + IEEE-754 float32 (bits 23-30) with sign=0 and mantissa=0. + """ + exp_bits = scale.view(torch.uint8).to(torch.int32) + fp32_bits = exp_bits << 23 + return fp32_bits.view(torch.float32) + + def deepgemm_post_process_fp8_weight_block( - wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool + wq: torch.Tensor, + ws: torch.Tensor, + quant_block_shape: tuple[int, ...], + use_e8m0: bool, + is_bmm: bool = False, + bmm_batch_size: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: assert wq.dtype == torch.float8_e4m3fn, ( "Expected quantized tensor dtype " f"to be torch.float8_e4m3fn, got {wq.dtype} instead." ) - assert ws.dtype == torch.float32, ( - f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead" - ) - if use_e8m0: - requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape) + if ws.dtype == torch.float8_e8m0fnu: + # Scales already in E8M0 from checkpoint — upcast to fp32 + # and skip requantization (weights already have power-of-two scales). + ws = _upcast_e8m0_to_fp32(ws) + else: + assert ws.dtype == torch.float32, ( + f"Expected tensor scales dtype to be torch.float32 or " + f"torch.float8_e8m0fnu, got {ws.dtype} instead" + ) + if use_e8m0: + requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape) + + if is_bmm: + # Reshape 2D weight/scale to 3D for grouped BMM (einsum): + # wq: (g*r, d) -> (g, r, d) + # ws: (g*r/128, d/128) -> (g, r/128, d/128) + g = bmm_batch_size + assert wq.ndim == 2 and ws.ndim == 2 + d = wq.size(1) + r = wq.size(0) // g + wq = wq.view(g, r, d) + ws = ws.view(g, r // quant_block_shape[0], d // quant_block_shape[1]) + # Pre-transform scale with recipe=(1, 128, 128) to broadcast + pack + # into TMA-aligned UE8M0 (INT32) layout. At runtime fp8_einsum uses + # recipe=(1, 1, 128) which sees INT dtype and skips re-transform. + dg_ws = transform_sf_into_required_layout( + sf=ws, + mn=r, + k=d, + recipe=(1, quant_block_shape[0], quant_block_shape[1]), + num_groups=g, + is_sfa=False, + ) + return wq, dg_ws original_ndim = wq.ndim if wq.ndim == 2: @@ -984,11 +1197,13 @@ def create_fp8_scale_parameter( input_size_per_partition: int, block_size: list[int] | None, weight_loader: Callable | None, + scale_dtype: torch.dtype | None = None, ) -> torch.nn.Parameter: """Create scale parameter based on quantization strategy.""" + dtype = scale_dtype if scale_dtype is not None else torch.float32 if parameter_type == ChannelQuantScaleParameter: scale = parameter_type( - data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes), 1), dtype=dtype), output_dim=0, weight_loader=weight_loader, ) @@ -1000,7 +1215,7 @@ def create_fp8_scale_parameter( data=torch.empty( (output_size_per_partition + block_n - 1) // block_n, (input_size_per_partition + block_k - 1) // block_k, - dtype=torch.float32, + dtype=dtype, ), input_dim=1, output_dim=0, @@ -1008,13 +1223,14 @@ def create_fp8_scale_parameter( ) elif parameter_type == PerTensorScaleParameter: scale = parameter_type( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + data=torch.empty(len(output_partition_sizes), dtype=dtype), weight_loader=weight_loader, ) else: raise ValueError(f"Unknown parameter type: {parameter_type}") - scale[:] = torch.finfo(torch.float32).min + if dtype == torch.float32: + scale[:] = torch.finfo(torch.float32).min set_weight_attrs(scale, {"scale_type": "weight_scale"}) return scale diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 9a541877575d..b955b37be603 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -7,7 +7,10 @@ import torch from .base import RotaryEmbedding -from .deepseek_scaling_rope import DeepseekScalingRotaryEmbedding +from .deepseek_scaling_rope import ( + DeepseekScalingRotaryEmbedding, + DeepseekV4ScalingRotaryEmbedding, +) from .dual_chunk_rope import DualChunkRotaryEmbedding from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding @@ -60,11 +63,13 @@ def get_rope( rope_parameters = rope_parameters or {} base = rope_parameters.get("rope_theta", 10000) scaling_type = rope_parameters.get("rope_type", "default") - partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) - - if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0: - raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0") - rotary_dim = int(head_size * partial_rotary_factor) + if rotary_dim := rope_parameters.get("rope_dim", None): + pass + else: + partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) + if partial_rotary_factor <= 0.0 or partial_rotary_factor > 1.0: + raise ValueError(f"{partial_rotary_factor=} must be between 0.0 and 1.0") + rotary_dim = int(head_size * partial_rotary_factor) key = ( head_size, @@ -289,7 +294,11 @@ def get_rope( "mscale_all_dim", ) } - rotary_emb = DeepseekScalingRotaryEmbedding( + if rope_parameters.get("is_deepseek_v4", False): + cls = DeepseekV4ScalingRotaryEmbedding + else: + cls = DeepseekScalingRotaryEmbedding + rotary_emb = cls( head_size, rotary_dim, original_max_position, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 69c1101664d0..6cb9101a78b1 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -132,10 +132,8 @@ def forward_native( ] cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: - # NOTE(woosuk): Here we assume that the positions tensor has the - # shape [batch_size, seq_len]. - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) + cos = torch.cat((cos, cos), dim=-1).unsqueeze(-2) + sin = torch.cat((sin, sin), dim=-1).unsqueeze(-2) else: cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) @@ -197,3 +195,118 @@ def forward_cuda( return query, key else: return self.forward_native(positions, query, key, offsets) + + +class DeepseekV4ScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + + Compared to DeepseekScalingRotaryEmbedding: + - Applies RoPE to the last rotary_dim + - The forward method requires an inverse parameter to indicate + whether to negate the sin + - Supports applying RoPE to query only (without key) + - cos_sin_cache stored as fp32 for higher precision RoPE + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + cache_fp32 = self._compute_cos_sin_cache() + self.register_buffer("cos_sin_cache", cache_fp32, persistent=False) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=current_platform.device_type, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + inverse: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """PyTorch-native implementation equivalent to forward().""" + + head_size = query.size(-1) + query_rot = query[..., -self.rotary_dim :] + key_rot = key[..., -self.rotary_dim :] if key is not None else None + + if self.rotary_dim < head_size: + query_pass = query[..., : -self.rotary_dim] + key_pass = key[..., : -self.rotary_dim] if key is not None else None + + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + cos = torch.cat((cos, cos), dim=-1).unsqueeze(-2) + sin = torch.cat((sin, sin), dim=-1).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + if inverse: + sin = -sin + rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj + orig_dtype = query.dtype + query_rot = (query_rot * cos + rotate_fn(query_rot) * sin).to(orig_dtype) + if key_rot is not None: + key_rot = (key_rot * cos + rotate_fn(key_rot) * sin).to(orig_dtype) + + if self.rotary_dim < head_size: + query = torch.cat((query_pass, query_rot), dim=-1) + key = torch.cat((key_pass, key_rot), dim=-1) if key is not None else None + else: + query = query_rot + key = key_rot + + return query, key + + def forward_hip( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + inverse: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return self.forward_native(positions, query, key, offsets) + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + inverse: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from vllm import _custom_ops as ops + + # The indexer and attention have different head_dim, + # we obtain the corresponding head_dim via the query. + head_size = query.size(-1) + rope_dim_offset = head_size - self.rotary_dim + # ops.rotary_embedding() is an in-place operation + # that updates the query and key tensors. + ops.rotary_embedding( + torch.add(positions, offsets) if offsets is not None else positions, + query, + key, + head_size, + self.cos_sin_cache, + self.is_neox_style, + rope_dim_offset=rope_dim_offset, + inverse=inverse, + ) + return query, key diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 1f1f7e7df89f..62cfb817dcd3 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -4,13 +4,16 @@ import torch -import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform -from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, has_deep_gemm +from vllm.utils.deep_gemm import ( + fp8_fp4_mqa_logits, + fp8_fp4_paged_mqa_logits, + has_deep_gemm, +) from vllm.utils.torch_utils import ( LayerNameType, _encode_layer_name, @@ -19,6 +22,7 @@ ) from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, + sparse_indexer_max_logits_bytes, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager @@ -31,13 +35,90 @@ logger = init_logger(__name__) RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 +SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH = 4096 +SM120_SHORT_ROW_TOPK_MAX_WIDTH = 12288 +SM120_SHORT_ROW_TOPK_MAX_ROWS = 16 + +# MXFP4 layout: 2 values packed per byte, ue8m0 (1-byte) scale per block of 32. +MXFP4_BLOCK_SIZE = 32 + + +def _should_use_sm120_short_row_topk_decode( + topk_tokens: int, + logits_width: int, + num_rows: int, + is_cuda_sm120: bool, +) -> bool: + if not is_cuda_sm120 or topk_tokens != 512: + return False + if logits_width <= SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH: + return True + return ( + logits_width < SM120_SHORT_ROW_TOPK_MAX_WIDTH + and num_rows <= SM120_SHORT_ROW_TOPK_MAX_ROWS + ) + + +def _use_sm120_short_row_topk_decode( + logits: torch.Tensor, + topk_tokens: int, +) -> bool: + return _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits.shape[1], + logits.shape[0], + current_platform.is_cuda() + and current_platform.is_device_capability_family(120), + ) + + +def _gather_workspace_shapes( + total_seq_lens: int, + head_dim: int, + fp8_dtype: torch.dtype, + use_fp4_cache: bool, +) -> tuple[tuple[tuple[int, int], torch.dtype], tuple[tuple[int, int], torch.dtype]]: + """Return ((values_shape, values_dtype), (scales_shape, scales_dtype)) for + the K-gather workspace. FP8 path: (T, head_dim) fp8 + (T, 4) uint8 fp32 + scales. MXFP4 path: (T, head_dim // 2) uint8 packed mxfp4 + + (T, head_dim // MXFP4_BLOCK_SIZE) uint8 ue8m0 scales.""" + if use_fp4_cache: + return ( + ((total_seq_lens, head_dim // 2), torch.uint8), + ((total_seq_lens, head_dim // MXFP4_BLOCK_SIZE), torch.uint8), + ) + return ( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + + +def kv_cache_as_quant_view( + kv_cache: torch.Tensor, + head_dim: int, + use_fp4_cache: bool, +) -> torch.Tensor: + """4D ``[num_blocks, block_size, 1, head_width]`` view expected by + DeepGEMM, from the 3D indexer kv-cache allocation.""" + if use_fp4_cache: + assert kv_cache.ndim == 3 and kv_cache.dtype == torch.uint8 + num_blocks, block_size, _ = kv_cache.shape + page_bytes = int(kv_cache.stride(0)) + fp4_bytes = head_dim // 2 + head_dim // MXFP4_BLOCK_SIZE + return torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, fp4_bytes), + stride=(page_bytes, fp4_bytes, fp4_bytes, 1), + ) + return kv_cache.unsqueeze(-2) def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, kv_cache: torch.Tensor, - q_fp8: torch.Tensor, + q_quant: torch.Tensor, + q_scale: torch.Tensor | None, k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, @@ -47,6 +128,8 @@ def sparse_attn_indexer( max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor, + skip_k_cache_insert: bool, + use_fp4_cache: bool = False, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata @@ -56,15 +139,18 @@ def sparse_attn_indexer( # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): # Reserve workspace for indexer during profiling run + values_spec, scales_spec = _gather_workspace_shapes( + total_seq_lens, head_dim, fp8_dtype, use_fp4_cache + ) current_workspace_manager().get_simultaneous( - ((total_seq_lens, head_dim), torch.float8_e4m3fn), - ((total_seq_lens, 4), torch.uint8), + values_spec, + scales_spec, ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), ) # Dummy allocation to simulate for peak logits tensor memory during inference. # FP8 elements so elements == bytes - max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_elems = sparse_indexer_max_logits_bytes() _ = torch.empty( max_logits_elems, dtype=torch.uint8, device=hidden_states.device ) @@ -73,7 +159,8 @@ def sparse_attn_indexer( hidden_states, k_cache_prefix, kv_cache, - q_fp8, + q_quant, + q_scale, k, weights, quant_block_size, @@ -83,6 +170,8 @@ def sparse_attn_indexer( max_model_len, total_seq_lens, topk_indices_buffer, + skip_k_cache_insert, + use_fp4_cache, ) attn_metadata_narrowed = attn_metadata[k_cache_prefix] assert isinstance(attn_metadata_narrowed, DeepseekV32IndexerMetadata) @@ -91,49 +180,81 @@ def sparse_attn_indexer( has_prefill = attn_metadata_narrowed.num_prefills > 0 num_decode_tokens = attn_metadata_narrowed.num_decode_tokens + # q_scale is required iff the FP4 cache path is enabled; the FP8 path + # folds the Q scale into `weights` inside fused_indexer_q_rope_quant. + if use_fp4_cache: + assert q_scale is not None, "use_fp4_cache=True requires q_scale" + else: + assert q_scale is None, "q_scale must be None when use_fp4_cache=False" + # During speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. Truncate k to avoid # out-of-bounds reads in the kernel. num_tokens = slot_mapping.shape[0] - k = k[:num_tokens] - - # scale_fmt can be None, but the function expects str - assert scale_fmt is not None - ops.indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) + if k is not None: + k = k[:num_tokens] + + if not skip_k_cache_insert: + # scale_fmt can be None, but the function expects str + assert scale_fmt is not None + assert not use_fp4_cache, "Unfused FP4 Insert is not supported yet" + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata_narrowed.prefill assert prefill_metadata is not None - # Get the full shared workspace buffers once (will allocate on first use) + # Get the full shared workspace buffers once (will allocate on first use). + # Layout switches between FP8 (head_dim bytes + 4-byte fp32 scale) and + # MXFP4 (head_dim/2 bytes packed + head_dim/MXFP4_BLOCK_SIZE ue8m0 + # scales) based on use_fp4_cache. workspace_manager = current_workspace_manager() - k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( - ((total_seq_lens, head_dim), fp8_dtype), - ((total_seq_lens, 4), torch.uint8), + values_spec, scales_spec = _gather_workspace_shapes( + total_seq_lens, head_dim, fp8_dtype, use_fp4_cache + ) + k_quant_full, k_scale_full = workspace_manager.get_simultaneous( + values_spec, + scales_spec, ) for chunk in prefill_metadata.chunks: - k_fp8 = k_fp8_full[: chunk.total_seq_lens] + k_quant = k_quant_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens] if not chunk.skip_kv_gather: ops.cp_gather_indexer_k_quant_cache( kv_cache, - k_fp8, + k_quant, k_scale, chunk.block_table, chunk.cu_seq_lens, ) - logits = fp8_mqa_logits( - q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale.view(torch.float32).flatten()), + q_slice = q_quant[chunk.token_start : chunk.token_end] + q_scale_slice = ( + q_scale[chunk.token_start : chunk.token_end] + if q_scale is not None + else None + ) + # DeepGEMM scalar-type tags (zero-copy): MXFP4 values → int8 + # (kPackedFP4), scales → int32 squeezed to 1-D kv_sf / 2-D q_sf. + if use_fp4_cache: + q_slice_cast = q_slice.view(torch.int8) + k_quant_cast = k_quant.view(torch.int8) + k_scale_cast = k_scale.view(torch.int32).squeeze(-1) + else: + q_slice_cast = q_slice + k_quant_cast = k_quant + k_scale_cast = k_scale.view(torch.float32).squeeze(-1) + logits = fp8_fp4_mqa_logits( + (q_slice_cast, q_scale_slice), + (k_quant_cast, k_scale_cast), weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, @@ -171,32 +292,55 @@ def sparse_attn_indexer( if has_decode: decode_metadata = attn_metadata_narrowed.decode assert decode_metadata is not None - # kv_cache shape [ - # kv_cache size requirement [num_block, block_size, n_head, head_dim], - # we only have [num_block, block_size, head_dim], - kv_cache = kv_cache.unsqueeze(-2) + kv_cache = kv_cache_as_quant_view(kv_cache, head_dim, use_fp4_cache) decode_lens = decode_metadata.decode_lens if decode_metadata.requires_padding: # pad in edge case where we have short chunked prefill length < # decode_threshold since we unstrictly split # prefill and decode by decode_threshold - # (currently set to 1 + speculative tokens) - padded_q_fp8_decode_tokens = pack_seq_triton( - q_fp8[:num_decode_tokens], decode_lens - ) + # (currently set to 1 + speculative tokens). + # FP8 Q is float8_e4m3fn (pack_seq_triton's fp32 pad path is OK — + # downstream context_lens masks stale slots). MXFP4 Q is two + # uint8 tensors (values + ue8m0 scales) — use the dedicated uint8 + # packer with pad_byte=0 so padded slots dequantize to 0 and + # can't produce NaN/Inf in the logits kernel. + if q_scale is not None: + padded_q_quant_decode_tokens = pack_seq_triton( + q_quant[:num_decode_tokens], decode_lens, pad_value=0 + ) + padded_q_scale = pack_seq_triton( + q_scale[:num_decode_tokens], decode_lens, pad_value=0 + ) + else: + padded_q_quant_decode_tokens = pack_seq_triton( + q_quant[:num_decode_tokens], decode_lens + ) + padded_q_scale = None else: - padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_fp8.shape[1:] + padded_q_quant_decode_tokens = q_quant[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_quant.shape[1:] ) + if q_scale is not None: + padded_q_scale = q_scale[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_scale.shape[1:] + ) + else: + padded_q_scale = None # TODO: move and optimize below logic with triton kernels - batch_size = padded_q_fp8_decode_tokens.shape[0] - next_n = padded_q_fp8_decode_tokens.shape[1] + batch_size = padded_q_quant_decode_tokens.shape[0] + next_n = padded_q_quant_decode_tokens.shape[1] num_padded_tokens = batch_size * next_n seq_lens = decode_metadata.seq_lens[:batch_size] - # seq_lens is (B, next_n) for native spec decode, (B,) otherwise. - # fp8_paged_mqa_logits and all topk kernels accept both shapes. - logits = fp8_paged_mqa_logits( - padded_q_fp8_decode_tokens, + # seq_lens is always 2D: (B, next_n) for native spec decode, (B, 1) + # otherwise. deep_gemm fp8_fp4_paged_mqa_logits requires 2D context_lens; + # the downstream topk kernels accept both 1D and 2D. + padded_q_quant_cast = ( + padded_q_quant_decode_tokens.view(torch.int8) + if use_fp4_cache + else padded_q_quant_decode_tokens + ) + logits = fp8_fp4_paged_mqa_logits( + (padded_q_quant_cast, padded_q_scale), kv_cache, weights[:num_padded_tokens], seq_lens, @@ -208,7 +352,18 @@ def sparse_attn_indexer( num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - if current_platform.is_cuda(): + if _use_sm120_short_row_topk_decode(logits, topk_tokens): + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + elif current_platform.is_cuda() and topk_tokens in (512, 2048): workspace_manager = current_workspace_manager() (topk_workspace,) = workspace_manager.get_simultaneous( ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), @@ -263,7 +418,8 @@ def sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, kv_cache: torch.Tensor, - q_fp8: torch.Tensor, + q_quant: torch.Tensor, + q_scale: torch.Tensor | None, k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, @@ -273,6 +429,8 @@ def sparse_attn_indexer_fake( max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, + skip_k_cache_insert: bool, + use_fp4_cache: bool = False, ) -> torch.Tensor: return topk_indices_buffer @@ -309,6 +467,8 @@ def __init__( max_model_len: int, max_total_seq_len: int, topk_indices_buffer: torch.Tensor, + skip_k_cache_insert: bool = False, + use_fp4_cache: bool = False, ): super().__init__() self.k_cache = k_cache @@ -319,6 +479,8 @@ def __init__( self.max_model_len = max_model_len self.max_total_seq_len = max_total_seq_len self.topk_indices_buffer = topk_indices_buffer + self.skip_k_cache_insert = skip_k_cache_insert + self.use_fp4_cache = use_fp4_cache if current_platform.is_cuda() and not has_deep_gemm(): raise RuntimeError( "Sparse Attention Indexer CUDA op requires DeepGEMM to be installed." @@ -327,14 +489,14 @@ def __init__( def forward_native( self, hidden_states: torch.Tensor, - q_fp8: torch.Tensor, + q_quant: torch.Tensor | tuple[torch.Tensor, torch.Tensor], k: torch.Tensor, weights: torch.Tensor, ): if current_platform.is_cuda() or current_platform.is_xpu(): - return self.forward_cuda(hidden_states, q_fp8, k, weights) + return self.forward_cuda(hidden_states, q_quant, k, weights) elif current_platform.is_rocm(): - return self.forward_hip(hidden_states, q_fp8, k, weights) + return self.forward_hip(hidden_states, q_quant, k, weights) else: raise NotImplementedError( "SparseAttnIndexer native forward is only implemented for " @@ -344,15 +506,22 @@ def forward_native( def forward_cuda( self, hidden_states: torch.Tensor, - q_fp8: torch.Tensor, + q_quant: torch.Tensor | tuple[torch.Tensor, torch.Tensor], k: torch.Tensor, weights: torch.Tensor, ): + # FP8 path: single tensor (per-token scale is folded into `weights`). + # FP4 path: (values, scales) tuple with scales required by the kernel. + if isinstance(q_quant, tuple): + q_values, q_scale = q_quant + else: + q_values, q_scale = q_quant, None return torch.ops.vllm.sparse_attn_indexer( hidden_states, _encode_layer_name(self.k_cache.prefix), self.k_cache.kv_cache, - q_fp8, + q_values, + q_scale, k, weights, self.quant_block_size, @@ -362,21 +531,30 @@ def forward_cuda( self.max_model_len, self.max_total_seq_len, self.topk_indices_buffer, + self.skip_k_cache_insert, + self.use_fp4_cache, ) def forward_hip( self, hidden_states: torch.Tensor, - q_fp8: torch.Tensor, + q_quant: torch.Tensor | tuple[torch.Tensor, torch.Tensor], k: torch.Tensor, weights: torch.Tensor, ): + assert not self.skip_k_cache_insert, ( + "AMD platform doesn't support skip cache insert yet" + ) + assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet" + assert isinstance(q_quant, torch.Tensor), ( + "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" + ) if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, _encode_layer_name(self.k_cache.prefix), self.k_cache.kv_cache, - q_fp8, + q_quant, k, weights, self.quant_block_size, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 4918c83bdc39..e26b511de4ce 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -299,6 +299,13 @@ def cpu_unquantized_gemm( return layer.cpu_linear(x, weight, bias) +def cublas_gemm_bf16_bf16_fp32( + x: torch.Tensor, + weight: torch.Tensor, +): + return ops.router_gemm_bf16_fp32(x, weight) + + def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: if current_platform.is_rocm(): return rocm_unquantized_gemm diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 856f4b33ed3b..e8f5101b577d 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -107,6 +107,31 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: ) +class DeepseekV4ForCausalLMConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + quant_config = getattr(model_config.hf_config, "quantization_config", None) + if quant_config is not None and quant_config.get("quant_method") == "fp8": + model_type = getattr(model_config.hf_config, "model_type", None) + if model_type == "deepseek_v4": + model_config.hf_config.quantization_config["quant_method"] = ( + "deepseek_v4_fp8" + ) + + hf_text_quant_config = getattr( + model_config.hf_text_config, "quantization_config", None + ) + if ( + hf_text_quant_config is not None + and hf_text_quant_config.get("quant_method") == "fp8" + ): + model_type = getattr(model_config.hf_text_config, "model_type", None) + if model_type == "deepseek_v4": + model_config.hf_text_config.quantization_config["quant_method"] = ( + "deepseek_v4_fp8" + ) + + class GptOssForCausalLMConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_model_config(model_config: "ModelConfig") -> None: @@ -635,6 +660,7 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None: MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "ColBERTJinaRobertaModel": JinaRobertaModelConfig, "ColQwen3_5": Qwen3_5ForConditionalGenerationConfig, + "DeepseekV4ForCausalLM": DeepseekV4ForCausalLMConfig, "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501 "FalconMambaForCausalLM": MambaModelConfig, diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py new file mode 100644 index 000000000000..0ff1d7b5fd97 --- /dev/null +++ b/vllm/model_executor/models/deepseek_v4.py @@ -0,0 +1,1482 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import typing +from collections.abc import Callable, Iterable +from itertools import islice + +import regex as re +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm import envs +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.deepseek_v4_attention import ( + DeepseekV4Indexer, + DeepseekV4MLAModules, + DeepseekV4MultiHeadLatentAttentionWrapper, +) +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear +from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( + fused_topk_bias, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLP +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton +from vllm.utils.multi_stream_utils import AuxStreamType +from vllm.utils.torch_utils import direct_register_custom_op + +from .interfaces import SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + + +class DeepseekV4FP8Config(Fp8Config): + """FP8 config that routes MoE layers to MXFP4 quantization. + + DeepSeek V4 checkpoints use FP8 for linear/attention layers but + MXFP4 for MoE expert weights. This config inherits standard FP8 + behavior and overrides only the MoE dispatch. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_scale_e8m0: bool = True + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "deepseek_v4_fp8" + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant, hf_config=None + ) -> QuantizationMethods | None: + if not ( + isinstance(hf_quant_cfg, dict) + and hf_quant_cfg.get("quant_method") in ("fp8", "deepseek_v4_fp8") + ): + return None + model_type = getattr(hf_config, "model_type", None) + if model_type == "deepseek_v4" or user_quant == "deepseek_v4_fp8": + return "deepseek_v4_fp8" + return None + + def get_quant_method(self, layer, prefix): + if isinstance(layer, FusedMoE): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedFusedMoEMethod(layer.moe_config) + return Mxfp4MoEMethod(layer.moe_config) + return super().get_quant_method(layer, prefix) + + def is_mxfp4_quant(self, prefix, layer): + return isinstance(layer, FusedMoE) + + +@triton.jit +def _deepseek_v4_stage_mega_moe_inputs_kernel( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_stride_m: tl.constexpr, + hidden_stride_k: tl.constexpr, + x_stride_m: tl.constexpr, + x_stride_k: tl.constexpr, + x_sf_stride_m: tl.constexpr, + x_sf_stride_k: tl.constexpr, + topk_ids_stride_m: tl.constexpr, + topk_ids_stride_k: tl.constexpr, + topk_weights_stride_m: tl.constexpr, + topk_weights_stride_k: tl.constexpr, + topk_idx_stride_m: tl.constexpr, + topk_idx_stride_k: tl.constexpr, + topk_weights_out_stride_m: tl.constexpr, + topk_weights_out_stride_k: tl.constexpr, + hidden_size: tl.constexpr, + top_k: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_K: tl.constexpr, + BLOCK_TOPK: tl.constexpr, +) -> None: + token_id = tl.program_id(0) + k_block_id = tl.program_id(1) + + k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = k_offsets < hidden_size + hidden = tl.load( + hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, + mask=k_mask, + other=0.0, + ).to(tl.float32) + + num_groups: tl.constexpr = BLOCK_K // GROUP_K + hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(hidden_groups, axis=1) + amax = tl.maximum(amax, 1.0e-4) + + scale = amax / 448.0 + scale_bits = scale.to(tl.uint32, bitcast=True) + scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( + tl.uint32 + ) + scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) + rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) + + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + scaled = hidden_groups * (1.0 / rounded_scale)[:, None] + scaled = tl.reshape(scaled, [BLOCK_K]) + fp8 = scaled.to(tl.float8e4nv) + tl.store( + x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, + fp8, + mask=k_mask, + ) + + scale_offsets = tl.arange(0, num_groups) + packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) + tl.store( + x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, + packed_scale, + ) + + if k_block_id == 0: + topk_offsets = tl.arange(0, BLOCK_TOPK) + topk_mask = topk_offsets < top_k + + ids = tl.load( + topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, + mask=topk_mask, + other=0, + ).to(tl.int64) + tl.store( + topk_idx_out + + token_id * topk_idx_stride_m + + topk_offsets * topk_idx_stride_k, + ids, + mask=topk_mask, + ) + + weights = tl.load( + topk_weights + + token_id * topk_weights_stride_m + + topk_offsets * topk_weights_stride_k, + mask=topk_mask, + other=0.0, + ) + tl.store( + topk_weights_out + + token_id * topk_weights_out_stride_m + + topk_offsets * topk_weights_out_stride_k, + weights, + mask=topk_mask, + ) + + +def _stage_deepseek_v4_mega_moe_inputs( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + x_fp8: torch.Tensor, + x_sf: torch.Tensor, + topk_idx_out: torch.Tensor, + topk_weights_out: torch.Tensor, +) -> None: + num_tokens, hidden_size = hidden_states.shape + if num_tokens == 0: + return + if hidden_size % 128 != 0: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires hidden_size to be " + "a multiple of 128." + ) + top_k = topk_ids.shape[1] + if topk_weights.shape != topk_ids.shape: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires topk_weights and " + "topk_ids to have the same shape." + ) + + block_k = 128 + grid = (num_tokens, triton.cdiv(hidden_size, block_k)) + block_topk = triton.next_power_of_2(top_k) + _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_states.stride(0), + hidden_states.stride(1), + x_fp8.stride(0), + x_fp8.stride(1), + x_sf.stride(0), + x_sf.stride(1), + topk_ids.stride(0), + topk_ids.stride(1), + topk_weights.stride(0), + topk_weights.stride(1), + topk_idx_out.stride(0), + topk_idx_out.stride(1), + topk_weights_out.stride(0), + topk_weights_out.stride(1), + hidden_size, + top_k, + BLOCK_K=block_k, + GROUP_K=32, + BLOCK_TOPK=block_topk, + num_warps=4, + ) + + +def make_deepseek_v4_expert_params_mapping( + num_experts: int, +) -> list[tuple[str, str, int, str]]: + return [ + ( + "experts.w13_" if shard_id in ("w1", "w3") else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", "w1"), + ("w2", "w2"), + ("w3", "w3"), + ] + ] + + +class DeepseekV4MegaMoEExperts(nn.Module): + _symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {} + + def __init__( + self, + vllm_config: VllmConfig, + *, + num_experts: int, + num_local_experts: int, + experts_start_idx: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + prefix: str = "", + ): + super().__init__() + self.prefix = prefix + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.experts_start_idx = experts_start_idx + self.experts_end_idx = experts_start_idx + num_local_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + + weight_attrs = {"weight_loader": self.weight_loader} + self.w13_weight = nn.Parameter( + torch.zeros( + num_local_experts, + 2 * intermediate_size, + hidden_size // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs(self.w13_weight, weight_attrs) + + self.w13_weight_scale = nn.Parameter( + torch.zeros( + num_local_experts, + 2 * intermediate_size, + hidden_size // 32, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs(self.w13_weight_scale, weight_attrs) + self.w13_weight_scale.quant_method = "block" + + self.w2_weight = nn.Parameter( + torch.zeros( + num_local_experts, + hidden_size, + intermediate_size // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs(self.w2_weight, weight_attrs) + + self.w2_weight_scale = nn.Parameter( + torch.zeros( + num_local_experts, + hidden_size, + intermediate_size // 32, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs(self.w2_weight_scale, weight_attrs) + self.w2_weight_scale.quant_method = "block" + + self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None + self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None + + # Register in the static forward context so the custom-op wrapper + # can look up this module by name from within a torch.compile graph. + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def _map_global_expert_id(self, expert_id: int) -> int: + if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx: + return -1 + return expert_id - self.experts_start_idx + + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False, + ) -> bool | None: + local_expert_id = self._map_global_expert_id(expert_id) + if local_expert_id == -1: + return False if return_success else None + + expert_data = param.data[local_expert_id] + if shard_id in ("w1", "w3"): + if "w13_" not in weight_name: + return False if return_success else None + shard_offset = 0 if shard_id == "w1" else self.intermediate_size + expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size) + elif shard_id == "w2": + if "w2_" not in weight_name: + return False if return_success else None + else: + raise ValueError(f"Unsupported expert shard id: {shard_id}") + + if expert_data.shape != loaded_weight.shape: + raise ValueError( + f"DeepSeek V4 MegaMoE expert weight shape mismatch for " + f"{weight_name}: parameter shard {tuple(expert_data.shape)} " + f"vs checkpoint {tuple(loaded_weight.shape)}" + ) + expert_data.copy_(loaded_weight) + return True if return_success else None + + @staticmethod + def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor: + return (sf.to(torch.int32) << 23).view(torch.float32) + + def _check_runtime_supported(self) -> None: + if not torch.cuda.is_available(): + raise NotImplementedError("DeepSeek V4 MegaMoE requires CUDA.") + device = self.w13_weight.device + if device.type != "cuda": + raise NotImplementedError( + "DeepSeek V4 MegaMoE expert weights must be loaded on CUDA." + ) + if torch.cuda.get_device_capability(device)[0] != 10: + raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.") + if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0: + raise ValueError( + "DeepGEMM MegaMoE requires hidden and intermediate sizes " + "to be multiples of 128." + ) + + def finalize_weights(self) -> None: + if self._transformed_l1_weights is not None: + return + + self._check_runtime_supported() + import vllm.third_party.deep_gemm as deep_gemm + + w13_scale = deep_gemm.transform_sf_into_required_layout( + self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(), + 2 * self.intermediate_size, + self.hidden_size, + (1, 32), + self.num_local_experts, + ) + w2_scale = deep_gemm.transform_sf_into_required_layout( + self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(), + self.hidden_size, + self.intermediate_size, + (1, 32), + self.num_local_experts, + ) + self._transformed_l1_weights, self._transformed_l2_weights = ( + deep_gemm.transform_weights_for_mega_moe( + (self.w13_weight.data.view(torch.int8).contiguous(), w13_scale), + (self.w2_weight.data.view(torch.int8).contiguous(), w2_scale), + ) + ) + # Drop the original loader-side parameters: the MegaMoE kernels only + # consume the transformed views above. transform_weights_for_mega_moe + # allocates a fresh tensor for the L1 weight (see _interleave_l1_weights) + # and fresh SF tensors for L1/L2; the L2 weight is the only tensor that + # aliases the original storage, and _transformed_l2_weights still holds + # it, so the storage stays live after we drop the Parameter. + self.w13_weight = None + self.w13_weight_scale = None + self.w2_weight = None + self.w2_weight_scale = None + + def get_symm_buffer(self): + import vllm.third_party.deep_gemm as deep_gemm + + group = get_ep_group().device_group + device = torch.accelerator.current_device_index() + key = ( + id(group), + device, + self.num_experts, + self.max_num_tokens, + self.top_k, + self.hidden_size, + self.intermediate_size, + ) + symm_buffer = self._symm_buffer_cache.get(key) + if symm_buffer is None: + symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, + self.num_experts, + self.max_num_tokens, + self.top_k, + self.hidden_size, + self.intermediate_size, + ) + self._symm_buffer_cache[key] = symm_buffer + return symm_buffer + + def forward( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + *, + activation_clamp: float | None, + fast_math: bool = True, + ) -> torch.Tensor: + if hidden_states.shape[0] > self.max_num_tokens: + raise ValueError( + f"DeepSeek V4 MegaMoE got {hidden_states.shape[0]} tokens, " + f"but the symmetric buffer was sized for {self.max_num_tokens}." + ) + y = torch.empty_like(hidden_states, dtype=torch.bfloat16) + torch.ops.vllm.deepseek_v4_mega_moe_experts( + hidden_states, + topk_weights, + topk_ids, + y, + self.prefix, + activation_clamp, + fast_math, + ) + return y + + def _run_mega_moe( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + y: torch.Tensor, + activation_clamp: float | None, + fast_math: bool, + ) -> None: + import vllm.third_party.deep_gemm as deep_gemm + + symm_buffer = self.get_symm_buffer() + num_tokens = hidden_states.shape[0] + _stage_deepseek_v4_mega_moe_inputs( + hidden_states, + topk_weights, + topk_ids, + symm_buffer.x[:num_tokens], + symm_buffer.x_sf[:num_tokens], + symm_buffer.topk_idx[:num_tokens], + symm_buffer.topk_weights[:num_tokens], + ) + + # This method must have been already called duing the weight loading phase. + # We call it again here to cover the dummy weight loading case. + self.finalize_weights() + + assert self._transformed_l1_weights is not None + assert self._transformed_l2_weights is not None + deep_gemm.fp8_fp4_mega_moe( + y, + self._transformed_l1_weights, + self._transformed_l2_weights, + symm_buffer, + activation_clamp=activation_clamp, + fast_math=fast_math, + ) + + +DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] + + +def _deepseek_v4_mega_moe_experts_op( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + out: torch.Tensor, + layer_name: str, + activation_clamp: float | None, + fast_math: bool, +) -> None: + self = get_forward_context().no_compile_layers[layer_name] + self._run_mega_moe( + hidden_states, + topk_weights, + topk_ids, + out, + activation_clamp, + fast_math, + ) + + +def _deepseek_v4_mega_moe_experts_op_fake( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + out: torch.Tensor, + layer_name: str, + activation_clamp: float | None, + fast_math: bool, +) -> None: + return None + + +direct_register_custom_op( + op_name="deepseek_v4_mega_moe_experts", + op_func=_deepseek_v4_mega_moe_experts_op, + mutates_args=["out"], + fake_impl=_deepseek_v4_mega_moe_experts_op_fake, +) + + +class DeepseekV4MoE(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + self.tp_size = get_tensor_model_parallel_world_size() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.prefix = prefix + if vllm_config.parallel_config.enable_expert_parallel: + self.use_mega_moe = envs.VLLM_DEEPSEEK_V4_USE_MEGA_MOE + else: + self.use_mega_moe = False + + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) + self.hidden_size = config.hidden_size + + self.n_routed_experts = config.n_routed_experts + self.n_activated_experts = config.num_experts_per_tok + self.moe_intermediate_size = config.moe_intermediate_size + self.swiglu_limit = config.swiglu_limit + self.renormalize = config.norm_topk_prob + self.scoring_func = getattr(config, "scoring_func", "sqrtsoftplus") + if self.use_mega_moe and self.scoring_func != "sqrtsoftplus": + raise NotImplementedError( + "DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only." + ) + + self.gate = GateLinear( + config.hidden_size, + config.n_routed_experts, + out_dtype=torch.float32, + bias=False, + prefix=f"{prefix}.gate", + ) + self.gate.e_score_correction_bias = None + self.gate.tid2eid = None + is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers + self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32 + + if is_hash_moe: + # hash MoE doesn't use e_score_correction_bias + # Use randint instead of empty to avoid garbage values causing + # invalid memory access in dummy mode (--load-format="dummy") + self.gate.tid2eid = nn.Parameter( + torch.randint( + 0, + config.n_routed_experts, + (config.vocab_size, config.num_experts_per_tok), + dtype=self.hash_indices_dtype, + ), + requires_grad=False, + ) + elif getattr(config, "topk_method", None) == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32), + requires_grad=False, + ) + + if config.n_shared_experts is None: + self.shared_experts = None + else: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.use_mega_moe, + prefix=f"{prefix}.shared_experts", + ) + + if self.use_mega_moe: + self._init_mega_moe_experts(vllm_config, config, prefix) + else: + self._init_fused_moe_experts(config, quant_config, prefix) + + def _init_mega_moe_experts( + self, + vllm_config: VllmConfig, + config, + prefix: str, + ) -> None: + self.ep_group = get_ep_group() + self.ep_size = self.ep_group.world_size + self.ep_rank = self.ep_group.rank_in_group + assert config.n_routed_experts % self.ep_size == 0 + + self.n_local_experts = config.n_routed_experts // self.ep_size + self.experts_start_idx = self.ep_rank * self.n_local_experts + self.experts_end_idx = self.experts_start_idx + self.n_local_experts + + self.experts = DeepseekV4MegaMoEExperts( + vllm_config, + num_experts=config.n_routed_experts, + num_local_experts=self.n_local_experts, + experts_start_idx=self.experts_start_idx, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + prefix=f"{prefix}.experts", + ) + + def _init_fused_moe_experts( + self, + config, + quant_config, + prefix: str, + ) -> None: + self.tp_rank = get_tensor_model_parallel_rank() + assert config.n_routed_experts % self.tp_size == 0 + + self.n_local_experts = config.n_routed_experts // self.tp_size + self.experts_start_idx = self.tp_rank * self.n_local_experts + self.experts_end_idx = self.experts_start_idx + self.n_local_experts + + self.experts = FusedMoE( + shared_experts=self.shared_experts, + gate=self.gate, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + hash_indices_table=self.gate.tid2eid, + swiglu_limit=self.swiglu_limit, + router_logits_dtype=torch.float32, + ) + + def forward( + self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None + ) -> torch.Tensor: + if self.gate.tid2eid is not None: + if input_ids is None: + raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.") + input_ids = input_ids.to(dtype=self.hash_indices_dtype) + if not self.use_mega_moe: + return self._forward_fused_moe(hidden_states, input_ids) + + org_shape = hidden_states.shape + router_logits, _ = self.gate(hidden_states) + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=router_logits, + scoring_func=self.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias.data + if self.gate.e_score_correction_bias is not None + else None, + topk=self.n_activated_experts, + renormalize=self.renormalize, + indices_type=self.hash_indices_dtype, + input_tokens=input_ids, + hash_indices_table=self.gate.tid2eid, + routed_scaling_factor=self.routed_scaling_factor, + ) + activation_clamp = ( + float(self.swiglu_limit) if self.swiglu_limit is not None else None + ) + final_hidden_states = self.experts( + hidden_states, + topk_weights, + topk_ids, + activation_clamp=activation_clamp, + ) + + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + final_hidden_states += shared_output + + return final_hidden_states.view(org_shape) + + def _forward_fused_moe( + self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None + ) -> torch.Tensor: + org_shape = hidden_states.shape + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=hidden_states, + input_ids=input_ids, + ) + else: + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + input_ids=input_ids, + ) + + return final_hidden_states.view(org_shape) + + def finalize_mega_moe_weights(self) -> None: + if self.use_mega_moe: + self.experts.finalize_weights() + + +class DeepseekV4Attention(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + topk_indices_buffer: torch.Tensor | None = None, + aux_stream: torch.cuda.Stream | None = None, + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + layer_id = extract_layer_index(prefix) + + self.layer_id = layer_id + self.hidden_size = config.hidden_size + self.n_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + assert self.n_heads % tp_size == 0 + + self.n_local_heads = self.n_heads // tp_size + self.q_lora_rank = config.q_lora_rank + self.o_lora_rank = config.o_lora_rank + self.head_dim = config.head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.nope_head_dim = self.head_dim - self.rope_head_dim + self.n_groups = config.o_groups + self.n_local_groups = self.n_groups // tp_size + self.window_size = config.sliding_window + # NOTE(zyongye) Compress ratio can't be 0 + # we do this for because MTP layer is not included + # in the compress ratio list + if layer_id < config.num_hidden_layers: + self.compress_ratio = max(1, config.compress_ratios[layer_id]) + else: + self.compress_ratio = 1 + self.eps = config.rms_norm_eps + self.max_position_embeddings = config.max_position_embeddings + + # Padded to min 64 heads for FlashMLA, initialized to -inf + # (no sink effect). Weight loading fills the first n_local_heads slots. + padded_heads = max(self.n_local_heads, 64) + self.attn_sink = nn.Parameter( + torch.full((padded_heads,), -float("inf"), dtype=torch.float32), + requires_grad=False, + ) + + self.fused_wqa_wkv = MergedColumnParallelLinear( + self.hidden_size, + [self.q_lora_rank, self.head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_wqa_wkv", + disable_tp=True, # fused ReplicatedLinear + ) + self.q_norm = RMSNorm(self.q_lora_rank, self.eps) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wq_b", + ) + + self.kv_norm = RMSNorm(self.head_dim, self.eps) + self.wo_a = ColumnParallelLinear( + self.n_heads * self.head_dim // self.n_groups, + self.n_groups * self.o_lora_rank, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wo_a", + ) + self.wo_a.is_bmm = True + self.wo_a.bmm_batch_size = self.n_local_groups + self.wo_b = RowParallelLinear( + self.n_groups * self.o_lora_rank, + self.hidden_size, + bias=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.wo_b", + ) + self.softmax_scale = self.head_dim**-0.5 + self.scale_fmt = config.quantization_config["scale_fmt"] + + self.rope_parameters = config.rope_scaling + + # Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it) + rope_parameters = config.rope_parameters + rope_parameters["rope_theta"] = ( + config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta + ) + if config.rope_parameters["rope_type"] != "default": + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) + rope_parameters["mscale"] = 0 # Disable mscale + rope_parameters["mscale_all_dim"] = 0 # Disable mscale + rope_parameters["is_deepseek_v4"] = True + rope_parameters["rope_dim"] = self.rope_head_dim + self.rotary_emb = get_rope( + self.head_dim, + max_position=self.max_position_embeddings, + rope_parameters=rope_parameters, + is_neox_style=False, + dtype=config.torch_dtype, + ) + + self.indexer = None + if self.compress_ratio == 4: + # Only C4A uses sparse attention and hence has indexer. + self.indexer = DeepseekV4Indexer( + vllm_config, + config=config, + hidden_size=self.hidden_size, + q_lora_rank=self.q_lora_rank, + quant_config=quant_config, + cache_config=vllm_config.cache_config, + topk_indices_buffer=topk_indices_buffer, + compress_ratio=self.compress_ratio, + prefix=f"{prefix}.indexer", + ) + + mla_modules = DeepseekV4MLAModules( + vllm_config=vllm_config, + fused_wqa_wkv=self.fused_wqa_wkv, + q_norm=self.q_norm, + wq_b=self.wq_b, + kv_norm=self.kv_norm, + wo_a=self.wo_a, + wo_b=self.wo_b, + attn_sink=self.attn_sink, + rotary_emb=self.rotary_emb, + indexer=self.indexer, + indexer_rotary_emb=self.rotary_emb, + topk_indices_buffer=topk_indices_buffer, + aux_stream=aux_stream, + ) + self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper( + hidden_size=self.hidden_size, + num_heads=self.n_local_heads, + head_dim=self.head_dim, + scale=self.softmax_scale, + qk_nope_head_dim=self.nope_head_dim, + qk_rope_head_dim=self.rope_head_dim, + v_head_dim=self.head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.head_dim, + o_lora_rank=self.o_lora_rank, + mla_modules=mla_modules, + window_size=self.window_size, + compress_ratio=self.compress_ratio, + cache_config=vllm_config.cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, + ): + return self.mla_attn(positions, hidden_states, llama_4_scaling) + + +class DeepseekV4DecoderLayer(nn.Module): + def __init__( + self, + vllm_config, + prefix, + topk_indices_buffer: torch.Tensor | None = None, + aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream] | None = None, + ): + super().__init__() + config = vllm_config.model_config.hf_config + self.hidden_size = config.hidden_size + + self.rms_norm_eps = config.rms_norm_eps + self.attn = DeepseekV4Attention( + vllm_config, + prefix=f"{prefix}.attn", + topk_indices_buffer=topk_indices_buffer, + aux_stream=aux_stream_dict.get(AuxStreamType.Attention) + if aux_stream_dict is not None + else None, + ) + self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") + + self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) + self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) + self.hc_mult = config.hc_mult + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + self.hc_post_alpha = 2.0 + mix_hc = (2 + self.hc_mult) * self.hc_mult + hc_dim = self.hc_mult * self.hidden_size + self.hc_attn_fn = nn.Parameter( + torch.empty( + (mix_hc, hc_dim), + dtype=torch.float32, + ), + requires_grad=False, + ) + self.hc_ffn_fn = nn.Parameter( + torch.empty( + (mix_hc, hc_dim), + dtype=torch.float32, + ), + requires_grad=False, + ) + self.hc_attn_base = nn.Parameter( + torch.empty( + mix_hc, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.hc_ffn_base = nn.Parameter( + torch.empty( + mix_hc, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.hc_attn_scale = nn.Parameter( + torch.empty( + 3, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.hc_ffn_scale = nn.Parameter( + torch.empty( + 3, + dtype=torch.float32, + ), + requires_grad=False, + ) + + def hc_pre( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + # Lazy import to avoid top-level tilelang dependency. + # Registers both torch.ops.vllm.mhc_pre and mhc_post, + # so hc_post() doesn't need its own import. + import vllm.model_executor.layers.mhc # noqa: F401 + + post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre( + residual=x, + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=self.rms_norm_eps, + hc_pre_eps=self.hc_eps, + hc_sinkhorn_eps=self.hc_eps, + hc_post_mult_value=self.hc_post_alpha, + sinkhorn_repeat=self.hc_sinkhorn_iters, + ) + return layer_input, post_mix, res_mix + + def hc_post( + self, + x: torch.Tensor, + residual: torch.Tensor, + post: torch.Tensor, + comb: torch.Tensor, + ): + return torch.ops.vllm.mhc_post(x, residual, post, comb) + + def forward( + self, + x: torch.Tensor, + positions: torch.Tensor, + input_ids: torch.Tensor | None, + ) -> torch.Tensor: + residual = x + x, post, comb = self.hc_pre( + x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) + x = self.attn_norm(x) + x = self.attn(positions, x, None) + x = self.hc_post(x, residual, post, comb) + + residual = x + x, post, comb = self.hc_pre( + x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base + ) + x = self.ffn_norm(x) + x = self.ffn(x, input_ids) + x = self.hc_post(x, residual, post, comb) + return x + + +@support_torch_compile +class DeepseekV4Model(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + self.hc_eps = config.hc_eps + self.hc_mult = config.hc_mult + self.hc_dim = self.hc_mult * config.hidden_size + self.rms_norm_eps = config.rms_norm_eps + + aux_stream_list = [torch.cuda.Stream() for _ in range(1)] + self.aux_stream_dict = { + AuxStreamType.Attention: aux_stream_list[0], + } + + self.device = current_platform.device_type + # Reserved topk indices buffer for all Indexer layers to reuse. + self.topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + config.index_topk, + dtype=torch.int32, + device=self.device, + ) + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV4DecoderLayer( + vllm_config, + prefix=prefix, + topk_indices_buffer=self.topk_indices_buffer, + aux_stream_dict=self.aux_stream_dict, + ), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.hc_dim + ) + + self.hc_head_fn = nn.Parameter( + torch.empty( + self.hc_mult, + self.hc_dim, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.hc_head_base = nn.Parameter( + torch.empty( + self.hc_mult, + dtype=torch.float32, + ), + requires_grad=False, + ) + self.hc_head_scale = nn.Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + + # Pre-hc_head residual stream buffer for the MTP draft. Stable + # address (outside the cudagraph pool) so the copy_ in forward() + # refreshes it correctly across captured shapes. + if get_pp_group().is_last_rank: + self._mtp_hidden_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + self.hc_dim, + dtype=vllm_config.model_config.dtype, + device=self.device, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"].view( + -1, self.hc_mult, self.config.hidden_size + ) + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states = layer( + hidden_states, + positions, + input_ids, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states.flatten(1)}) + + # Stash pre-hc_head residual for the MTP draft (captured copy_). + num_tokens = hidden_states.shape[0] + self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1)) + + hidden_states = hc_head( + hidden_states, + self.hc_head_fn, + self.hc_head_scale, + self.hc_head_base, + self.rms_norm_eps, + self.hc_eps, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ("attn.fused_wqa_wkv", "attn.wq_a", 0), + ("attn.fused_wqa_wkv", "attn.wkv", 1), + ("compressor.fused_wkv_wgate", "compressor.wkv", 0), + ("compressor.fused_wkv_wgate", "compressor.wgate", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + # TP for attention + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + n_head = self.config.num_attention_heads + n_local_head = n_head // tp_size + head_rank_start = n_local_head * tp_rank + head_rank_end = n_local_head * (tp_rank + 1) + + # Pre-compute expert mapping ONCE. + expert_mapping = self.get_expert_mapping() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if ".experts." in name: + continue + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + break + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + if ".experts." in name: + # E8M0 scales are stored as float8_e8m0fnu in + # checkpoints but the MoE param is uint8. copy_() + # would do a numeric conversion (e.g. 2^-7 → 0), + # destroying the raw exponent bytes. + if ( + "weight_scale" in name + and loaded_weight.dtype == torch.float8_e8m0fnu + ): + loaded_weight = loaded_weight.view(torch.uint8) + skip_expert_weight = False + for mapping in expert_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): + skip_expert_weight = True + break + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + if skip_expert_weight: + continue + loaded_params.add(name_mapped) + continue + elif "attn_sink" in name: + if is_pp_missing_parameter(name, self): + continue + narrow_weight = loaded_weight[head_rank_start:head_rank_end] + n = narrow_weight.shape[0] + params_dict[name][:n].copy_(narrow_weight) + loaded_params.add(name) + continue + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + + return loaded_params + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + first_layer = next(iter(islice(self.layers, self.start_layer, self.end_layer))) + if first_layer.ffn.use_mega_moe: + return make_deepseek_v4_expert_params_mapping(self.config.n_routed_experts) + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.n_routed_experts, + ) + + def finalize_mega_moe_weights(self) -> None: + for layer in islice(self.layers, self.start_layer, self.end_layer): + layer.ffn.finalize_mega_moe_weights() + + +@torch.compile(backend=current_platform.simple_compile_backend) +def hc_head( + hidden_states: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_norm_eps: float, + hc_eps: float, +) -> torch.Tensor: + x = hidden_states + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + rms_norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + + +class DeepseekV4ForCausalLM(nn.Module, SupportsPP): + model_cls = DeepseekV4Model + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "layers.": "model.layers.", + "embed.": "model.embed.", + "norm.": "model.norm.", + "hc_head": "model.hc_head", + "mtp.": "model.mtp.", + }, + orig_to_new_regex={ + # Routed MoE expert scales: experts.N.wX.scale -> .weight_scale + re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale", + # Everything else (FP8 linear + shared experts): .scale -> .weight_scale_inv + re.compile(r"\.scale$"): ".weight_scale_inv", + }, + orig_to_new_suffix={ + "head.weight": "lm_head.weight", + "embed.weight": "embed_tokens.weight", + ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", + }, + orig_to_new_substr={ + ".attn.compressor.": ".attn.mla_attn.compressor.", + ".shared_experts.w2": ".shared_experts.down_proj", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + self.config = config + + self.model = self.model_cls( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def get_mtp_target_hidden_states(self) -> torch.Tensor | None: + """Pre-hc_head residual stream buffer (max_num_batched_tokens, + hc_mult * hidden_size) for the MTP draft model. Populated by + forward(); valid after each target step.""" + return getattr(self.model, "_mtp_hidden_buffer", None) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) + loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + self.model.finalize_mega_moe_weights() + return loaded_params + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/deepseek_v4_mtp.py b/vllm/model_executor/models/deepseek_v4_mtp.py new file mode 100644 index 000000000000..c1f0e3fb5d3a --- /dev/null +++ b/vllm/model_executor/models/deepseek_v4_mtp.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""MTP draft model for DeepSeek V4 (internal codename: DeepseekV4). + +Split from ``deepseek_mtp.py`` because the V4 architecture introduces several +pieces that have no analogue in V3/V32: + * separate ``e_proj`` / ``h_proj`` with fp8 linear quantization (instead of + the fused ``eh_proj``); + * ``hc_head`` hypercompressed vocab projection applied in ``compute_logits``; + * ``DeepseekV4DecoderLayer`` with its own aux-stream management; + * V4-specific checkpoint weight-name remapping in ``load_weights``. +""" + +import typing +from collections.abc import Callable, Iterable + +import regex as re +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.utils.multi_stream_utils import AuxStreamType + +from .deepseek_mtp import SharedHead +from .deepseek_v2 import get_spec_layer_idx_from_weight_name +from .deepseek_v4 import ( + DeepseekV4DecoderLayer, + hc_head, + make_deepseek_v4_expert_params_mapping, +) +from .utils import maybe_prefix + +logger = init_logger(__name__) + +# MoE expert scales are fused into per-layer w13/w2 tensors; other FP8 linear +# scales use `.weight_scale_inv`. Mirrors the regex in +# DeepseekV4ForCausalLM.hf_to_vllm_mapper. +_EXPERT_SCALE_RE = re.compile(r"\.experts\.\d+\.w[123]\.scale$") + + +class DeepSeekV4MultiTokenPredictorLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + topk_indices_buffer: torch.Tensor, + prefix: str, + ) -> None: + super().__init__() + + config = vllm_config.speculative_config.draft_model_config.hf_config + self.config = config + quant_config = vllm_config.quant_config + self.rms_norm_eps = config.rms_norm_eps + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # V4 keeps e_ and h_ proj separate (with fp8 linear quant) rather than + # fusing them the way V3 does with eh_proj. + self.e_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + return_bias=False, + quant_config=quant_config, + ) + self.h_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + return_bias=False, + quant_config=quant_config, + ) + + self.hc_eps = config.hc_eps + self.hc_mult = config.hc_mult + self.hc_dim = self.hc_mult * config.hidden_size + self.hc_head_fn = nn.Parameter( + torch.empty(self.hc_mult, self.hc_dim, dtype=torch.float32), + requires_grad=False, + ) + self.hc_head_base = nn.Parameter( + torch.empty(self.hc_mult, dtype=torch.float32), + requires_grad=False, + ) + self.hc_head_scale = nn.Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + + self.shared_head = SharedHead( + config=config, prefix=prefix, quant_config=quant_config + ) + self.aux_stream_dict = { + AuxStreamType.Attention: torch.cuda.Stream(), + } + self.mtp_block = DeepseekV4DecoderLayer( + vllm_config, + prefix, + topk_indices_buffer=topk_indices_buffer, + aux_stream_dict=self.aux_stream_dict, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + + # Target stashes pre-hc_head residual as flat (T, hc_mult * D); + # reshape to (T, hc_mult, D) — the training-time layout. + previous_hidden_states = previous_hidden_states.view( + -1, self.hc_mult, self.config.hidden_size + ) + previous_hidden_states = self.hnorm(previous_hidden_states) + hidden_states = self.h_proj(previous_hidden_states) + self.e_proj( + inputs_embeds + ).unsqueeze(-2) + hidden_states = self.mtp_block( + positions=positions, x=hidden_states, input_ids=None + ) + # Return the flat pre-hc_head residual so it can be re-fed as the + # next spec step's `previous_hidden_states` when + # num_speculative_tokens > 1. hc_head is deferred to compute_logits. + return hidden_states.flatten(1) + + +class DeepSeekV4MultiTokenPredictor(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + self.device = current_platform.device_type + + topk_tokens = config.index_topk + self.topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device=self.device, + ) + + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict( + { + str(idx): DeepSeekV4MultiTokenPredictorLayer( + vllm_config, + self.topk_indices_buffer, + f"{prefix}.layers.{idx}", + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = spec_step_idx % self.num_mtp_layers + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = spec_step_idx % self.num_mtp_layers + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + # MTP forward returns the pre-hc_head residual (T, hc_mult * D); apply + # hc_head here so logits are computed from the dense hidden state. + hidden_states = hidden_states.view( + -1, mtp_layer.hc_mult, mtp_layer.config.hidden_size + ) + hidden_states = hc_head( + hidden_states, + mtp_layer.hc_head_fn, + mtp_layer.hc_head_scale, + mtp_layer.hc_head_base, + mtp_layer.rms_norm_eps, + mtp_layer.hc_eps, + ) + logits = self.logits_processor( + mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) + ) + return logits + + +@support_torch_compile +class DeepSeekV4MTP(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.model = DeepSeekV4MultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor | None: + return self.model.compute_logits(hidden_states, spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Weight name remapping for checkpoint compatibility. + # Maps checkpoint weight paths to model parameter paths. + WEIGHT_NAME_REMAPPING: dict[str, str] = { + ".emb.tok_emb.weight": ".embed_tokens.weight", + ".head.weight": ".shared_head.head.weight", + ".norm.weight": ".shared_head.norm.weight", + } + + def _remap_weight_name(name: str) -> str: + """Remap checkpoint weight names to model parameter names.""" + for old_pattern, new_pattern in WEIGHT_NAME_REMAPPING.items(): + if old_pattern in name: + name = name.replace(old_pattern, new_pattern) + return name + + def _find_mtp_layer_idx(name: str) -> int: + subnames = name.split(".") + for subname in subnames: + try: + # we return the first encountered integer + return int(subname) + except ValueError: + continue + return 0 + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ("attn.fused_wqa_wkv", "attn.wq_a", 0), + ("attn.fused_wqa_wkv", "attn.wkv", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + # TP for attention + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + n_head = self.config.num_attention_heads + n_local_head = n_head // tp_size + head_rank_start = n_local_head * tp_rank + head_rank_end = n_local_head * (tp_rank + 1) + + # Pre-compute expert mapping ONCE. + first_layer = next(iter(self.model.layers.values())) + if first_layer.mtp_block.ffn.use_mega_moe: + expert_mapping = make_deepseek_v4_expert_params_mapping( + self.config.n_routed_experts + ) + else: + expert_mapping = FusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.n_routed_experts, + ) + + for name, loaded_weight in weights: + mtp_layer_idx = _find_mtp_layer_idx(name) + # V4 checkpoints store MTP weights as `mtp.{i}.*`; remap to + # `model.layers.{num_hidden_layers + i}.*` so that + # get_spec_layer_idx_from_weight_name can identify them. + name = name.replace( + f"mtp.{mtp_layer_idx}.", + f"model.layers.{self.config.num_hidden_layers + mtp_layer_idx}.", + ) + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + + name = _remap_weight_name(name) + name = self._rewrite_spec_layer_name(spec_layer, name) + + if spec_layer != self.model.mtp_start_layer_idx and ".layers" not in name: + continue + if name.endswith(".scale"): + suffix = ( + ".weight_scale" + if _EXPERT_SCALE_RE.search(name) + else ".weight_scale_inv" + ) + name = name.removesuffix(".scale") + suffix + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if ".experts." in name: + continue + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + if ".experts." in name: + # Reinterpret E8M0 scales as uint8 to preserve raw + # exponent bytes; numeric copy_() would zero them. + # Mirrors the main DeepseekV4 loader. + if ( + "weight_scale" in name + and loaded_weight.dtype == torch.float8_e8m0fnu + ): + loaded_weight = loaded_weight.view(torch.uint8) + for mapping in expert_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name_mapped = name.replace(weight_name, param_name) + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + loaded_params.add(name_mapped) + break + continue + elif "attn_sink" in name: + narrow_weight = loaded_weight[head_rank_start:head_rank_end] + n = narrow_weight.shape[0] + params_dict[name][:n].copy_(narrow_weight) + loaded_params.add(name) + continue + else: + if ".shared_experts.w2" in name: + name = name.replace( + ".shared_experts.w2", ".shared_experts.down_proj" + ) + if name.endswith(".ffn.gate.bias"): + name = name.replace(".bias", ".e_score_correction_bias") + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + + loaded_layers: set[int] = set() + for param_name in loaded_params: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name) + if spec_layer is not None: + loaded_layers.add(spec_layer) + for layer_idx in range( + self.model.mtp_start_layer_idx, + self.model.mtp_start_layer_idx + self.model.num_mtp_layers, + ): + if layer_idx not in loaded_layers: + raise ValueError( + f"MTP speculative decoding layer {layer_idx} weights " + f"missing from checkpoint. The checkpoint may have " + f"been quantized without including the MTP layers. " + f"Use a checkpoint that includes MTP layer weights, " + f"or disable speculative decoding." + ) + self.finalize_mega_moe_weights() + logger.info_once("MTP draft model loaded: %d params", len(loaded_params)) + return loaded_params + + def finalize_mega_moe_weights(self) -> None: + for layer in self.model.layers.values(): + layer.mtp_block.ffn.finalize_mega_moe_weights() + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. + """ + spec_layer_weight_names = [ + "embed_tokens", + "enorm", + "hnorm", + "h_proj", + "e_proj", + "shared_head", + "hc_head_fn", + "hc_head_base", + "hc_head_scale", + ] + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 3f80f66161fe..01f357a4993a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -96,6 +96,7 @@ "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), + "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"), "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), @@ -586,6 +587,7 @@ "Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "DeepSeekV4MTPModel": ("deepseek_v4_mtp", "DeepSeekV4MTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"), "Exaone4_5_MTP": ("exaone4_5_mtp", "Exaone4_5_MTP"), diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 42b522f691e6..755fa56d294c 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -28,6 +28,10 @@ "deepseek_v3_reasoning_parser", "DeepSeekV3ReasoningParser", ), + "deepseek_v4": ( + "deepseek_v3_reasoning_parser", + "DeepSeekV3ReasoningParser", + ), "ernie45": ( "ernie45_reasoning_parser", "Ernie45ReasoningParser", diff --git a/vllm/renderers/deepseek_v4.py b/vllm/renderers/deepseek_v4.py new file mode 100644 index 000000000000..3dc82b9622e5 --- /dev/null +++ b/vllm/renderers/deepseek_v4.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.config import VllmConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ConversationMessage, + parse_chat_messages, + parse_chat_messages_async, +) +from vllm.logger import init_logger +from vllm.tokenizers.deepseek_v4 import DeepseekV4Tokenizer +from vllm.utils.async_utils import make_async + +from .base import BaseRenderer +from .inputs import DictPrompt +from .inputs.preprocess import parse_dec_only_prompt +from .params import ChatParams + +logger = init_logger(__name__) + + +class DeepseekV4Renderer(BaseRenderer[DeepseekV4Tokenizer]): + def __init__( + self, + config: VllmConfig, + tokenizer: DeepseekV4Tokenizer | None, + ) -> None: + super().__init__(config, tokenizer) + + self._apply_chat_template_async = make_async( + self._apply_chat_template, executor=self._executor + ) + + def _apply_chat_template(self, *args, **kwargs): + return self.get_tokenizer().apply_chat_template(*args, **kwargs) + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + params: ChatParams, + ) -> tuple[list[ConversationMessage], DictPrompt]: + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + self.model_config, + content_format="string", + media_io_kwargs=params.media_io_kwargs, + mm_processor_kwargs=params.mm_processor_kwargs, + ) + + prompt_raw = self._apply_chat_template( + conversation=conversation, + messages=messages, + **params.get_apply_chat_template_kwargs(), + ) + + prompt = parse_dec_only_prompt(prompt_raw) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + params: ChatParams, + ) -> tuple[list[ConversationMessage], DictPrompt]: + conversation, mm_data, mm_uuids = await parse_chat_messages_async( + messages, + self.model_config, + content_format="string", + media_io_kwargs=params.media_io_kwargs, + mm_processor_kwargs=params.mm_processor_kwargs, + ) + + prompt_raw = await self._apply_chat_template_async( + conversation=conversation, + messages=messages, + **params.get_apply_chat_template_kwargs(), + ) + + prompt = parse_dec_only_prompt(prompt_raw) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt diff --git a/vllm/renderers/registry.py b/vllm/renderers/registry.py index 85a34a986720..df35987b24cc 100644 --- a/vllm/renderers/registry.py +++ b/vllm/renderers/registry.py @@ -21,6 +21,7 @@ _VLLM_RENDERERS = { "deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"), + "deepseek_v4": ("deepseek_v4", "DeepseekV4Renderer"), "hf": ("hf", "HfRenderer"), "grok2": ("grok2", "Grok2Renderer"), "kimi_audio": ("hf", "HfRenderer"), diff --git a/vllm/tokenizers/deepseek_v4.py b/vllm/tokenizers/deepseek_v4.py new file mode 100644 index 000000000000..76725dab16a1 --- /dev/null +++ b/vllm/tokenizers/deepseek_v4.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from typing import Any + +from transformers import PreTrainedTokenizerFast + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + +from .deepseek_v4_encoding import encode_messages +from .hf import HfTokenizer, get_cached_tokenizer +from .protocol import TokenizerLike + + +def get_deepseek_v4_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: + """ + Wraps a tokenizer to use the custom DeepSeek V4 chat template encoding. + """ + dsv4_tokenizer = copy.copy(tokenizer) + + added_vocab = tokenizer.get_added_vocab() + added_vocab_size = len(added_vocab) + tokenizer_vocab_size = tokenizer.vocab_size + + class _DeepseekV4Tokenizer(tokenizer.__class__): # type: ignore + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: list[dict[str, Any]] | None = None, + **kwargs, + ) -> str | list[int]: + thinking = kwargs.get("thinking", False) + enable_thinking = kwargs.get("enable_thinking", False) + thinking = thinking or enable_thinking + thinking_mode = "thinking" if thinking else "chat" + + conversation = kwargs.get("conversation", messages) + messages = conversation.copy() + if tools is not None and len(tools) > 0: + messages.insert(0, {"role": "system"}) + messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key] + + # The V4 reference currently accepts only "max", "high", or None. + reasoning_effort = kwargs.get("reasoning_effort") + if reasoning_effort not in ("max", "high"): + reasoning_effort = None + + encode_config = dict( + thinking_mode=thinking_mode, + drop_thinking=kwargs.get("drop_thinking", True), + reasoning_effort=reasoning_effort, + ) + + prompt_str = encode_messages(messages, **encode_config) # type: ignore + + if kwargs.get("tokenize", True): + tokenizer_kwargs = { + k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs + } + return self.encode( + prompt_str, + add_special_tokens=False, + **tokenizer_kwargs, + ) + + return prompt_str + + def num_special_tokens_to_add(self) -> int: + return len(self.encode("")) + + def __len__(self) -> int: + return tokenizer_vocab_size + added_vocab_size + + def get_added_vocab(self) -> dict[str, int]: + return added_vocab.copy() + + def __reduce__(self): + return get_deepseek_v4_tokenizer, (tokenizer,) + + _DeepseekV4Tokenizer.__name__ = f"DSV4{tokenizer.__class__.__name__}" + + dsv4_tokenizer.__class__ = _DeepseekV4Tokenizer + return dsv4_tokenizer + + +class DeepseekV4Tokenizer(TokenizerLike): + @classmethod + def from_pretrained(cls, *args, **kwargs) -> HfTokenizer: + tokenizer = PreTrainedTokenizerFast.from_pretrained(*args, **kwargs) + return get_cached_tokenizer(get_deepseek_v4_tokenizer(tokenizer)) diff --git a/vllm/tokenizers/deepseek_v4_encoding.py b/vllm/tokenizers/deepseek_v4_encoding.py new file mode 100644 index 000000000000..6895771e2f59 --- /dev/null +++ b/vllm/tokenizers/deepseek_v4_encoding.py @@ -0,0 +1,757 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa +# fmt: off + +""" +DeepSeek-V4 Encoding + +A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages +with tool calling, thinking mode, and quick instruction task support. +""" + +from typing import Any, Dict, List, Union, Optional, Tuple +import copy +import json + +import regex as re + +# ============================================================ +# Special Tokens +# ============================================================ + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" + +USER_SP_TOKEN = "<|User|>" +ASSISTANT_SP_TOKEN = "<|Assistant|>" +LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>" + +# Task special tokens for internal classification tasks +DS_TASK_SP_TOKENS = { + "action": "<|action|>", + "query": "<|query|>", + "authority": "<|authority|>", + "domain": "<|domain|>", + "title": "<|title|>", + "read_url": "<|read_url|>", +} +VALID_TASKS = set(DS_TASK_SP_TOKENS.keys()) + +# ============================================================ +# Templates +# ============================================================ + +system_msg_template: str = "{content}" +user_msg_template: str = "{content}" +latest_reminder_msg_template: str = "{content}" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token +assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}" +thinking_template: str = "{reasoning}" + +response_format_template: str = ( + "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +) +tool_call_template: str = ( + "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n" +) +tool_calls_template = ( + "<{dsml_token}{tc_block_name}>\n{tool_calls}\n" +) +tool_calls_block_name: str = "tool_calls" + +tool_output_template: str = ( + "{content}" +) + +REASONING_EFFORT_MAX = ( + "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n" + "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n" + "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n" +) + +TOOLS_TEMPLATE = """## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following: + +<{dsml_token}tool_calls> +<{dsml_token}invoke name="$TOOL_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response. + +Otherwise, output directly after {thinking_end_token} with tool calls or final response. + +### Available Tool Schemas + +{tool_schemas} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +""" + +# ============================================================ +# Utility Functions +# ============================================================ + +def to_json(value: Any) -> str: + """Serialize a value to JSON string.""" + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + """Extract function definitions from OpenAI-format tool list.""" + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + """Convert OpenAI-format tool calls to internal format.""" + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + """Convert internal tool calls to OpenAI format.""" + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + } + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str: + """ + Encode tool call arguments into DSML parameter format. + + Args: + tool_call: Dict with "name" and "arguments" keys. + + Returns: + DSML-formatted parameter string. + """ + p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' + P_dsml_strs = [] + + if isinstance(tool_call["arguments"], str): + arguments = json.loads(tool_call["arguments"]) + else: + arguments = tool_call["arguments"] + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]: + """ + Decode DSML parameters back to a tool call dict. + + Args: + tool_name: Name of the tool. + tool_args: Dict mapping param_name -> (value, is_string_flag). + + Returns: + Dict with "name" and "arguments" (JSON string) keys. + """ + def _decode_value(key: str, value: str, string: str): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}" + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str: + """ + Render tool schemas into the system prompt format. + + Args: + tools: List of tool schema dicts (each with name, description, parameters). + + Returns: + Formatted tools section string. + """ + tools_json = [to_json(t) for t in tools] + + return TOOLS_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages: List[Dict[str, Any]]) -> int: + """Find the index of the last user/developer message.""" + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +# ============================================================ +# Message Rendering +# ============================================================ + +def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str: + """ + Render a single message at the given index into its encoded string form. + + This is the core function that converts each message in the conversation + into the DeepSeek-V4 format. + + Args: + index: Index of the message to render. + messages: Full list of messages in the conversation. + thinking_mode: Either "chat" or "thinking". + drop_thinking: Whether to drop reasoning content from earlier turns. + reasoning_effort: Optional reasoning effort level ("max", "high", or None). + + Returns: + Encoded string for this message. + """ + assert 0 <= index < len(messages) + assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`" + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning = msg.get("reasoning") + wo_eos = msg.get("wo_eos", False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + # Reasoning effort prefix (only at index 0 in thinking mode with max effort) + assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}" + if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max': + prompt += REASONING_EFFORT_MAX + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + if response_format: + prompt += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + elif role == "developer": + assert content, f"Invalid message for role `{role}`: {msg}" + + content_developer = USER_SP_TOKEN + content_developer += content + + if tools: + content_developer += "\n\n" + render_tools(tools) + if response_format: + content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format)) + + prompt += user_msg_template.format(content=content_developer) + + elif role == "user": + prompt += USER_SP_TOKEN + + # Handle content blocks (tool results mixed with text) + content_blocks = msg.get("content_blocks") + if content_blocks: + parts = [] + for block in content_blocks: + block_type = block.get("type") + if block_type == "text": + parts.append(block.get("text", "")) + elif block_type == "tool_result": + tool_content = block.get("content", "") + if isinstance(tool_content, list): + text_parts = [] + for b in tool_content: + if b.get("type") == "text": + text_parts.append(b.get("text", "")) + else: + text_parts.append(f"[Unsupported {b.get('type')}]") + tool_content = "\n\n".join(text_parts) + parts.append(tool_output_template.format(content=tool_content)) + else: + parts.append(f"[Unsupported {block_type}]") + prompt += "\n\n".join(parts) + else: + prompt += content or "" + + elif role == "latest_reminder": + prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content) + + elif role == "tool": + raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()") + + elif role == "assistant": + thinking_part = "" + tc_content = "" + + if tool_calls: + tc_list = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tc.get("name"), + arguments=encode_arguments_to_dsml(tc) + ) + for tc in tool_calls + ] + tc_content += '\n\n' + tool_calls_template.format( + dsml_token=dsml_token, + tool_calls="\n".join(tc_list), + tc_block_name=tool_calls_block_name, + ) + + summary_content = content or "" + reasoning = reasoning or "" + + # Check if previous message has a task - if so, this is a task output (no thinking) + prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None + + if thinking_mode == "thinking" and not prev_has_task: + if not drop_thinking or index > last_user_idx: + thinking_part = thinking_template.format(reasoning=reasoning) + thinking_end_token + else: + thinking_part = "" + + if wo_eos: + prompt += assistant_msg_wo_eos_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + # Append transition tokens based on what follows + if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]: + return prompt + + task = messages[index].get("task") + if task is not None: + # Task special token for internal classification tasks + assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}" + task_sp_token = DS_TASK_SP_TOKENS[task] + + if task != "action": + # Non-action tasks: append task sp token directly after the message + prompt += task_sp_token + else: + # Action task: append Assistant + thinking token + action sp token + prompt += ASSISTANT_SP_TOKEN + prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token + prompt += task_sp_token + + elif messages[index].get("role") in ["user", "developer"]: + # Normal generation: append Assistant + thinking token + prompt += ASSISTANT_SP_TOKEN + if not drop_thinking and thinking_mode == "thinking": + prompt += thinking_start_token + elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx: + prompt += thinking_start_token + else: + prompt += thinking_end_token + + return prompt + + +# ============================================================ +# Preprocessing +# ============================================================ + +def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Merge tool messages into the preceding user message using content_blocks format. + + DeepSeek-V4 does not have a standalone "tool" role; instead, tool results + are encoded as blocks within user messages. + + This function converts a standard OpenAI-format conversation (with separate + "tool" role messages) into V4 format where tool results are merged into + user messages. + + Args: + messages: List of message dicts in OpenAI format. + + Returns: + Processed message list with tool messages merged into user messages. + """ + merged: List[Dict[str, Any]] = [] + + for msg in messages: + msg = copy.deepcopy(msg) + role = msg.get("role") + + if role == "tool": + # Convert tool message to a user message with tool_result block + tool_block = { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": msg.get("content", ""), + } + # Merge into previous message if it's already a user (merged tool) + if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]: + merged[-1]["content_blocks"].append(tool_block) + else: + merged.append({ + "role": "user", + "content_blocks": [tool_block], + }) + elif role == "user": + text_block = {"type": "text", "text": msg.get("content", "")} + if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None: + merged[-1]["content_blocks"].append(text_block) + else: + new_msg = { + "role": "user", + "content": msg.get("content", ""), + "content_blocks": [text_block], + } + # Preserve extra fields (task, wo_eos, mask, etc.) + for key in ("task", "wo_eos", "mask"): + if key in msg: + new_msg[key] = msg[key] + merged.append(new_msg) + else: + merged.append(msg) + + return merged + + +def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Sort tool_result blocks within user messages by the order of tool_calls + in the preceding assistant message. + + Args: + messages: Preprocessed message list (after merge_tool_messages). + + Returns: + Message list with sorted tool result blocks. + """ + last_tool_call_order: Dict[str, int] = {} + + for msg in messages: + role = msg.get("role") + if role == "assistant" and msg.get("tool_calls"): + last_tool_call_order = {} + for idx, tc in enumerate(msg["tool_calls"]): + tc_id = tc.get("id") or tc.get("function", {}).get("id", "") + if tc_id: + last_tool_call_order[tc_id] = idx + + elif role == "user" and msg.get("content_blocks"): + tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"] + if len(tool_blocks) > 1 and last_tool_call_order: + sorted_blocks = sorted( + tool_blocks, + key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0) + ) + sorted_idx = 0 + new_blocks = [] + for block in msg["content_blocks"]: + if block.get("type") == "tool_result": + new_blocks.append(sorted_blocks[sorted_idx]) + sorted_idx += 1 + else: + new_blocks.append(block) + msg["content_blocks"] = new_blocks + + return messages + + +# ============================================================ +# Main Encoding Function +# ============================================================ + +def encode_messages( + messages: List[Dict[str, Any]], + thinking_mode: str, + context: Optional[List[Dict[str, Any]]] = None, + drop_thinking: bool = True, + add_default_bos_token: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + """ + Encode a list of messages into the DeepSeek-V4 prompt format. + + This is the main entry point for encoding conversations. It handles: + - BOS token insertion + - Thinking mode with optional reasoning content dropping + - Tool message merging into user messages + - Multi-turn conversation context + + Args: + messages: List of message dicts to encode. + thinking_mode: Either "chat" or "thinking". + context: Optional preceding context messages (already encoded prefix). + drop_thinking: If True, drop reasoning from earlier assistant turns + (only keep reasoning for messages after the last user message). + add_default_bos_token: Whether to prepend BOS token at conversation start. + reasoning_effort: Optional reasoning effort level ("max", "high", or None). + + Returns: + The encoded prompt string. + """ + context = context if context else [] + + # Preprocess: merge tool messages and sort tool results + messages = merge_tool_messages(messages) + messages = sort_tool_results_by_call_order(context + messages)[len(context):] + if context: + context = merge_tool_messages(context) + context = sort_tool_results_by_call_order(context) + + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + # Resolve drop_thinking: if any message has tools defined, don't drop thinking + effective_drop_thinking = drop_thinking + if any(m.get("tools") for m in full_messages): + effective_drop_thinking = False + + if thinking_mode == "thinking" and effective_drop_thinking: + full_messages = _drop_thinking_messages(full_messages) + # After dropping, recalculate how many messages to render + # (context may have shrunk too) + num_to_render = len(full_messages) - len(_drop_thinking_messages(context)) + context_len = len(full_messages) - num_to_render + else: + num_to_render = len(messages) + context_len = len(context) + + for idx in range(num_to_render): + prompt += render_message( + idx + context_len, + full_messages, + thinking_mode=thinking_mode, + drop_thinking=effective_drop_thinking, + reasoning_effort=reasoning_effort, + ) + + return prompt + + +def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Drop reasoning and non-essential messages before the last user message. + + Behavior: + - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept. + - Messages at or after the last user index are always kept. + - Assistant messages before the last user get reasoning removed. + - Developer messages before the last user are dropped entirely. + """ + last_user_idx = find_last_user_index(messages) + result = [] + keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"} + + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in keep_roles or idx >= last_user_idx: + result.append(msg) + elif role == "assistant": + msg = copy.copy(msg) + msg.pop("reasoning", None) + result.append(msg) + # developer and other roles before last_user_idx are dropped + + return result + + +# ============================================================ +# Parsing (Decoding model output) +# ============================================================ + +def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]: + """ + Read text from index until one of the stop strings is found. + + Returns: + Tuple of (new_index, content_before_stop, matched_stop_string_or_None). + """ + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]: + """ + Parse DSML tool calls from text starting at the given index. + + Args: + index: Starting position in text. + text: The full text to parse. + + Returns: + Tuple of (new_index, last_stop_token, list_of_tool_call_dicts). + Each tool call dict has "name" and "arguments" keys. + """ + tool_calls: List[Dict[str, Any]] = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, content_before, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token]) + if content_before != ">\n": + raise ValueError(f"Tool call format error: expected '>\\n' but got '{content_before}'") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise ValueError("Missing special token in tool calls") + + index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL) + if len(p_tool_name) != 1: + raise ValueError(f"Tool name format error: '{tool_name_content}'") + tool_name = p_tool_name[0] + + tool_args: Dict[str, Tuple[str, str]] = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"]) + + param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL) + if len(param_kv) != 1: + raise ValueError(f"Parameter format error: '{param_content}'") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise ValueError(f"Duplicate parameter name: '{param_name}'") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"\n": + raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'") + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]: + """ + Parse a model completion text into a structured assistant message. + + This function takes the raw text output from the model (a single assistant turn) + and extracts: + - reasoning (thinking block) + - content (summary/response) + - tool_calls (if any) + + NOTE: This function is designed to parse only correctly formatted strings and + will raise ValueError for malformed output. + + Args: + text: The raw completion text (including EOS token). + thinking_mode: Either "chat" or "thinking". + + Returns: + Dict with keys: "role", "content", "reasoning", "tool_calls". + tool_calls are in OpenAI format. + """ + summary_content, reasoning = "", "" + tool_calls: List[Dict[str, str]] = [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}" + + is_thinking = thinking_mode == "thinking" + is_tool_calling = False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token]) + reasoning = content_delta + if stop_token != thinking_end_token: + raise ValueError("Invalid thinking format: missing ") + + index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token]) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + if stop_token != eos_token: + raise ValueError("Invalid format: missing EOS token") + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + if tool_ends_text: + raise ValueError("Unexpected content after tool calls") + + if len(text) != index or stop_token not in [eos_token, None]: + raise ValueError("Unexpected content at end") + + for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]: + if sp_token in summary_content or sp_token in reasoning: + raise ValueError(f"Unexpected special token '{sp_token}' in content") + + return { + "role": "assistant", + "content": summary_content, + "reasoning": reasoning, + "tool_calls": tool_calls_to_openai_format(tool_calls) + } + +# fmt: on diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index 8f16e6d28f43..8778aa9d691f 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -42,6 +42,7 @@ _VLLM_TOKENIZERS = { "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"), + "deepseek_v4": ("deepseek_v4", "DeepseekV4Tokenizer"), "grok2": ("grok2", "Grok2Tokenizer"), "hf": ("hf", "CachedHfTokenizer"), "kimi_audio": ("kimi_audio", "KimiAudioTokenizer"), diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index 7d5ea8d5ea7b..8a39ca825d5f 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -34,6 +34,10 @@ "deepseekv32_tool_parser", "DeepSeekV32ToolParser", ), + "deepseek_v4": ( + "deepseekv4_tool_parser", + "DeepSeekV4ToolParser", + ), "ernie45": ( "ernie45_tool_parser", "Ernie45ToolParser", diff --git a/vllm/tool_parsers/deepseekv32_tool_parser.py b/vllm/tool_parsers/deepseekv32_tool_parser.py index 0bbb6d7e73e0..b8623592365c 100644 --- a/vllm/tool_parsers/deepseekv32_tool_parser.py +++ b/vllm/tool_parsers/deepseekv32_tool_parser.py @@ -46,21 +46,24 @@ class DeepSeekV32ToolParser(ToolParser): """ + tool_call_start_token: str = "<|DSML|function_calls>" + tool_call_end_token: str = "" + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) self.prev_tool_call_arr: list[dict] = [] - # Sentinel token - self.tool_call_start_token: str = "<|DSML|function_calls>" - # Streaming state self.current_tool_index: int = 0 self._sent_content_idx: int = 0 # Regex patterns for complete parsing self.tool_call_complete_regex = re.compile( - r"<|DSML|function_calls>(.*?)", re.DOTALL + re.escape(self.tool_call_start_token) + + r"(.*?)" + + re.escape(self.tool_call_end_token), + re.DOTALL, ) self.invoke_complete_regex = re.compile( r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)', re.DOTALL @@ -86,7 +89,7 @@ def adjust_request( request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # Ensure tool call tokens - # (<|DSML|function_calls>, ) + # (e.g. <|DSML|function_calls>, ) # are not skippedduring decoding. # Even though they are not marked as special tokens, # setting skip_special_tokens=False ensures proper handling in diff --git a/vllm/tool_parsers/deepseekv4_tool_parser.py b/vllm/tool_parsers/deepseekv4_tool_parser.py new file mode 100644 index 000000000000..006e5cf14bd5 --- /dev/null +++ b/vllm/tool_parsers/deepseekv4_tool_parser.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import regex as re + +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import Tool +from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser + + +class DeepSeekV4ToolParser(DeepSeekV32ToolParser): + """ + DeepSeek V4 DSML tool parser. + + V4 keeps the V3.2 DSML invoke/parameter grammar, but wraps tool calls in + ``<|DSML|tool_calls>`` instead of ``<|DSML|function_calls>``. + """ + + tool_call_start_token: str = "<|DSML|tool_calls>" + tool_call_end_token: str = "" + + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): + super().__init__(tokenizer, tools) + + self.tool_call_start_token = "<|DSML|tool_calls>" + self.tool_call_end_token = "" + self.tool_call_complete_regex = re.compile( + re.escape(self.tool_call_start_token) + + r"(.*?)" + + re.escape(self.tool_call_end_token), + re.DOTALL, + ) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 93dba4fd2f34..32241aec53e1 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -89,6 +89,7 @@ def __getitem__(self, key): qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig", deepseek_vl_v2="DeepseekVLV2Config", deepseek_v32="DeepseekV3Config", + deepseek_v4="DeepseekV4Config", flex_olmo="FlexOlmoConfig", fireredlid="FireRedLIDConfig", funaudiochat="FunAudioChatConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 45eff21513be..667ed5a2596c 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -26,6 +26,7 @@ "OpsColQwen3Config": "vllm.transformers_utils.configs.colqwen3", "Qwen3VLNemotronEmbedConfig": "vllm.transformers_utils.configs.colqwen3", "DeepseekVLV2Config": "vllm.transformers_utils.configs.deepseek_vl2", + "DeepseekV4Config": "vllm.transformers_utils.configs.deepseek_v4", "DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr", "EAGLEConfig": "vllm.transformers_utils.configs.eagle", "FireRedLIDConfig": "vllm.transformers_utils.configs.fireredlid", @@ -88,6 +89,7 @@ "Qwen3VLNemotronEmbedConfig", "DeepseekVLV2Config", "DeepseekV3Config", + "DeepseekV4Config", "DotsOCRConfig", "EAGLEConfig", "FlexOlmoConfig", diff --git a/vllm/transformers_utils/configs/deepseek_v4.py b/vllm/transformers_utils/configs/deepseek_v4.py new file mode 100755 index 000000000000..7708272c3bd4 --- /dev/null +++ b/vllm/transformers_utils/configs/deepseek_v4.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from transformers import PretrainedConfig + + +class DeepseekV4Config(PretrainedConfig): + model_type = "deepseek_v4" + + def __init__( + self, + max_position_embeddings: int = 1048576, + rope_scaling: dict[str, Any] | None = None, + rope_parameters: dict[str, Any] | None = None, + rope_theta: float = 10000.0, + **kwargs, + ): + self.max_position_embeddings = max_position_embeddings + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + self.rope_parameters = rope_scaling or rope_parameters + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index cef655803474..443223689a95 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -47,6 +47,9 @@ def get_hidden_size(self) -> int: def get_head_size(self) -> int: if self.is_deepseek_mla(): + # special case for deepseek_v4 + if hasattr(self.hf_text_config, "compress_ratios"): + return self.hf_text_config.head_dim qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) if not envs.VLLM_MLA_DISABLE: return self.hf_text_config.kv_lora_rank + qk_rope_head_dim @@ -222,6 +225,7 @@ def is_deepseek_mla(self) -> bool: "deepseek_v2", "deepseek_v3", "deepseek_v32", + "deepseek_v4", "deepseek_mtp", "glm_moe_dsa", "glm4_moe_lite", @@ -233,7 +237,11 @@ def is_deepseek_mla(self) -> bool: "pangu_ultra_moe_mtp", "bailing_hybrid", ): - return getattr(self.hf_text_config, "kv_lora_rank", None) is not None + # check is deepseek_v4 model + if hasattr(self.hf_text_config, "compress_ratios"): + return getattr(self.hf_text_config, "head_dim", None) is not None + else: + return getattr(self.hf_text_config, "kv_lora_rank", None) is not None elif self.hf_text_config.model_type == "eagle": # if the model is an EAGLE module, check for the # underlying architecture diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 637e9ec37e08..8be663a726ac 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -125,12 +125,16 @@ def _missing(*_: Any, **__: Any) -> NoReturn: ) +_cublaslt_gemm_nt_impl: Callable[..., Any] | None = None _fp8_gemm_nt_impl: Callable[..., Any] | None = None +_fp8_einsum_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None -_fp8_mqa_logits_impl: Callable[..., Any] | None = None -_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None +_grouped_fp4_impl: Callable[..., Any] | None = None +_fp8_fp4_mqa_logits_impl: Callable[..., Any] | None = None +_fp8_fp4_paged_mqa_logits_impl: Callable[..., Any] | None = None _get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None +_tf32_hc_prenorm_gemm_impl: Callable[..., Any] | None = None _get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None _get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None _transform_sf_into_required_layout_impl: Callable[..., Any] | None = None @@ -173,20 +177,27 @@ def _import_deep_gemm(): def _lazy_init() -> None: """Import deep_gemm and resolve symbols on first use.""" - global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl - global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl + global _cublaslt_gemm_nt_impl + global _fp8_gemm_nt_impl, _fp8_einsum_impl + global _grouped_impl, _grouped_masked_impl, _grouped_fp4_impl + global _fp8_fp4_mqa_logits_impl, _fp8_fp4_paged_mqa_logits_impl global _get_paged_mqa_logits_metadata_impl + global _tf32_hc_prenorm_gemm_impl global _get_mn_major_tma_aligned_tensor_impl global _get_mk_alignment_for_contiguous_layout_impl global _transform_sf_into_required_layout_impl # fast path if ( - _fp8_gemm_nt_impl is not None + _cublaslt_gemm_nt_impl is not None + or _fp8_gemm_nt_impl is not None + or _fp8_einsum_impl is not None or _grouped_impl is not None or _grouped_masked_impl is not None - or _fp8_mqa_logits_impl is not None - or _fp8_paged_mqa_logits_impl is not None + or _grouped_fp4_impl is not None + or _fp8_fp4_mqa_logits_impl is not None + or _fp8_fp4_paged_mqa_logits_impl is not None or _get_paged_mqa_logits_metadata_impl is not None + or _tf32_hc_prenorm_gemm_impl is not None or _get_mk_alignment_for_contiguous_layout_impl is not None or _transform_sf_into_required_layout_impl is not None ): @@ -202,18 +213,29 @@ def _lazy_init() -> None: envs.VLLM_CACHE_ROOT, "deep_gemm" ) + if ( + current_platform.is_device_capability_family(120) + and envs.VLLM_DEEP_GEMM_SM120_PAGED_MQA_TILED + ): + os.environ.setdefault("DG_SM120_PAGED_MQA_TILED", "1") _dg = _import_deep_gemm() if _dg is None: return + _cublaslt_gemm_nt_impl = getattr(_dg, "cublaslt_gemm_nt", None) _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) + _fp8_einsum_impl = getattr(_dg, "fp8_einsum", None) _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) - _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) - _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) + _grouped_fp4_impl = getattr(_dg, "m_grouped_fp8_fp4_gemm_nt_contiguous", None) + # DeepGEMM exposes fp8_fp4_*_mqa_logits as the canonical symbols that + # handle both the FP8 and FP4 Q/K paths via a tuple-typed `q`. + _fp8_fp4_mqa_logits_impl = getattr(_dg, "fp8_fp4_mqa_logits", None) + _fp8_fp4_paged_mqa_logits_impl = getattr(_dg, "fp8_fp4_paged_mqa_logits", None) _get_paged_mqa_logits_metadata_impl = getattr( _dg, "get_paged_mqa_logits_metadata", None ) + _tf32_hc_prenorm_gemm_impl = getattr(_dg, "tf32_hc_prenorm_gemm", None) _get_mn_major_tma_aligned_tensor_impl = getattr( _dg, "get_mn_major_tma_aligned_tensor", None ) @@ -259,6 +281,13 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: return _get_mn_major_tma_aligned_tensor_impl(x) +def cublaslt_gemm_nt(*args, **kwargs): + _lazy_init() + if _cublaslt_gemm_nt_impl is None: + return _missing(*args, **kwargs) + return _cublaslt_gemm_nt_impl(*args, **kwargs) + + def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: @@ -271,6 +300,13 @@ def fp8_gemm_nt(*args, **kwargs): return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs) +def fp8_einsum(*args, **kwargs): + _lazy_init() + if _fp8_einsum_impl is None: + return _missing(*args, **kwargs) + return _fp8_einsum_impl(*args, **kwargs) + + def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: @@ -280,6 +316,15 @@ def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): ) +def m_grouped_fp8_fp4_gemm_nt_contiguous(*args, **kwargs): + _lazy_init() + if _grouped_fp4_impl is None: + return _missing(*args, **kwargs) + return _grouped_fp4_impl( + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) + + def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): _lazy_init() if _grouped_masked_impl is None: @@ -298,37 +343,93 @@ def transform_sf_into_required_layout(*args, **kwargs): ) -def fp8_mqa_logits( - q: torch.Tensor, +def _fp8_mqa_logits_torch_reference( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA logits fallback only supports FP8 Q") + + k_values, k_scales = kv + q_f32 = q_values.to(torch.float32) + k_f32 = k_values.to(torch.float32) * k_scales.reshape(-1, 1).to(torch.float32) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_f32.shape + seq_len_kv = k_f32.shape[0] + logits = torch.zeros( + (seq_len, seq_len_kv), device=q_values.device, dtype=torch.float32 + ) + + # Avoid materializing the full [H, M, N] score tensor for all heads. + for head_start in range(0, num_heads, 8): + head_end = min(head_start + 8, num_heads) + q_chunk = q_f32[:, head_start:head_end, :].transpose(0, 1).contiguous() + scores = torch.matmul(q_chunk, k_t) + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + logits += (scores.relu() * head_weights).sum(dim=0) + + if clean_logits: + offsets = torch.arange(seq_len_kv, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + logits = logits.masked_fill(~valid, float("-inf")) + + return logits + + +def fp8_fp4_mqa_logits( + q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, clean_logits: bool, ) -> torch.Tensor: - """Compute FP8 MQA logits for a single sequence without KV paging. + """Compute MQA logits for a single sequence without KV paging. + + Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes + ``q = (values, scales_or_None)`` where ``scales`` is None for FP8 Q + (per-token scale is folded into ``weights``) and a packed block-scale + tensor for MXFP4 Q. Args: - q: Query tensor of shape [M, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with - dtype `torch.float8_e4m3fn` and `k_scales` has shape [N]) - with dtype `torch.float32`. + q: Tuple ``(q_values, q_scale)``. FP8 path: q_values is [M, H, D] + float8_e4m3fn and q_scale is None (per-token scale is folded + into ``weights``). FP4 path: q_values is packed uint8 and + q_scale is the companion block-scale tensor. + kv: Tuple `(k_packed, k_scales)` — FP8 layout is [N, D] + float8_e4m3fn plus fp32 scales [N]; FP4 layout is packed uint8. weights: weights of shape [M, H], dtype `torch.float32`. - cu_seqlen_ks: Start indices (inclusive) for valid K per query position, - shape [M], dtype int32. - cu_seqlen_ke: End indices (exclusive) for valid K per query position, - shape [M], dtype int32. + cu_seqlen_ks: Start indices (inclusive) for valid K per query + position, shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query + position, shape [M], dtype int32. clean_logits: Whether to clean the unfilled logits into `-inf`. Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ _lazy_init() - if _fp8_mqa_logits_impl is None: + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_mqa_logits_torch_reference( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + if _fp8_fp4_mqa_logits_impl is None: return _missing() - return _fp8_mqa_logits_impl( - q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=clean_logits + return _fp8_fp4_mqa_logits_impl( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=clean_logits, ) @@ -344,7 +445,7 @@ def get_paged_mqa_logits_metadata( num_sms: Number of SMs available. 132 for Hopper Returns: - Backend-specific tensor consumed by `fp8_paged_mqa_logits` to + Backend-specific tensor consumed by `fp8_fp4_paged_mqa_logits` to schedule work across SMs. """ _lazy_init() @@ -353,9 +454,9 @@ def get_paged_mqa_logits_metadata( return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) -def fp8_paged_mqa_logits( - q_fp8: torch.Tensor, - kv_cache_fp8: torch.Tensor, +def fp8_fp4_paged_mqa_logits( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, @@ -363,14 +464,20 @@ def fp8_paged_mqa_logits( max_model_len: int, clean_logits: bool, ) -> torch.Tensor: - """Compute FP8 MQA logits using paged KV-cache. + """Compute MQA logits using a paged KV-cache. + + Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes + ``q = (values, scales_or_None)``; pass ``(q_tensor, None)`` for the FP8 + path and ``(q_values, q_scale)`` for MXFP4. Args: - q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to - `torch.float8_e4m3fn` by caller. - kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape - [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last - 4 bytes per (block,pos) store the `float` dequant scale. + q: Tuple ``(q_values, q_scale)``. FP8 path: q_values is + [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path: + q_values is packed uint8 and q_scale is the companion + block-scale tensor. + kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, + D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos) + storing the float dequant scale. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. context_lens: Tensor of shape [B], dtype int32; effective context length for each batch element. @@ -386,11 +493,11 @@ def fp8_paged_mqa_logits( `torch.float32`. """ _lazy_init() - if _fp8_paged_mqa_logits_impl is None: + if _fp8_fp4_paged_mqa_logits_impl is None: return _missing() - return _fp8_paged_mqa_logits_impl( - q_fp8, - kv_cache_fp8, + return _fp8_fp4_paged_mqa_logits_impl( + q, + kv_cache, weights, context_lens, block_tables, @@ -400,6 +507,32 @@ def fp8_paged_mqa_logits( ) +def tf32_hc_prenorm_gemm( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + """ + Perform the following computation: + out = x.float() @ fn.T + sqrsum = x.float().square().sum(-1) + + See the caller function for shape requirement + """ + _lazy_init() + if _tf32_hc_prenorm_gemm_impl is None: + return _missing() + return _tf32_hc_prenorm_gemm_impl( + x, + fn, + out, + sqrsum, + num_split, + ) + + def _ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) @@ -482,10 +615,12 @@ def should_use_deepgemm_for_fp8_linear( "calc_diff", "DeepGemmQuantScaleFMT", "fp8_gemm_nt", + "fp8_einsum", "m_grouped_fp8_gemm_nt_contiguous", + "m_grouped_fp8_fp4_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", - "fp8_mqa_logits", - "fp8_paged_mqa_logits", + "fp8_fp4_mqa_logits", + "fp8_fp4_paged_mqa_logits", "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py index 3ade910bf99c..cc6bc6462449 100644 --- a/vllm/utils/multi_stream_utils.py +++ b/vllm/utils/multi_stream_utils.py @@ -2,11 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from enum import Enum from typing import Any import torch +class AuxStreamType(Enum): + Attention = 1 + + +class EventType(Enum): + Main = 0 + Attention = 1 + + def maybe_execute_in_parallel( fn0: Callable[[], Any], fn1: Callable[[], Any], diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 16535ee3c6c1..d83489238d33 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -392,6 +392,11 @@ class CommonAttentionMetadata: dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" + positions: torch.Tensor | None = None + """(num_actual_tokens,) token positions. Optional; set when the caller + has positions available so that builders can pre-compute position-dependent + metadata (e.g. C128A topk indices for DeepSeek V4).""" + is_prefilling: torch.Tensor | None = None """(batch_size,) bool tensor: True if request is still in prefill phase (num_computed_tokens < num_prompt_tokens). Used by some backends to diff --git a/vllm/v1/attention/backends/mla/compressor_utils.py b/vllm/v1/attention/backends/mla/compressor_utils.py new file mode 100644 index 000000000000..36b115f64444 --- /dev/null +++ b/vllm/v1/attention/backends/mla/compressor_utils.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _compressed_slot_mapping_kernel( + # [num_tokens] + slot_mapping_ptr, + # [num_reqs + 1] + query_start_loc_ptr, + # [num_reqs] + seq_lens_ptr, + # [num_reqs, max_num_blocks] + block_table_ptr, + block_table_stride, + block_size, + COMPRESS_RATIO: tl.constexpr, + PAD_ID: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + seq_len = tl.load(seq_lens_ptr + batch_idx) + start_pos = seq_len - query_len + + for i in range(0, query_len, TRITON_BLOCK_SIZE): + offset = i + tl.arange(0, TRITON_BLOCK_SIZE) + mask = offset < query_len + + pos = start_pos + i + tl.arange(0, TRITON_BLOCK_SIZE) + is_valid = (pos + 1) % COMPRESS_RATIO == 0 + pos_after_compress = pos // COMPRESS_RATIO + + block_ids = pos_after_compress // block_size + block_numbers = tl.load( + block_table_ptr + batch_idx * block_table_stride + block_ids, + mask=mask & is_valid, + ) + slot_ids = block_numbers * block_size + pos_after_compress % block_size + + # NOTE + slot_ids = tl.where(is_valid, slot_ids, PAD_ID) + tl.store(slot_mapping_ptr + query_start + offset, slot_ids, mask=mask) + + +def get_compressed_slot_mapping( + num_tokens: int, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + compress_ratio: int, + out: torch.Tensor | None = None, +) -> torch.Tensor: + if out is not None: + # Guard: for padded / invalid sequences. + # Negative positions produce bogus block indices that lead to illegal memory + # accesses inside the block_table load. + # NOTE: Fill -1 to the whole tensor, not just the first `num_tokens`. + out.fill_(-1) + slot_mapping = out[:num_tokens] + else: + slot_mapping = torch.full( + (num_tokens,), -1, dtype=torch.int64, device=query_start_loc.device + ) + + num_reqs = block_table.shape[0] + _compressed_slot_mapping_kernel[(num_reqs,)]( + slot_mapping, + query_start_loc, + seq_lens, + block_table, + block_table.stride(0), + block_size, + compress_ratio, + PAD_ID=-1, + TRITON_BLOCK_SIZE=1024, + ) + return slot_mapping diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index e67282aab8cc..fe4cde313654 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -15,6 +15,8 @@ ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability +from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import num_compute_units from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( @@ -27,6 +29,11 @@ MultipleOf, SparseMLAAttentionImpl, ) +from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_sparse_mla_reference_attention_enabled_for_platform, + sparse_mla_reference_cudagraphs_allowed, +) from vllm.v1.attention.backends.mla.sparse_utils import ( triton_convert_req_index_to_global_index, ) @@ -65,8 +72,8 @@ """ NOTE: FlashMLA Sparse uses an fp8 cache with the following format -In the "FP8 with scale" format, each token's KV cache is 656 Bytes, -structured as: +For DeepSeek V3.2, in the "FP8 with scale" format, each token's KV cache is 656 +Bytes, structured as: - **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values. - **Next 16 bytes:** Scale factors, containing 4 `float32` values. @@ -74,6 +81,16 @@ the second for the next 128, and so on. - **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy. + +For DeepSeek V4, in the "FP8 with scale" format, each token's KV cache is 584 +Bytes, structured as: +- **First 448 bytes:** The "quantized NoPE" part, containing 448 + `float8_e4m3` values. +- **Next 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This + part is not quantized for accuracy. +- **Last 8 bytes:** Scale factors, containing 7 `ue8m0` values + 1B pad. + The first `ue8m0` is the scale for the first 64 `float8_e4m3` values, + the second for the next 64, and so on. """ @@ -104,7 +121,8 @@ def get_impl_cls() -> type["FlashMLASparseImpl"]: @classmethod def get_supported_head_sizes(cls) -> list[int]: - return [576] + # V3.2: 576 (512 NoPE + 64 RoPE); DeepseekV4: 512 (448 NoPE + 64 RoPE) + return [512, 576] @classmethod def is_mla(cls) -> bool: @@ -127,13 +145,37 @@ def get_kv_cache_shape( cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if cache_dtype_str == "fp8_ds_mla": - # custom storage format is 656 bytes - # see FlashMLA readme.md for details + # V3.2 main MLA: 656-byte custom storage format. See module docstring. return (num_blocks, block_size, 656) else: return (num_blocks, block_size, head_size) +class DeepseekV4FlashMLASparseBackend(FlashMLASparseBackend): + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [256] + + @staticmethod + def get_name() -> str: + return "V4_FLASHMLA_SPARSE" + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if cache_dtype_str == "fp8_ds_mla": + # DeepseekV4 main MLA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale). + # head_size passed in is the semantic head_dim (512). + return (num_blocks, block_size, 584) + else: + return (num_blocks, block_size, head_size) + + @dataclass class FlashMLASparseMetadata(AttentionMetadata): num_reqs: int @@ -159,6 +201,7 @@ class FP8KernelMetadata: class FP8SeparatePrefillDecode: @dataclass class Decode: + seq_lens: torch.Tensor kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata" decode_query_len: int # needed for reshape in spec decode @@ -206,6 +249,13 @@ class Chunk: fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None fp8_use_mixed_batch: bool = False + # Pre-computed C128A metadata (DeepseekV4 only, compress_ratio == 128). + # Decode: global slot ids + valid-entry counts (fused from positions). + c128a_global_decode_topk_indices: torch.Tensor | None = None + c128a_decode_topk_lens: torch.Tensor | None = None + # Prefill: local topk indices (used by combine_topk_swa_indices). + c128a_prefill_topk_indices: torch.Tensor | None = None + def get_prefill_workspace_size(max_model_len: int): # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. @@ -220,6 +270,20 @@ def get_prefill_workspace_size(max_model_len: int): class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_sparse_mla_reference_attention_enabled_for_platform() + and not sparse_mla_reference_cudagraphs_allowed(vllm_config) + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__( self, kv_cache_spec: AttentionSpec, @@ -235,8 +299,9 @@ def __init__( parallel_config = vllm_config.parallel_config self.device = device - # Treat requests with query length <= 1 as decodes to match the - # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) + # Classify single-token queries (plus num_speculative_tokens via + # supports_spec_as_decode=True) as decodes; longer queries go to + # prefill. self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) sm_count = num_compute_units(device.index) @@ -300,6 +365,68 @@ def __init__( device=device, ) + # DeepseekV4: has compress_ratios in hf_config. + hf_config = vllm_config.model_config.hf_config + self.is_deepseek_v4 = ( + hasattr(hf_config, "compress_ratios") and len(hf_config.compress_ratios) > 0 + ) + self.compress_ratio = 1 + if self.is_deepseek_v4: + assert hasattr(self.kv_cache_spec, "compress_ratio") + self.compress_ratio = self.kv_cache_spec.compress_ratio + # Pre-allocate compressed slot mapping buffer for CUDA graph + # address stability when compress_ratio > 1. + if self.compress_ratio > 1: + max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + ) + self.compressed_slot_mapping_buffer = torch.empty( + max_num_batched_tokens, + dtype=torch.int64, + device=self.device, + ) + + # Pre-allocate C128A topk buffers for CUDA graph address stability. + if self.compress_ratio == 128: + max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + ) + # Pad to B_TOPK alignment (128 covers both h_q=64 B_TOPK=64 and + # h_q=128 B_TOPK=128). FlashMLA decode asserts extra_topk % B_TOPK + # == 0; unaligned widths (e.g. 17 = ceil(2136/128)) crash the + # sm100 head64 kernel. Padded slots stay -1 and decode_lens caps + # them via topk_length, so the pad is a no-op at kernel level. + # Mirrors _SPARSE_PREFILL_TOPK_ALIGNMENT in cache_utils.py. + _C128A_TOPK_ALIGNMENT = 128 + c128a_max_compressed = cdiv( + self.model_config.max_model_len, self.compress_ratio + ) + c128a_max_compressed = ( + cdiv(c128a_max_compressed, _C128A_TOPK_ALIGNMENT) + * _C128A_TOPK_ALIGNMENT + ) + # Stored so _build_c128a_metadata passes it as the kernel's + # max_compressed_tokens, matching the buffer stride. Otherwise + # the kernel's default 8192 iterates past row width and spills + # writes into adjacent rows (present in both decode and prefill + # branches of _build_c128a_topk_metadata_kernel). + self.c128a_max_compressed = c128a_max_compressed + self.c128a_global_decode_buffer = torch.empty( + (max_num_batched_tokens, c128a_max_compressed), + dtype=torch.int32, + device=self.device, + ) + self.c128a_decode_lens_buffer = torch.empty( + max_num_batched_tokens, + dtype=torch.int32, + device=self.device, + ) + self.c128a_prefill_buffer = torch.empty( + (max_num_batched_tokens, c128a_max_compressed), + dtype=torch.int32, + device=self.device, + ) + def _build_fp8_mixed_decode_prefill( self, common_attn_metadata: CommonAttentionMetadata, @@ -460,15 +587,7 @@ def _build_fp8_separate_prefill_decode( decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item() # Use padded head count since that's what the kernel will see - padded_heads = self.fp8_decode_padded_heads - scheduler_metadata, _ = get_mla_metadata( - cache_seqlens=self.topk_tokens_tensor[:num_decodes], - num_q_tokens_per_head_k=decode_query_len * padded_heads, - topk=self.topk_tokens, - num_heads_q=padded_heads, - num_heads_k=1, - is_fp8_kvcache=True, - ) + scheduler_metadata, _ = get_mla_metadata() kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata( scheduler_metadata=scheduler_metadata, @@ -476,6 +595,7 @@ def _build_fp8_separate_prefill_decode( cache_lens=self.max_model_len_tensor[:num_decodes], ) fp8_metadata.decode = FP8Meta.Decode( + seq_lens=common_attn_metadata.seq_lens[:num_decodes], kernel_metadata=kernel_meta, decode_query_len=decode_query_len, ) @@ -502,35 +622,109 @@ def build( ) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + slot_mapping = cm.slot_mapping + if self.compress_ratio > 1: + slot_mapping = get_compressed_slot_mapping( + common_attn_metadata.num_actual_tokens, + common_attn_metadata.query_start_loc, + common_attn_metadata.seq_lens, + common_attn_metadata.block_table_tensor.clamp(min=0), + int(self.kv_cache_spec.storage_block_size), + self.compress_ratio, + out=self.compressed_slot_mapping_buffer, + ) + fp8_extra_metadata: ( FlashMLASparseMetadata.FP8SeparatePrefillDecode | FlashMLASparseMetadata.FP8KernelMetadata | None ) = None - fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL - if self.use_fp8_kv_cache: + fp8_use_mixed_batch = ( + self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4 + ) + # DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not + # consume fp8_extra_metadata. Skipping the build here avoids a + # forced D2H sync on seq_lens that would otherwise fire on every + # prefill-bearing step, lifting GPU utilization on long-prefill + # workloads (e.g. LongBench) from ~83% to ~100%. + if self.use_fp8_kv_cache and not self.is_deepseek_v4: if fp8_use_mixed_batch: fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm) else: fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm) + # Pre-compute C128A topk indices for DeepseekV4. + c128a_fields = {} + if self.is_deepseek_v4 and self.compress_ratio == 128: + c128a_fields = self._build_c128a_metadata(cm, req_id_per_token) + metadata = FlashMLASparseMetadata( num_reqs=cm.num_reqs, max_query_len=cm.max_query_len, max_seq_len=cm.max_seq_len, num_actual_tokens=cm.num_actual_tokens, query_start_loc=cm.query_start_loc, - slot_mapping=cm.slot_mapping, + slot_mapping=slot_mapping, block_table=cm.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, topk_tokens=self.topk_tokens, fp8_extra_metadata=fp8_extra_metadata, fp8_use_mixed_batch=fp8_use_mixed_batch, + **c128a_fields, ) return metadata + def _build_c128a_metadata( + self, + cm: CommonAttentionMetadata, + req_id_per_token: torch.Tensor, + ) -> dict[str, torch.Tensor | None]: + """Pre-compute C128A topk indices for DeepseekV4 (compress_ratio >= 128).""" + # Must match SWA's decode split (no `require_uniform=True`) so + # `c128a_global_decode_topk_indices.shape[0]` lines up with q in + # `_forward_decode`. The per-token C128A kernel handles non-uniform + # query lengths. + (num_decodes, _, num_decode_tokens, num_prefill_tokens) = ( + split_decodes_and_prefills( + cm, + decode_threshold=self.reorder_batch_threshold or 1, + ) + ) + + num_total = num_decode_tokens + num_prefill_tokens + if num_total == 0: + return {} + + assert cm.positions is not None, ( + "positions is required for C128A metadata build" + ) + block_size = self.kv_cache_spec.block_size // self.compress_ratio + global_decode, decode_lens, prefill_local = build_c128a_topk_metadata( + cm.positions[:num_total], + self.compress_ratio, + num_decode_tokens, + req_id_per_token, + cm.block_table_tensor[:num_decodes], + block_size, + cm.slot_mapping, + self.c128a_global_decode_buffer, + self.c128a_decode_lens_buffer, + self.c128a_prefill_buffer, + max_compressed_tokens=self.c128a_max_compressed, + ) + + result: dict[str, torch.Tensor | None] = {} + if num_decode_tokens > 0: + result["c128a_global_decode_topk_indices"] = global_decode.view( + num_decode_tokens, 1, -1 + ) + result["c128a_decode_topk_lens"] = decode_lens + if num_prefill_tokens > 0: + result["c128a_prefill_topk_indices"] = prefill_local + return result + class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]): @staticmethod @@ -552,7 +746,7 @@ def __init__( attn_type: str, kv_sharing_target_layer_name: str | None, # MLA Specific Arguments - topk_indice_buffer: torch.Tensor | None = None, + topk_indices_buffer: torch.Tensor | None = None, indexer: "Indexer | None" = None, **mla_args, ) -> None: @@ -615,7 +809,11 @@ def _forward_bf16_kv( NUM_TOPK_TOKENS=topk_indices.shape[1], ) - return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices) + return self._bf16_flash_mla_kernel( + q, + kv_c_and_k_pe_cache, + topk_indices, + ) def _forward_fp8_kv_separate_prefill_decode( self, @@ -656,7 +854,10 @@ def _forward_fp8_kv_separate_prefill_decode( fp8_metadata = attn_metadata.fp8_extra_metadata assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode) - def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: + def _fp8_decode( + q: torch.Tensor, + topk_indices: torch.Tensor, + ) -> torch.Tensor: # Reshape q: (num_decode_tokens, num_heads, head_dim) # -> (num_decodes, seq_len, num_heads, head_dim) q = reshape_query_for_spec_decode(q, num_decodes) @@ -692,7 +893,8 @@ def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor: if num_decode_tokens > 0: attn_out[:num_decode_tokens] = _fp8_decode( - q[:num_decode_tokens], topk_indices[:num_decode_tokens] + q[:num_decode_tokens], + topk_indices[:num_decode_tokens], ) assert fp8_metadata.prefill is not None @@ -823,6 +1025,7 @@ def _bf16_flash_mla_kernel( output = flash_mla_sparse_fwd( q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale )[0] + output = output[:, : self.num_heads, :] return output @@ -864,3 +1067,123 @@ def forward_mqa( ) return attn_out, None + + +def build_c128a_topk_metadata( + positions: torch.Tensor, + compress_ratio: int, + num_decode_tokens: int, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + slot_mapping: torch.Tensor, + global_decode_buffer: torch.Tensor, + decode_lens_buffer: torch.Tensor, + prefill_buffer: torch.Tensor, + max_compressed_tokens: int = 8192, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Single kernel for all C128A tokens (decode + prefill). + + Decode tokens: position → block_table lookup → global slot ids + topk_lens. + Prefill tokens: position → local indices [0, ..., n-1, -1, ...]. + + Writes into pre-allocated buffers for CUDA graph address stability. + Returns slices of the buffers. + """ + num_tokens = positions.shape[0] + num_prefill_tokens = num_tokens - num_decode_tokens + + global_decode = global_decode_buffer[:num_decode_tokens] + decode_lens = decode_lens_buffer[:num_decode_tokens] + prefill_local = prefill_buffer[:num_prefill_tokens] + + if num_tokens == 0: + return global_decode, decode_lens, prefill_local + + _build_c128a_topk_metadata_kernel[(num_tokens,)]( + global_decode_buffer, + global_decode_buffer.stride(0), + decode_lens_buffer, + prefill_buffer, + prefill_buffer.stride(0), + positions, + compress_ratio, + max_compressed_tokens, + num_decode_tokens, + token_to_req_indices, + block_table, + block_table.stride(0), + block_size, + slot_mapping, + BLOCK_SIZE=1024, + ) + return global_decode, decode_lens, prefill_local + + +@triton.jit +def _build_c128a_topk_metadata_kernel( + # Decode outputs + global_decode_ptr, + global_decode_stride, + decode_lens_ptr, + # Prefill output + prefill_local_ptr, + prefill_local_stride, + # Inputs + positions_ptr, + compress_ratio, + max_compressed_tokens, + num_decode_tokens, + token_to_req_indices_ptr, + block_table_ptr, + block_table_stride, + block_size, + slot_mapping_ptr, + BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + position = tl.load(positions_ptr + token_idx) + num_compressed = (position + 1) // compress_ratio + num_compressed = tl.minimum(num_compressed, max_compressed_tokens) + is_decode = token_idx < num_decode_tokens + + if is_decode: + # --- Decode: block-table lookup → global slot ids + count --- + is_valid_token = tl.load(slot_mapping_ptr + token_idx) >= 0 + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + count = tl.zeros((), dtype=tl.int32) + for i in range(0, max_compressed_tokens, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < max_compressed_tokens + is_valid = offset < num_compressed + + block_indices = offset // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=mask & is_valid, + ) + block_offsets = offset % block_size + slot_ids = block_numbers * block_size + block_offsets + slot_ids = tl.where(is_valid, slot_ids, -1) + tl.store( + global_decode_ptr + token_idx * global_decode_stride + offset, + slot_ids, + mask=mask, + ) + count += tl.sum(is_valid.to(tl.int32), axis=0) + + tl.store( + decode_lens_ptr + token_idx, + tl.where(is_valid_token, count, 0), + ) + else: + # --- Prefill: write local indices --- + pfx_idx = token_idx - num_decode_tokens + for i in range(0, max_compressed_tokens, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < max_compressed_tokens + tl.store( + prefill_local_ptr + pfx_idx * prefill_local_stride + offset, + tl.where(offset < num_compressed, offset, -1), + mask=mask, + ) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 237ccfeb4729..59e2e4a059d9 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import dataclass import torch -import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -22,15 +22,30 @@ CommonAttentionMetadata, MultipleOf, ) +from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping from vllm.v1.attention.backends.utils import ( split_decodes_and_prefills, ) -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm.v1.worker.cp_utils import get_total_cp_world_size logger = init_logger(__name__) +def sparse_indexer_max_logits_bytes(is_sm12x: bool | None = None) -> int: + configured_mb = os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB") + if configured_mb is not None: + return int(configured_mb) * 1024 * 1024 + + if is_sm12x is None: + is_sm12x = ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + ) + default_mb = 256 if is_sm12x else 512 + return default_mb * 1024 * 1024 + + @triton.jit def _prepare_uniform_decode_kernel( seq_lens_ptr, @@ -154,6 +169,16 @@ def get_kv_cache_stride_order( return (0, 1, 2) +class DeepseekV4IndexerBackend(DeepseekV32IndexerBackend): + @staticmethod + def get_name() -> str: + return "DEEPSEEK_V4_INDEXER" + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [256] + + @dataclass class DeepseekV32IndexerPrefillChunkMetadata: block_table: torch.Tensor @@ -179,7 +204,7 @@ class DeepSeekV32IndexerDecodeMetadata: # seq_lens: per-token effective context lengths. # - flatten path / plain decode: 1D (batch_size,) # - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1 - # Both fp8_paged_mqa_logits and the topk kernels accept both shapes. + # Both fp8_fp4_paged_mqa_logits and the topk kernels accept both shapes. seq_lens: torch.Tensor decode_lens: torch.Tensor requires_padding: bool @@ -191,16 +216,8 @@ class DeepseekV32IndexerMetadata: # FIXME (zyongye) # hacky way to access the data now, need to be in chunked meta seq_lens: torch.Tensor - - num_reqs: int - max_query_len: int max_seq_len: int - - num_actual_tokens: int # Number of tokens excluding padding. - query_start_loc: torch.Tensor slot_mapping: torch.Tensor - # The dimension of the attention heads - head_dim: int # New for MLA (compared to FlashAttention) # For handling prefill decode split @@ -213,71 +230,6 @@ class DeepseekV32IndexerMetadata: prefill: DeepseekV32IndexerPrefillMetadata | None = None -# TODO (zyongye) optimize this, this is now vibe coded -def kv_spans_from_batches( - start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Args: - start_seq_loc: 1D long tensor [B+1], cumulative counts of - selected tokens per batch. - Example: [0, 2, 4, 7] -> - batch sizes (selected) [2, 2, 3], N=7 tokens total. - seq_len_per_batch: 1D long tensor [B], - full sequence length (KV length) of each batch. - Example: [5, 9, 4]. - - Returns: - start_tensor: 1D long tensor [N], start offset in the - concatenated KV cache for each token's batch. - end_location: 1D long tensor [N], - **exclusive** end = start + token's local position. - (So the attended KV slice is kv[start:end].) - - Assumes each batch contributes its full `seq_len_per_batch[i]` - keys to the KV cache, andthe selected tokens within a batch - are the **last** `counts[i]` positions of that sequence. - """ - q = start_seq_loc.to(dtype=torch.long) - L = seq_len_per_batch.to(dtype=torch.long) - assert q.dim() == 1 and L.dim() == 1 - assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" - - # Selected tokens per batch and totals - counts = q[1:] - q[:-1] # [B] - N = int(q[-1].item()) # total selected tokens - B = L.numel() - - if N == 0: - return ( - torch.empty(0, dtype=torch.long, device=device), - torch.empty(0, dtype=torch.long, device=device), - ) - - # KV start offsets per batch in the concatenated KV cache - kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] - - # For each selected token, which batch does it belong to? - batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N] - - # Map batch KV start to each token - start_tensor = kv_starts_per_batch[batch_id] # [N] - - # End-align local positions inside each batch: - # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b - L_expand = torch.repeat_interleave(L, counts) # [N] - m_expand = torch.repeat_interleave(counts, counts) # [N] - # position within the selected block: 1..counts[b] - pos_within = ( - torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1 - ) - - local_pos = L_expand - m_expand + pos_within # [N], 1-based - end_location = start_tensor + local_pos # exclusive end - - return start_tensor.int().to(device), end_location.int().to(device) - - def get_max_prefill_buffer_size(vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len # NOTE(Chen): 40 is a magic number for controlling the prefill buffer size. @@ -293,7 +245,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): reorder_batch_threshold: int = 1 - natively_supported_next_n: list[int] = [1, 2] + natively_supported_next_n_fp4: list[int] = [1, 2] # TODO (matt): integrate kernel with next_n = 4 support @classmethod @@ -314,9 +266,29 @@ def __init__(self, *args, **kwargs): if self.vllm_config.speculative_config else 0 ) + self.use_fp4_indexer_cache = ( + self.vllm_config.attention_config.use_fp4_indexer_cache + ) + + assert ( + current_platform.is_device_capability_family(100) + or not self.use_fp4_indexer_cache + ), ( + "use_fp4_indexer_cache requires Blackwell datacenter GPUs " + "(sm_10x, e.g. B200/GB200); sm_120 (consumer Blackwell) and " + "earlier architectures are not supported." + ) + next_n = self.num_speculative_tokens + 1 self.reorder_batch_threshold += self.num_speculative_tokens - self.use_flattening = next_n not in self.natively_supported_next_n + # NOTE(zyongye) fp4 indexer cache only natively supports next_n in + # natively_supported_next_n_fp4; for other next_n values we fall back + # to the flattening path. When fp4 indexer cache is disabled, the + # native (non-flattening) path handles all next_n values. + self.use_flattening = ( + self.use_fp4_indexer_cache + and next_n not in self.natively_supported_next_n_fp4 + ) sm_count = num_compute_units(self.device.index) self.num_sms = sm_count @@ -331,7 +303,6 @@ def __init__(self, *args, **kwargs): ) if not self.use_flattening and next_n > 1: # Native MTP: 2D buffer for per-token seq_lens. - # Flattening path is never used, so no expanded_seq_lens_buffer. self.decode_seq_lens_buffer = torch.zeros( (scheduler_config.max_num_seqs, next_n), dtype=torch.int32, @@ -367,53 +338,27 @@ def __init__(self, *args, **kwargs): (self.num_sms + 1, 2), dtype=torch.int32, device=self.device ) - def build_one_prefill_chunk( - self, - req_slice: slice, - query_slice: slice, - query_start_loc_cpu, - seq_lens_cpu, - block_table, - skip_kv_gather: bool = False, - ) -> DeepseekV32IndexerPrefillChunkMetadata: - prefill_query_start_loc = ( - query_start_loc_cpu[req_slice.start : req_slice.stop + 1] - - query_start_loc_cpu[req_slice.start] - ) - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( - prefill_query_start_loc, seq_lens_cpu[req_slice], self.device - ) - token_start = query_start_loc_cpu[req_slice.start].item() - total_seq_lens = seq_lens_cpu[req_slice].sum() - num_reqs = req_slice.stop - req_slice.start - seq_idx = torch.arange(0, num_reqs, dtype=torch.int32) - token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to( - self.device - ) - assert total_seq_lens <= self.max_prefill_buffer_size - cu_seq_lens = ( - torch.cat( - [ - torch.zeros(1, dtype=torch.int32), - seq_lens_cpu[req_slice].cumsum(dim=0), - ] + # KV compression. Default to 1 for no compression. + self.compress_ratio = 1 + # Get compress_ratio for DeepseekV4 support + if isinstance(self.kv_cache_spec, MLAAttentionSpec): + self.compress_ratio = self.kv_cache_spec.compress_ratio + + # Pre-allocate buffers for CUDA graph compatibility when + if self.compress_ratio > 1: + # compress_ratio > 1 (DeepseekV4) + # Compressed slot mapping output buffer + self.compressed_slot_mapping_buffer = torch.zeros( + (scheduler_config.max_num_batched_tokens,), + dtype=torch.int64, + device=self.device, + ) + # Buffer for compressed seq_lens in decode path + self.expanded_seq_lens_buffer = torch.zeros( + (scheduler_config.max_num_batched_tokens,), + dtype=torch.int32, + device=self.device, ) - .to(torch.int32) - .to(self.device) - ) - - return DeepseekV32IndexerPrefillChunkMetadata( - cu_seqlen_ks=cu_seqlen_ks[query_slice], - cu_seqlen_ke=cu_seqlen_ke[query_slice], - cu_seq_lens=cu_seq_lens, - token_to_seq=token_to_seq, - total_seq_lens=total_seq_lens, - block_table=block_table[req_slice], - token_start=token_start + query_slice.start, - token_end=token_start + query_slice.stop, - num_reqs=num_reqs, - skip_kv_gather=skip_kv_gather, - ) def _prepare_decode_tensors( self, @@ -520,11 +465,15 @@ def _prepare_decode_tensors( requires_padding = min_decode_len != max_decode_len if use_native and next_n > 1: assert self.decode_seq_lens_buffer.dim() == 2 - # (B, next_n): token j attends to L - next_n + j + 1 KV tokens - self.decode_seq_lens_buffer[:num_decodes] = ( - seq_lens.unsqueeze(1) - next_n + 1 + self.offsets_buffer + # (B, max_decode_len): token j attends to + # L - max_decode_len + j + 1 KV tokens. + self.decode_seq_lens_buffer[:num_decodes, :max_decode_len] = ( + seq_lens.unsqueeze(1) + - max_decode_len + + 1 + + self.offsets_buffer[:max_decode_len] ) - seq_lens = self.decode_seq_lens_buffer[:num_decodes] + seq_lens = self.decode_seq_lens_buffer[:num_decodes, :max_decode_len] return seq_lens, block_table, decode_lens, num_decodes, requires_padding def build( @@ -535,8 +484,12 @@ def build( ) -> DeepseekV32IndexerMetadata: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens - + query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens + slot_mapping = common_attn_metadata.slot_mapping + block_table = common_attn_metadata.block_table_tensor + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, @@ -548,37 +501,67 @@ def build( assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens + compressed_slot_mapping = slot_mapping + compressed_seq_lens = seq_lens + if self.compress_ratio > 1: + compressed_slot_mapping = get_compressed_slot_mapping( + num_tokens, + query_start_loc, + seq_lens, + block_table, + self.kv_cache_spec.storage_block_size, + self.compress_ratio, + out=self.compressed_slot_mapping_buffer, + ) + compressed_seq_lens = seq_lens // self.compress_ratio + prefill_metadata = None if num_prefills > 0: + # This CPU value is an upper bound for async-spec extend rows. It + # is safe for chunking/allocation because CUDA metadata below is + # built from exact device seq_lens and gather ignores the tail. + assert common_attn_metadata.seq_lens_cpu_upper_bound is not None + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound + compressed_seq_lens_cpu = ( + seq_lens_cpu // self.compress_ratio + if self.compress_ratio > 1 + else seq_lens_cpu + ) prefill_query_lens_cpu = torch.diff( query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] ) - max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_bytes = sparse_indexer_max_logits_bytes() # Upper bound is exact for prefill rows (the `[num_decodes:]` # slice below). assert common_attn_metadata.seq_lens_cpu_upper_bound is not None seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound chunk_specs = split_indexer_prefill_chunks( - seq_lens_cpu[num_decodes:], + compressed_seq_lens_cpu[num_decodes:], prefill_query_lens_cpu, self.max_prefill_buffer_size, max_logits_bytes, request_offset=num_decodes, ) - chunks = [ - self.build_one_prefill_chunk( - req_slice, - query_slice, + + chunks = [] + for req_slice, query_slice in chunk_specs: + metadata = build_prefill_chunk_metadata( + req_slice.start, + req_slice.stop, + query_start_loc, query_start_loc_cpu, - seq_lens_cpu, + seq_lens, + compressed_seq_lens, + compressed_seq_lens_cpu, common_attn_metadata.block_table_tensor, + self.compress_ratio, + query_slice=query_slice, skip_kv_gather=query_slice.start > 0, ) - for req_slice, query_slice in chunk_specs - ] - prefill_metadata = DeepseekV32IndexerPrefillMetadata( - chunks=chunks, - ) + # Skip when total_seq_lens is 0 (i.e., no compressed token). + if metadata is not None: + chunks.append(metadata) + prefill_metadata = DeepseekV32IndexerPrefillMetadata(chunks) decode_metadata = None if num_decodes > 0: @@ -596,7 +579,7 @@ def build( max_decode_len = int(decode_lens_cpu.max().item()) next_n = 1 + self.num_speculative_tokens - use_native = not self.use_flattening and max_decode_len == next_n + use_native = not self.use_flattening and max_decode_len <= next_n seq_lens, block_table, decode_lens, batch_size, requires_padding = ( self._prepare_decode_tensors( @@ -613,11 +596,35 @@ def build( ) ) + # For DeepseekV4 (compress_ratio > 1), the indexer KV cache stores + # compressed tokens. Convert uncompressed seq_lens to compressed. + if self.compress_ratio > 1: + # True iff seq_lens aliases decode_seq_lens_buffer (flatten or + # native wrote it); False iff it aliases common_attn_metadata. + seq_lens_is_local_view = (use_native and next_n > 1) or ( + not use_native and max_decode_len > 1 + ) + if seq_lens_is_local_view: + seq_lens //= self.compress_ratio + else: + # Copy to avoid mutating shared state; keeps CG address stable. + self.expanded_seq_lens_buffer[:num_decodes] = ( + seq_lens // self.compress_ratio + ) + self.expanded_seq_lens_buffer[num_decodes:num_decode_tokens] = 0 + seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens] + + # Non-MTP: deep_gemm paged MQA logits requires 2D context_lens + # (csrc/apis/attention.hpp). Unsqueeze to (B, 1) so downstream + # kernels see the same (B, next_n) layout as the MTP path. + if seq_lens.dim() == 1: + seq_lens = seq_lens.unsqueeze(-1) + # DeepGEMM is required for the paged MQA logits on CUDA devices if current_platform.is_cuda() and has_deep_gemm(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( seq_lens, - self.kv_cache_spec.block_size, + self.kv_cache_spec.storage_block_size, self.num_sms, ) @@ -631,13 +638,8 @@ def build( attn_metadata = DeepseekV32IndexerMetadata( seq_lens=common_attn_metadata.seq_lens, - num_reqs=common_attn_metadata.num_reqs, - max_query_len=common_attn_metadata.max_query_len, max_seq_len=common_attn_metadata.max_seq_len, - num_actual_tokens=common_attn_metadata.num_actual_tokens, - query_start_loc=common_attn_metadata.query_start_loc, - slot_mapping=common_attn_metadata.slot_mapping, - head_dim=128, + slot_mapping=compressed_slot_mapping, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, @@ -647,3 +649,138 @@ def build( ) return attn_metadata + + +def build_prefill_chunk_metadata( + start_idx: int, + end_idx: int, + query_start_loc: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + uncompressed_seq_lens: torch.Tensor, + compressed_seq_lens: torch.Tensor, + compressed_seq_lens_cpu: torch.Tensor, + block_table: torch.Tensor, + compress_ratio: int, + query_slice: slice | None = None, + skip_kv_gather: bool = False, +) -> DeepseekV32IndexerPrefillChunkMetadata | None: + total_seq_lens = compressed_seq_lens_cpu[start_idx:end_idx].sum().item() + if total_seq_lens == 0: + return None + + num_reqs = end_idx - start_idx + device = block_table.device + token_to_seq = torch.empty(total_seq_lens, dtype=torch.int32, device=device) + + cu_seq_lens = torch.empty(num_reqs + 1, dtype=torch.int32, device=device) + # Assigning to slice avoids cpu sync. + cu_seq_lens[:1] = 0 + torch.cumsum(compressed_seq_lens[start_idx:end_idx], dim=0, out=cu_seq_lens[1:]) + + query_start_loc = ( + query_start_loc[start_idx : end_idx + 1] - query_start_loc[start_idx] + ) + + total_query_len = int( + (query_start_loc_cpu[end_idx] - query_start_loc_cpu[start_idx]).item() + ) + if query_slice is not None: + qs_start = query_slice.start + qs_stop = query_slice.stop + else: + qs_start = 0 + qs_stop = total_query_len + output_query_len = qs_stop - qs_start + + cu_seq_len_ks = torch.empty(output_query_len, dtype=torch.int32, device=device) + cu_seq_len_ke = torch.empty(output_query_len, dtype=torch.int32, device=device) + + _build_prefill_chunk_metadata_kernel[(num_reqs,)]( + query_start_loc, + uncompressed_seq_lens[start_idx:end_idx], + cu_seq_lens, + token_to_seq, + cu_seq_len_ks, + cu_seq_len_ke, + qs_start, + qs_stop, + BLOCK_SIZE=1024, + COMPRESS_RATIO=compress_ratio, + ) + + token_start = query_start_loc_cpu[start_idx].item() + if query_slice is not None: + token_end = token_start + qs_stop + token_start = token_start + qs_start + skip_kv_gather = skip_kv_gather or qs_start > 0 + else: + token_end = query_start_loc_cpu[end_idx].item() + + return DeepseekV32IndexerPrefillChunkMetadata( + cu_seqlen_ks=cu_seq_len_ks, + cu_seqlen_ke=cu_seq_len_ke, + cu_seq_lens=cu_seq_lens, + token_to_seq=token_to_seq, + total_seq_lens=total_seq_lens, + block_table=block_table[start_idx:end_idx], + token_start=token_start, + token_end=token_end, + num_reqs=num_reqs, + skip_kv_gather=skip_kv_gather, + ) + + +@triton.jit +def _build_prefill_chunk_metadata_kernel( + # Inputs + query_start_loc_ptr, + uncompressed_seq_lens_ptr, + cu_compressed_seq_lens_ptr, + # Outputs + token_to_seq_ptr, + cu_compressed_seq_len_ks_ptr, + cu_compressed_seq_len_ke_ptr, + query_slice_start, + query_slice_stop, + BLOCK_SIZE: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, +): + batch_idx = tl.program_id(0) + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + seq_start = tl.load(cu_compressed_seq_lens_ptr + batch_idx) + seq_end = tl.load(cu_compressed_seq_lens_ptr + batch_idx + 1) + compressed_seq_len = seq_end - seq_start + + uncompressed_seq_len = tl.load(uncompressed_seq_lens_ptr + batch_idx) + start_pos = uncompressed_seq_len - query_len + + for i in range(0, query_len, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + abs_pos = query_start + offset + mask = ( + (offset < query_len) + & (abs_pos >= query_slice_start) + & (abs_pos < query_slice_stop) + ) + out_pos = abs_pos - query_slice_start + + # Compute cu_seq_len_ks + tl.store(cu_compressed_seq_len_ks_ptr + out_pos, seq_start, mask=mask) + + # Compute cu_seq_len_ke + seq_len_per_token = (start_pos + 1 + offset) // COMPRESS_RATIO + tl.store( + cu_compressed_seq_len_ke_ptr + out_pos, + seq_start + seq_len_per_token, + mask=mask, + ) + + # Compute token_to_seq + for i in range(0, compressed_seq_len, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < compressed_seq_len + tl.store(token_to_seq_ptr + seq_start + offset, batch_idx, mask=mask) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py new file mode 100644 index 000000000000..fa1ce61314d3 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Environment controls for the portable sparse MLA fallback.""" + +import os + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +_TRITON_MLA_SPARSE_ENV = "VLLM_TRITON_MLA_SPARSE" +_TRITON_MLA_SPARSE_DUMP_ENV = "VLLM_TRITON_MLA_SPARSE_DUMP" +_TRITON_MLA_SPARSE_DUMP_PATH_ENV = "VLLM_TRITON_MLA_SPARSE_DUMP_PATH" +_TRITON_MLA_SPARSE_TOPK_CHUNK_ENV = "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE" +_TRITON_MLA_SPARSE_QUERY_CHUNK_ENV = "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE" +_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV = ( + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH" +) +_TRITON_MLA_SPARSE_HEAD_BLOCK_ENV = "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE" +_TRITON_MLA_SPARSE_MATMUL_DECODE_ENV = "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE" + +_ENV_TRUE_VALUES = {"1", "true", "yes", "on"} +_ENV_FALSE_VALUES = {"0", "false", "no", "off"} + +logger = init_logger(__name__) + + +def _optional_env_flag(name: str) -> bool | None: + raw_value = os.getenv(name) + if raw_value is None: + return None + value = raw_value.lower() + if value in _ENV_TRUE_VALUES: + return True + if value in _ENV_FALSE_VALUES: + return False + return None + + +def _is_sm12x_device(device: torch.device) -> bool: + if not torch.cuda.is_available(): + return False + index = device.index if device.index is not None else torch.cuda.current_device() + return torch.cuda.get_device_capability(index)[0] == 12 + + +def is_sparse_mla_attention_dump_enabled() -> bool: + configured = _optional_env_flag(_TRITON_MLA_SPARSE_DUMP_ENV) + if configured is not None: + return configured + return False + + +def sparse_mla_reference_attention_configured() -> bool | None: + return _optional_env_flag(_TRITON_MLA_SPARSE_ENV) + + +def is_sparse_mla_reference_attention_enabled_for_platform() -> bool: + configured = sparse_mla_reference_attention_configured() + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) + + +def is_sparse_mla_reference_attention_enabled(device: torch.device) -> bool: + configured = sparse_mla_reference_attention_configured() + if configured is not None: + return configured + return _is_sm12x_device(device) + + +def _uses_speculative_decoding(vllm_config) -> bool: + return bool(getattr(vllm_config, "speculative_config", None)) + + +def sparse_mla_reference_cudagraphs_allowed(vllm_config=None) -> bool: + configured = _optional_env_flag(_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV) + if configured is not None: + return configured + return not ( + vllm_config is not None and _uses_speculative_decoding(vllm_config) + ) + + +def disable_sparse_mla_reference_cudagraphs_if_enabled(vllm_config) -> None: + if not is_sparse_mla_reference_attention_enabled_for_platform(): + return + if sparse_mla_reference_cudagraphs_allowed(vllm_config): + logger.warning_once( + "Keeping vLLM compile and CUDA graphs enabled for the DeepSeek V4 " + "Triton sparse MLA fallback because " + f"{_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV}=1 or speculative " + "decoding is not configured. This is an " + "experimental performance mode." + ) + return + + from vllm.config.compilation import CompilationMode, CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if ( + compilation_config.mode == CompilationMode.NONE + and compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): + return + + logger.warning_once( + "Disabling vLLM compile and CUDA graphs for the DeepSeek V4 Triton " + "sparse MLA fallback because the current fallback path is not " + "compile/graph-safe yet, or because speculative decoding uses " + "multi-token sparse MLA decode." + ) + compilation_config.mode = CompilationMode.NONE + compilation_config.compile_sizes = [] + compilation_config.compile_ranges_endpoints = [] + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + compilation_config.cudagraph_capture_sizes = [] + compilation_config.max_cudagraph_capture_size = 0 + + +def sparse_mla_attention_dump_path() -> str: + return ( + os.getenv(_TRITON_MLA_SPARSE_DUMP_PATH_ENV) + or "/tmp/deepseek_v4_triton_mla_sparse_dump.jsonl" + ) + + +def sparse_mla_reference_topk_chunk_size() -> int: + raw_value = os.getenv(_TRITON_MLA_SPARSE_TOPK_CHUNK_ENV) + if raw_value is None: + return 512 + try: + return max(1, int(raw_value)) + except ValueError: + return 512 + + +def sparse_mla_reference_query_chunk_size() -> int: + raw_value = os.getenv(_TRITON_MLA_SPARSE_QUERY_CHUNK_ENV) + if raw_value is None: + return 256 + try: + return max(1, int(raw_value)) + except ValueError: + return 256 + + +def sparse_mla_reference_head_block_size() -> int | None: + raw_value = os.getenv(_TRITON_MLA_SPARSE_HEAD_BLOCK_ENV) + if raw_value is None: + return None + try: + value = int(raw_value) + except ValueError: + return None + if value in (1, 2, 4): + return value + return None + + +def sparse_mla_matmul_decode_enabled() -> bool: + configured = _optional_env_flag(_TRITON_MLA_SPARSE_MATMUL_DECODE_ENV) + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py new file mode 100644 index 000000000000..11cbbe95774b --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -0,0 +1,2362 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Portable sparse MLA Triton kernels.""" + +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + sparse_mla_reference_head_block_size, +) + + +def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: + """Choose the SM12x sparse MLA head grouping for decode kernels. + + Single-token decode is latency sensitive and does best with one head per + program. Once there are enough query tokens, grouping heads lets the kernel + reuse each dequantized KV row across multiple heads. + """ + + configured_head_block_size = sparse_mla_reference_head_block_size() + if configured_head_block_size is not None: + return configured_head_block_size + if num_decode_tokens <= 4: + return 1 + if num_decode_tokens < 16: + return 2 + return 4 + + +@triton.jit +def _merge_two_subsets_with_sink_kernel( + out0_ptr, + lse0_ptr, + out1_ptr, + lse1_ptr, + sink_ptr, + output_ptr, + stride_out0_t: tl.constexpr, + stride_out0_h: tl.constexpr, + stride_out0_d: tl.constexpr, + stride_lse0_t: tl.constexpr, + stride_lse0_h: tl.constexpr, + stride_out1_t: tl.constexpr, + stride_out1_h: tl.constexpr, + stride_out1_d: tl.constexpr, + stride_lse1_t: tl.constexpr, + stride_lse1_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + lse0 = tl.load(lse0_ptr + token_idx * stride_lse0_t + head_idx * stride_lse0_h) + lse1 = tl.load(lse1_ptr + token_idx * stride_lse1_t + head_idx * stride_lse1_h) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(tl.maximum(lse0, lse1), sink) + + weight0 = tl.exp(lse0 - merge_max) + weight1 = tl.exp(lse1 - merge_max) + weight_sink = tl.exp(sink - merge_max) + denom = weight0 + weight1 + weight_sink + + out0 = tl.load( + out0_ptr + + token_idx * stride_out0_t + + head_idx * stride_out0_h + + offsets * stride_out0_d, + mask=mask, + other=0.0, + ).to(tl.float32) + out1 = tl.load( + out1_ptr + + token_idx * stride_out1_t + + head_idx * stride_out1_h + + offsets * stride_out1_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = (out0 * weight0 + out1 * weight1) / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_two_sparse_mla_subsets_with_sink( + subset0_output: torch.Tensor, + subset0_lse: torch.Tensor, + subset1_output: torch.Tensor, + subset1_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset0_output.shape == subset1_output.shape + assert subset0_output.shape == output.shape + assert subset0_lse.shape == subset1_lse.shape + assert subset0_lse.shape == subset0_output.shape[:2] + assert attn_sink.shape[0] == subset0_output.shape[1] + assert subset0_output.is_cuda + assert subset1_output.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset0_output.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_two_subsets_with_sink_kernel[grid]( + subset0_output, + subset0_lse, + subset1_output, + subset1_lse, + attn_sink, + output, + subset0_output.stride(0), + subset0_output.stride(1), + subset0_output.stride(2), + subset0_lse.stride(0), + subset0_lse.stride(1), + subset1_output.stride(0), + subset1_output.stride(1), + subset1_output.stride(2), + subset1_lse.stride(0), + subset1_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _merge_single_subset_with_sink_kernel( + subset_output_ptr, + subset_lse_ptr, + sink_ptr, + output_ptr, + stride_subset_t: tl.constexpr, + stride_subset_h: tl.constexpr, + stride_subset_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + subset_lse = tl.load( + subset_lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h + ) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(subset_lse, sink) + + subset_weight = tl.exp(subset_lse - merge_max) + sink_weight = tl.exp(sink - merge_max) + denom = subset_weight + sink_weight + subset_output = tl.load( + subset_output_ptr + + token_idx * stride_subset_t + + head_idx * stride_subset_h + + offsets * stride_subset_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = subset_output * subset_weight / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_sparse_mla_subset_with_sink( + subset_output: torch.Tensor, + subset_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset_output.shape == output.shape + assert subset_lse.shape == subset_output.shape[:2] + assert attn_sink.shape[0] == subset_output.shape[1] + assert subset_output.is_cuda + assert subset_lse.is_cuda + assert attn_sink.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset_output.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_single_subset_with_sink_kernel[grid]( + subset_output, + subset_lse, + attn_sink, + output, + subset_output.stride(0), + subset_output.stride(1), + subset_output.stride(2), + subset_lse.stride(0), + subset_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _build_combined_decode_valid_mask_kernel( + output_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_lens_ptr, + stride_output_t: tl.constexpr, + stride_output_c: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_C: tl.constexpr, +): + token_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_C) + candidate_mask = offsets < num_candidates + + topk_lens = tl.load(topk_lens_ptr + token_idx) + swa_lens = tl.load(swa_lens_ptr + token_idx) + is_compressed = offsets < num_compressed_candidates + swa_offsets = offsets - num_compressed_candidates + slot_ids = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + offsets * stride_slot_c, + mask=is_compressed, + other=-1, + ) + valid_compressed = is_compressed & (offsets < topk_lens) & (slot_ids >= 0) + valid_swa = (~is_compressed) & (swa_offsets < swa_lens) + valid = valid_compressed | valid_swa + tl.store( + output_ptr + token_idx * stride_output_t + offsets * stride_output_c, + valid, + mask=candidate_mask, + ) + + +def build_combined_sparse_mla_decode_valid_mask( + output: torch.Tensor, + compressed_slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + swa_lens: torch.Tensor, +) -> None: + """Build `[compressed, SWA]` validity mask for SM12x decode fallback.""" + if compressed_slot_ids.dim() == 3: + assert compressed_slot_ids.shape[1] == 1 + compressed_slot_ids = compressed_slot_ids[:, 0, :] + + assert output.dim() == 2 + assert output.dtype == torch.bool + assert compressed_slot_ids.dim() == 2 + assert output.shape[0] == compressed_slot_ids.shape[0] + assert output.shape[0] == topk_lens.shape[0] + assert output.shape[0] == swa_lens.shape[0] + assert output.shape[1] >= compressed_slot_ids.shape[1] + assert output.is_cuda + assert compressed_slot_ids.is_cuda + assert topk_lens.is_cuda + assert swa_lens.is_cuda + + num_candidates = output.shape[1] + block_c = triton.next_power_of_2(num_candidates) + _build_combined_decode_valid_mask_kernel[(output.shape[0],)]( + output, + compressed_slot_ids, + topk_lens, + swa_lens, + output.stride(0), + output.stride(1), + compressed_slot_ids.stride(0), + compressed_slot_ids.stride(1), + compressed_slot_ids.shape[1], + num_candidates, + BLOCK_C=block_c, + num_warps=4, + ) + + +def matmul_sparse_mla_attention_with_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, +) -> None: + """Compute sink-aware sparse MLA over materialized BF16 KV. + + This path intentionally dequantizes/gathers KV once and then reuses it + across all heads with batched matrix multiplications. It is useful for the + SM12x decode fallback where the direct Triton reference kernel otherwise + repeats fp8_ds_mla dequantization once per head group. + """ + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert valid_tokens.shape == kv.shape[:2] + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert q.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + q_active = q[:, :active_heads] + if q_active.dtype != kv.dtype: + q_active = q_active.to(kv.dtype) + + scores = torch.bmm(q_active, kv.transpose(1, 2)).float() + scores.mul_(scale) + scores.masked_fill_(~valid_tokens[:, None, :], float("-inf")) + scores = torch.cat( + ( + scores, + attn_sink[:active_heads][None, :, None].expand( + q.shape[0], active_heads, 1 + ), + ), + dim=2, + ) + + weights = torch.softmax(scores, dim=-1)[..., : kv.shape[1]] + result = torch.bmm(weights.to(kv.dtype), kv) + output[:, :active_heads].copy_(result.to(output.dtype)) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + + +@triton.jit +def _accumulate_gathered_attention_chunk_kernel( + q_ptr, + kv_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HAS_SLOT_IDS: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_idx * stride_q_h + + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + is_valid = (candidate_offset + candidate_idx) < valid_len + if HAS_SLOT_IDS: + slot_id = tl.load( + slot_ids_ptr + + token_idx * stride_slot_t + + candidate_idx * stride_slot_c + ) + is_valid = is_valid & (slot_id >= 0) + + if is_valid: + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_gathered_sparse_mla_attention_chunk( + q: torch.Tensor, + kv: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + slot_ids: torch.Tensor | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + if slot_ids is not None: + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + assert slot_ids.dim() == 2 + assert slot_ids.shape == kv.shape[:2] + assert slot_ids.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = kv.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_gathered_attention_chunk_kernel[grid]( + q, + kv, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + slot_ids.stride(0) if slot_ids is not None else 0, + slot_ids.stride(1) if slot_ids is not None else 0, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HAS_SLOT_IDS=slot_ids is not None, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_indexed_attention_chunk_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_idx * stride_q_h + + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & ( + kv_index >= 0 + ) + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_indexed_sparse_mla_attention_chunk( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_indexed_attention_chunk_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_idx * stride_q_h + + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + + token_idx * stride_slot_t + + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_global_slots_attention_chunk_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = ( + token_idx * stride_state_t + head_offsets * stride_state_h + ) + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + + token_idx * stride_slot_t + + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_idx * stride_q_h + + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to( + tl.float32 + ) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = ( + k_cache_ptr + physical_block.to(tl.int64) * block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_paged_attention_chunk_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_paged_attention_with_sink_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + candidate_offset: tl.constexpr, + num_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where( + has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0 + ) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + candidate_offset: int, + num_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_paged_attention_with_sink_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, + head_dim, + candidate_offset, + num_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_global_paged_attention_with_sink_multihead_kernel( + q_ptr, + compressed_k_cache_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + compressed_cache_block_size: tl.constexpr, + compressed_block_stride: tl.constexpr, + swa_cache_block_size: tl.constexpr, + swa_block_stride: tl.constexpr, + token_data_size: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_swa_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + topk_len = tl.load(topk_lens_ptr + token_idx) + + for candidate_idx in range(0, num_compressed_candidates): + slot_id = tl.load( + slot_ids_ptr + + token_idx * stride_slot_t + + candidate_idx * stride_slot_c + ) + is_valid = (candidate_idx < topk_len) & (slot_id >= 0) + if is_valid: + block_idx = slot_id // compressed_cache_block_size + pos_in_block = slot_id % compressed_cache_block_size + cache_block_ptr = ( + compressed_k_cache_ptr + + block_idx.to(tl.int64) * compressed_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + compressed_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + for candidate_idx in range(0, num_swa_candidates): + is_valid = candidate_idx < gather_len + if is_valid: + pos = start_pos + candidate_idx + block_in_seq = pos // swa_cache_block_size + pos_in_block = pos % swa_cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = ( + swa_k_cache_ptr + + physical_block.to(tl.int64) * swa_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + swa_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where( + has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0 + ) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, + num_compressed_candidates: int, + num_swa_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert topk_lens.shape[0] == q.shape[0] + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert compressed_k_cache.dtype == torch.uint8 + assert swa_k_cache.dtype == torch.uint8 + assert q.is_cuda and compressed_k_cache.is_cuda and swa_k_cache.is_cuda + assert slot_ids.is_cuda and topk_lens.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_global_paged_attention_with_sink_multihead_kernel[grid]( + q, + compressed_k_cache, + slot_ids, + topk_lens, + swa_k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + compressed_block_size, + compressed_k_cache.stride(0), + swa_block_size, + swa_k_cache.stride(0), + token_data_size, + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, + head_dim, + num_compressed_candidates, + num_swa_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _finish_attention_state_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + output_ptr, + lse_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + is_valid = running_denom > 0.0 + inv_denom = tl.where(is_valid, 1.0 / running_denom, 0.0) + subset_lse = tl.where( + is_valid, + running_max + tl.log(running_denom), + -float("inf"), + ) + + acc = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + subset_output = acc * inv_denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + subset_output, + mask=dim_mask, + ) + if block_d == 0: + tl.store( + lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h, + subset_lse, + ) + + +def finish_gathered_sparse_mla_attention( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + output: torch.Tensor, + lse: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape == acc.shape + assert lse.shape == max_score.shape + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert output.dtype == torch.float32 + assert lse.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert output.is_cuda and lse.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_kernel[grid]( + max_score, + denom, + acc, + output, + lse, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _finish_attention_state_with_sink_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + sink_ptr, + output_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + sink = tl.load(sink_ptr + head_idx) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where( + has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0 + ) + subset_weight = running_denom * subset_scale + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = subset_weight + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc_values = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc_values = tl.where(has_tokens, acc_values, 0.0) + output = acc_values * subset_scale * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +@triton.jit +def _finish_two_attention_states_with_sink_kernel( + max0_ptr, + denom0_ptr, + acc0_ptr, + max1_ptr, + denom1_ptr, + acc1_ptr, + sink_ptr, + output_ptr, + stride_state0_t: tl.constexpr, + stride_state0_h: tl.constexpr, + stride_acc0_t: tl.constexpr, + stride_acc0_h: tl.constexpr, + stride_acc0_d: tl.constexpr, + stride_state1_t: tl.constexpr, + stride_state1_h: tl.constexpr, + stride_acc1_t: tl.constexpr, + stride_acc1_h: tl.constexpr, + stride_acc1_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h + state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h + max0 = tl.load(max0_ptr + state0_offset) + denom0 = tl.load(denom0_ptr + state0_offset) + max1 = tl.load(max1_ptr + state1_offset) + denom1 = tl.load(denom1_ptr + state1_offset) + sink = tl.load(sink_ptr + head_idx) + + has0 = denom0 > 0.0 + has1 = denom1 > 0.0 + has_sink = sink > -float("inf") + valid_max0 = tl.where(has0, max0, -float("inf")) + valid_max1 = tl.where(has1, max1, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(tl.maximum(valid_max0, valid_max1), valid_sink) + has_any = has0 | has1 | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_max0 = tl.where(has0, max0, safe_merge_max) + safe_max1 = tl.where(has1, max1, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0) + scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = denom0 * scale0 + denom1 * scale1 + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc0 = tl.load( + acc0_ptr + + token_idx * stride_acc0_t + + head_idx * stride_acc0_h + + offsets * stride_acc0_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc1 = tl.load( + acc1_ptr + + token_idx * stride_acc1_t + + head_idx * stride_acc1_h + + offsets * stride_acc1_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc0 = tl.where(has0, acc0, 0.0) + acc1 = tl.where(has1, acc1, 0.0) + output = (acc0 * scale0 + acc1 * scale1) * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +def finish_two_sparse_mla_attention_states_with_sink( + max_score0: torch.Tensor, + denom0: torch.Tensor, + acc0: torch.Tensor, + max_score1: torch.Tensor, + denom1: torch.Tensor, + acc1: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score0.shape == denom0.shape + assert max_score1.shape == denom1.shape + assert max_score0.shape == max_score1.shape + assert acc0.shape == acc1.shape + assert acc0.shape[:2] == max_score0.shape + assert output.shape[0] == acc0.shape[0] + assert output.shape[1] >= acc0.shape[1] + assert output.shape[2] == acc0.shape[2] + assert attn_sink.shape[0] >= acc0.shape[1] + assert max_score0.dtype == torch.float32 + assert denom0.dtype == torch.float32 + assert acc0.dtype == torch.float32 + assert max_score1.dtype == torch.float32 + assert denom1.dtype == torch.float32 + assert acc1.dtype == torch.float32 + assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda + assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc0.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_two_attention_states_with_sink_kernel[grid]( + max_score0, + denom0, + acc0, + max_score1, + denom1, + acc1, + attn_sink, + output, + max_score0.stride(0), + max_score0.stride(1), + acc0.stride(0), + acc0.stride(1), + acc0.stride(2), + max_score1.stride(0), + max_score1.stride(1), + acc1.stride(0), + acc1.stride(1), + acc1.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +def finish_sparse_mla_attention_with_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape[0] == acc.shape[0] + assert output.shape[1] >= acc.shape[1] + assert output.shape[2] == acc.shape[2] + assert attn_sink.shape[0] >= acc.shape[1] + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_with_sink_kernel[grid]( + max_score, + denom, + acc, + attn_sink, + output, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_reference.py b/vllm/v1/attention/backends/mla/sparse_mla_reference.py new file mode 100644 index 000000000000..203b64188202 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_reference.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Reference sparse MLA attention helpers. + +The helpers in this module intentionally use PyTorch tensor operations. They +are the correctness-first contract for portable sparse MLA fallbacks and tests; +optimized Triton/CUDA kernels should preserve these semantics. +""" + +import torch + + +def new_reference_attention_state( + q: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if q.dim() == 4: + q_bhd = q[:, 0, :, :].float() + else: + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + q_bhd = q.float() + + num_tokens = q_bhd.shape[0] + num_heads = q_bhd.shape[1] + head_dim = q_bhd.shape[2] + max_score = torch.full( + (num_tokens, num_heads), + float("-inf"), + dtype=torch.float32, + device=q.device, + ) + denom = torch.zeros_like(max_score) + acc = torch.zeros( + (num_tokens, num_heads, head_dim), + dtype=torch.float32, + device=q.device, + ) + return q_bhd, max_score, denom, acc + + +def accumulate_reference_attention_chunk( + q_bhd: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + kv_btd = kv.float() + kv_btd = torch.where( + valid_tokens[:, :, None], + kv_btd, + torch.zeros((), dtype=kv_btd.dtype, device=kv_btd.device), + ) + scores = torch.einsum("bhd,btd->bht", q_bhd, kv_btd) * scale + scores = scores.masked_fill(~valid_tokens[:, None, :], float("-inf")) + + chunk_max = scores.amax(dim=-1) + next_max = torch.maximum(max_score, chunk_max) + + previous_scale = torch.exp(max_score - next_max) + previous_scale = torch.nan_to_num(previous_scale) + weights = torch.exp(scores - next_max[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + + acc = acc * previous_scale[:, :, None] + denom = denom * previous_scale + acc = acc + torch.einsum("bht,btd->bhd", weights, kv_btd) + denom = denom + weights.sum(dim=-1) + return next_max, denom, acc + + +def finish_reference_attention_no_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + valid = denom > 0 + safe_denom = torch.where(valid, denom, torch.ones_like(denom)) + subset_output = acc / safe_denom[:, :, None] + subset_output = torch.where( + valid[:, :, None], + subset_output, + torch.zeros((), dtype=subset_output.dtype, device=subset_output.device), + ) + subset_lse = torch.where( + valid, + max_score + torch.log(safe_denom), + torch.full_like(max_score, float("-inf")), + ) + return subset_output, subset_lse + + +def reference_attention_no_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor]: + q_bhd, max_score, denom, acc = new_reference_attention_state(q) + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=kv, + valid_tokens=valid_tokens, + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + return finish_reference_attention_no_sink(max_score, denom, acc) + + +def merge_reference_attention_with_sink( + subset_outputs: list[torch.Tensor], + subset_lses: list[torch.Tensor], + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset_outputs, "At least one attention subset is required" + assert len(subset_outputs) == len(subset_lses) + + sink = attn_sink[None, :].float() + merge_max = sink + for subset_lse in subset_lses: + merge_max = torch.maximum(merge_max, subset_lse) + + safe_merge_max = torch.where( + torch.isfinite(merge_max), merge_max, torch.zeros_like(merge_max) + ) + merged_acc = torch.zeros_like(subset_outputs[0], dtype=torch.float32) + sink_weight = torch.exp(sink - safe_merge_max) + sink_weight = torch.nan_to_num(sink_weight) + merged_denom = sink_weight + for subset_output, subset_lse in zip(subset_outputs, subset_lses): + subset_weight = torch.exp(subset_lse - safe_merge_max) + subset_weight = torch.nan_to_num(subset_weight) + merged_acc = merged_acc + subset_output.float() * subset_weight[:, :, None] + merged_denom = merged_denom + subset_weight + + safe_denom = torch.where( + merged_denom > 0, merged_denom, torch.ones_like(merged_denom) + ) + reference_output = merged_acc / safe_denom[:, :, None] + reference_output = torch.where( + (merged_denom > 0)[:, :, None], + reference_output, + torch.zeros((), dtype=reference_output.dtype, device=reference_output.device), + ) + output.copy_(reference_output.to(dtype=output.dtype)) + + +def sink_aware_reference_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + subset_output, subset_lse = reference_attention_no_sink( + q=q, + kv=kv, + valid_tokens=valid_tokens, + scale=scale, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=attn_sink, + output=output, + ) + + +def reference_sparse_mla_prefill( + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + topk_chunk_size: int, + query_chunk_size: int, +) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = min(combined_indices.shape[-1], topk_chunk_size) + query_chunk_size = min(q.shape[0], query_chunk_size) + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + q_bhd, max_score, denom, acc = new_reference_attention_state(q_chunk) + + for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + indices_chunk = indices_chunk_full[:, index_start:index_end] + index_offsets = torch.arange( + index_start, + index_end, + device=q.device, + ) + valid_tokens = ( + (index_offsets[None, :] < lens_chunk[:, None]) + & (indices_chunk >= 0) + ) + safe_indices = torch.where( + valid_tokens, + indices_chunk, + torch.zeros((), dtype=indices_chunk.dtype, device=q.device), + ).long() + gathered_kv = kv_flat[safe_indices] + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=gathered_kv, + valid_tokens=valid_tokens, + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + + subset_output, subset_lse = finish_reference_attention_no_sink( + max_score, + denom, + acc, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=attn_sink, + output=output[token_start:token_end], + ) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py new file mode 100644 index 000000000000..6f81aa69eb3b --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -0,0 +1,543 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar, cast + +import torch + +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + MultipleOf, +) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_sparse_mla_attention_dump_enabled, + is_sparse_mla_reference_attention_enabled, + is_sparse_mla_reference_attention_enabled_for_platform, + sparse_mla_reference_cudagraphs_allowed, +) +from vllm.v1.attention.backends.utils import split_decodes_and_prefills +from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + KVCacheSpec, + MLAAttentionSpec, + SlidingWindowMLASpec, +) + +# DeepseekV4 decode layer types, keyed by compress_ratio. Each type has a distinct +# (topk, extra_topk, extra_page_block_size) config, so they cannot share a +# FlashMLA tile-scheduler plan. Within a type, all ~60 DeepseekV4 layers share one +# plan per step because b / s_q / h_q / page_block_sizes / topks are identical. +_LAYER_TYPE_SWAONLY = "swaonly" +_LAYER_TYPE_C4A = "c4a" +_LAYER_TYPE_C128A = "c128a" + + +def _layer_type_for(compress_ratio: int) -> str: + if compress_ratio <= 1: + return _LAYER_TYPE_SWAONLY + if compress_ratio == 4: + return _LAYER_TYPE_C4A + if compress_ratio == 128: + return _LAYER_TYPE_C128A + raise ValueError( + f"Unsupported DeepseekV4 compress_ratio={compress_ratio}; " + "expected 1, 4, or 128." + ) + + +class DeepseekV4SWACache(torch.nn.Module, AttentionLayerBase): + def __init__( + self, + head_dim: int, + window_size: int, + dtype: torch.dtype, + prefix: str, + cache_config: CacheConfig, + ): + super().__init__() + self.kv_cache = torch.tensor([]) + self.head_dim = head_dim + self.window_size = window_size + self.prefix = prefix + self.cache_config = cache_config + self.dtype = dtype + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + # Block size is constrained by tensor sharing between SWA and C4A KV blocks. + # Since both block types share the same physical tensor, they must use the + # same page size. The C4A KV block shape [256//4, head_dim] = [64, head_dim] + # determines the SWA block size of 64 tokens per block. + # TODO(yifan): make SWA block size automatically determined and configurable. + self.block_size = 64 + assert self.dtype == torch.uint8 + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + return SlidingWindowMLASpec( + block_size=self.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + sliding_window=self.window_size, + cache_dtype_str=self.cache_config.cache_dtype, + alignment=576, # NOTE: FlashMLA requires 576B alignment + model_version="deepseek_v4", + ) + + def forward(self): ... + + def get_attn_backend(self) -> type[AttentionBackend]: + return DeepseekSparseSWABackend + + +class DeepseekSparseSWABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "DEEPSEEK_SPARSE_SWA" + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(64)] + + @classmethod + def get_preferred_block_size(cls, default_block_size: int) -> int: + return 256 + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [512] + + @staticmethod + def get_builder_cls() -> type["DeepseekSparseSWAMetadataBuilder"]: + return DeepseekSparseSWAMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + assert num_kv_heads == 1 + if cache_dtype_str == "fp8_ds_mla": + # DeepseekV4 SWA: 584B per token (448 NoPE + 128 RoPE + 8 fp8 scale). + # head_size passed in is the semantic head_dim (512). + return (num_blocks, block_size, 584) + else: + return (num_blocks, block_size, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + if include_num_layers_dimension: + return (0, 1, 2, 3) + return (0, 1, 2) + + +@dataclass +class DeepseekSparseSWAMetadata: + block_table: torch.Tensor + slot_mapping: torch.Tensor + block_size: int + seq_lens: torch.Tensor | None = None # [num_seqs] + query_start_loc: torch.Tensor | None = None # [num_seqs + 1] + query_start_loc_cpu: torch.Tensor | None = None # [num_seqs + 1] + + is_valid_token: torch.Tensor | None = None # [num_tokens] + token_to_req_indices: torch.Tensor | None = None # [num_tokens] + decode_swa_indices: torch.Tensor | None = None # [num_decode_tokens, window_size] + decode_swa_lens: torch.Tensor | None = None # [num_decode_tokens] + + # Number of decode/prefill requests/tokens (batch is reordered: decodes first) + num_decodes: int = 0 + num_prefills: int = 0 + num_decode_tokens: int = 0 + num_prefill_tokens: int = 0 + + # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. + prefill_seq_lens: torch.Tensor | None = None + prefill_gather_lens: torch.Tensor | None = None + prefill_seq_lens_cpu: torch.Tensor | None = None + prefill_gather_lens_cpu: torch.Tensor | None = None + + # Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta + # per present DeepseekV4 layer type, shared across all ~60 layers of that type + # within a decode step. The first forward call of a given type triggers + # the in-kernel planner (which also allocates tile_scheduler_metadata and + # num_splits via PyTorch's graph-aware allocator); subsequent same-type + # calls skip planning and reuse the plan. Fresh instance per build(), so + # have_initialized is always False at the start of a step and the plan + # is re-derived from current seq_lens / topk_length on replay. + # None for layer types the model does not use (or when num_decode_tokens + # is zero). + tile_sched_swaonly: "FlashMLASchedMeta | None" = None + tile_sched_c4a: "FlashMLASchedMeta | None" = None + tile_sched_c128a: "FlashMLASchedMeta | None" = None + + +class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder): + """Builds metadata for DeepseekV4 SWA cache. + + Similar to the indexer, this handles mixed batches by: + 1. Using split_decodes_and_prefills() to determine the boundary + 2. Building separate metadata for decode and prefill portions + + Supports: + - Mixed decode/prefill batches + - MTP (Multi-Token Prediction) where decode has query_len > 1 + - Chunked prefill (aligns with the indexer's chunking) + """ + + # Base threshold: query_len <= 1 is decode + reorder_batch_threshold: int = 1 + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_sparse_mla_reference_attention_enabled_for_platform() + and not sparse_mla_reference_cudagraphs_allowed(vllm_config) + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) + mla_spec = cast(SlidingWindowMLASpec | MLAAttentionSpec, self.kv_cache_spec) + self.head_size = mla_spec.head_size # Already considered quantization. + self.compress_ratio = mla_spec.compress_ratio + self.block_size = mla_spec.block_size + + # Handle MTP: adjust decode_threshold like the indexer does + self.num_speculative_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config + else 0 + ) + # With MTP, decode can have query_len up to 1 + num_speculative_tokens. + # Must match the threshold used by the indexer and flashmla_sparse so + # that all backends agree on the decode/prefill split. + self.decode_threshold = ( + self.reorder_batch_threshold + self.num_speculative_tokens + ) + + hf_config = self.vllm_config.model_config.hf_config + assert hasattr(hf_config, "sliding_window") + self.window_size = hf_config.sliding_window + + # Detect which DeepseekV4 layer types this model uses so we only build a + # FlashMLA tile-scheduler plan for types that will actually be called. + # Models without compress_ratios (pure SWA) fall back to swaonly. + compress_ratios = getattr(hf_config, "compress_ratios", None) or [1] + self._layer_types: set[str] = set() + for ratio in compress_ratios: + self._layer_types.add(_layer_type_for(int(ratio))) + + max_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + self.token_to_req_indices = torch.zeros( + max_tokens, + dtype=torch.int32, + device=self.device, + ) + self.decode_swa_indices = torch.zeros( + max_tokens, + 1, + self.window_size, + dtype=torch.int32, + device=self.device, + ) + self.decode_swa_lens = torch.zeros( + max_tokens, + dtype=torch.int32, + device=self.device, + ) + self.is_valid_token = torch.zeros( + max_tokens, + dtype=torch.bool, + device=self.device, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> DeepseekSparseSWAMetadata: + """Build SWA metadata for mixed decode/prefill batches. + + The batch is assumed to be reordered with decodes first (by vLLM scheduler). + We use split_decodes_and_prefills() to find the boundary, then build + separate window_topk_idxs for each portion. + + For prefill, we use chunked prefill to align with the indexer's chunking. + """ + num_reqs = common_attn_metadata.num_reqs + seq_lens = common_attn_metadata.seq_lens + query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Split into decode and prefill portions using configurable threshold + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.decode_threshold + ) + ) + + # NOTE: Ensure all metadata tensors maintain fixed memory addresses + # for CUDA graph compatibility. + query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + x = torch.repeat_interleave(torch.arange(num_reqs), query_lens).pin_memory() + token_to_req_indices = self.token_to_req_indices[: x.shape[0]] + token_to_req_indices.copy_(x, non_blocking=True) + + is_valid_token = self.is_valid_token[: slot_mapping.shape[0]] + is_valid_token.copy_(slot_mapping >= 0) + + if num_decode_tokens > 0: + self.decode_swa_lens[num_decode_tokens:] = 0 + _compute_swa_indices_and_lens_kernel[(num_decode_tokens,)]( + self.decode_swa_indices, + self.decode_swa_indices.stride(0), + self.decode_swa_lens, + self.window_size, + query_start_loc, + seq_lens, + token_to_req_indices, + is_valid_token, + block_table, + block_table.stride(0), + self.block_size, + TRITON_BLOCK_SIZE=1024, + ) + + # Pre-compute DeepseekV4 prefill metadata shared across all attention layers. + deepseek_v4_fields = self._build_deepseek_v4_metadata( + num_decodes, + num_prefills, + seq_lens, + query_start_loc, + query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu_upper_bound, + ) + + # Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta + # per present DeepseekV4 layer type; the first flash_mla_with_kvcache call of + # each type triggers the planner and all same-type layers reuse the + # resulting plan for the rest of the step. + tile_sched = self.build_tile_scheduler(num_decode_tokens) + + return DeepseekSparseSWAMetadata( + seq_lens=seq_lens, + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + block_table=block_table, + slot_mapping=slot_mapping, + is_valid_token=is_valid_token, + token_to_req_indices=token_to_req_indices, + decode_swa_indices=self.decode_swa_indices[:num_decode_tokens], + decode_swa_lens=self.decode_swa_lens[:num_decode_tokens], + block_size=self.block_size, + num_decodes=num_decodes, + num_prefills=num_prefills, + num_decode_tokens=num_decode_tokens, + num_prefill_tokens=num_prefill_tokens, + tile_sched_swaonly=tile_sched[_LAYER_TYPE_SWAONLY], + tile_sched_c4a=tile_sched[_LAYER_TYPE_C4A], + tile_sched_c128a=tile_sched[_LAYER_TYPE_C128A], + **deepseek_v4_fields, + ) + + def build_tile_scheduler( + self, num_decode_tokens: int + ) -> dict[str, FlashMLASchedMeta | None]: + """Allocate one empty ``FlashMLASchedMeta`` per present DeepseekV4 layer type. + + Returned instances have ``tile_scheduler_metadata`` / ``num_splits`` + set to ``None``; the FlashMLA C++ decode path will allocate them and + run the tile-scheduler planner on the first ``flash_mla_with_kvcache`` + call of each type. Subsequent same-type calls reuse the plan because + the tensors (and ``have_initialized``) are populated on the struct. + + Returns all-``None`` when there are no decode tokens this step, so + ``_forward_decode`` sees a clean sentinel. + """ + out: dict[str, FlashMLASchedMeta | None] = { + _LAYER_TYPE_SWAONLY: None, + _LAYER_TYPE_C4A: None, + _LAYER_TYPE_C128A: None, + } + if num_decode_tokens == 0: + return out + if ( + is_sparse_mla_attention_dump_enabled() + or is_sparse_mla_reference_attention_enabled(self.device) + ): + return out + for layer_type in self._layer_types: + # get_mla_metadata() is the official FlashMLA entry point that + # returns a fresh empty FlashMLASchedMeta; using it keeps this + # call site aligned with the rest of the vLLM FlashMLA backends + # that already go through the same stub. + out[layer_type] = get_mla_metadata()[0] + return out + + def _build_deepseek_v4_metadata( + self, + num_decodes: int, + num_prefills: int, + seq_lens: torch.Tensor, + query_start_loc: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + seq_lens_cpu_upper_bound: torch.Tensor | None, + ) -> dict[str, torch.Tensor | None]: + """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. + + Returns a dict of keyword arguments to pass to the + DeepseekSparseSWAMetadata constructor. + + Note: C128A topk indices are computed by the FlashMLASparse builder + (which owns the C128A block_table), not here. + """ + result: dict[str, torch.Tensor | None] = {} + + # --- Prefill query metadata (single Triton kernel + CPU slicing) --- + if num_prefills > 0: + pfx_gather_lens = torch.empty( + num_prefills, dtype=torch.int32, device=seq_lens.device + ) + _compute_prefill_metadata_kernel[(1,)]( + pfx_gather_lens, + seq_lens, + query_start_loc, + num_prefills, + num_decodes, + self.window_size, + BLOCK_SIZE=triton.next_power_of_2(num_prefills), + ) + + assert seq_lens_cpu_upper_bound is not None + seq_lens_cpu = seq_lens_cpu_upper_bound + prefill_seq_lens_cpu = seq_lens_cpu[ + num_decodes : num_decodes + num_prefills + ] + query_lens_cpu = ( + query_start_loc_cpu[ + num_decodes + 1 : num_decodes + num_prefills + 1 + ] + - query_start_loc_cpu[num_decodes : num_decodes + num_prefills] + ) + prefix_lens_cpu = prefill_seq_lens_cpu - query_lens_cpu + prefill_gather_lens_cpu = query_lens_cpu + torch.minimum( + prefix_lens_cpu, + torch.full_like(prefix_lens_cpu, self.window_size - 1), + ) + + result["prefill_seq_lens"] = seq_lens[num_decodes:] + result["prefill_gather_lens"] = pfx_gather_lens + result["prefill_seq_lens_cpu"] = prefill_seq_lens_cpu + result["prefill_gather_lens_cpu"] = prefill_gather_lens_cpu + + return result + + +@triton.jit +def _compute_prefill_metadata_kernel( + # Outputs + prefill_gather_lens_ptr, + # Inputs + seq_lens_ptr, + query_start_loc_ptr, + num_prefills, + num_decodes, + window_size, + BLOCK_SIZE: tl.constexpr, +): + """Compute prefill gather_lens in a single pass.""" + offset = tl.arange(0, BLOCK_SIZE) + mask = offset < num_prefills + + seq_len = tl.load(seq_lens_ptr + num_decodes + offset, mask=mask) + qsl_start = tl.load(query_start_loc_ptr + num_decodes + offset, mask=mask) + qsl_end = tl.load(query_start_loc_ptr + num_decodes + offset + 1, mask=mask) + + query_len = qsl_end - qsl_start + prefix_len = seq_len - query_len + gather_len = query_len + tl.minimum(prefix_len, window_size - 1) + + tl.store(prefill_gather_lens_ptr + offset, gather_len, mask=mask) + + +@triton.jit +def _compute_swa_indices_and_lens_kernel( + swa_indices_ptr, + swa_indices_stride, + swa_lens_ptr, + window_size, + query_start_loc_ptr, + seq_lens_ptr, + token_to_req_indices_ptr, + is_valid_token_ptr, + block_table_ptr, + block_table_stride, + block_size, + TRITON_BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + is_valid = tl.load(is_valid_token_ptr + token_idx) + if not is_valid: + tl.store(swa_lens_ptr + token_idx, 0) + return + + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + + query_start = tl.load(query_start_loc_ptr + req_idx) + query_end = tl.load(query_start_loc_ptr + req_idx + 1) + query_len = query_end - query_start + + seq_len = tl.load(seq_lens_ptr + req_idx) + prefix_len = seq_len - query_len + + pos = prefix_len + token_idx - query_start + start_pos = tl.maximum(pos - window_size + 1, 0) + end_pos = pos + 1 + + swa_len = end_pos - start_pos + tl.store(swa_lens_ptr + token_idx, swa_len) + + for i in range(0, window_size, TRITON_BLOCK_SIZE): + offset = i + tl.arange(0, TRITON_BLOCK_SIZE) + + pos_offset = start_pos + offset + block_indices = pos_offset // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=pos_offset < end_pos, + ) + block_offsets = pos_offset % block_size + slot_ids = block_numbers * block_size + block_offsets + + slot_ids = tl.where(offset < swa_len, slot_ids, -1) + tl.store( + swa_indices_ptr + token_idx * swa_indices_stride + offset, + slot_ids, + mask=offset < window_size, + ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b4bdce876d81..54ebd088b95e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -356,7 +356,7 @@ def make_local_attention_virtual_batches( block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, causal=True, - seq_lens_cpu_upper_bound=seq_lens_cpu, + seq_lens_cpu_upper_bound=common_attn_metadata.seq_lens_cpu_upper_bound, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), ), make_block_table diff --git a/vllm/v1/attention/ops/common.py b/vllm/v1/attention/ops/common.py index 46c689ce0b8f..98abc7790ea2 100644 --- a/vllm/v1/attention/ops/common.py +++ b/vllm/v1/attention/ops/common.py @@ -265,6 +265,7 @@ def _pack_seq_kernel( D: tl.constexpr, Lmax: tl.constexpr, PAD_VALUE: tl.constexpr, + PAD_IS_UINT8: tl.constexpr, BLOCK_T: tl.constexpr, # timesteps per program BLOCK_D: tl.constexpr, # features per program ): @@ -294,9 +295,15 @@ def _pack_seq_kernel( # out_ptr: row-major [B, Lmax, D] out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] - # Initialize with PAD (cast will occur as needed based on out_ptr dtype) + # Initialize with PAD. PAD_IS_UINT8 selects the pad tensor's dtype so + # integer-typed outputs (e.g. MXFP4 packed nibbles, ue8m0 scale bytes) + # get an exact-byte pad rather than going through an fp32→uint8 cast + # that's implementation-defined outside of value 0. d_mask = off_d[None, :] < D - pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32) + if PAD_IS_UINT8: + pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.uint8) + else: + pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32) tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask) # Load & write only where within seq_len @@ -307,23 +314,36 @@ def _pack_seq_kernel( def pack_seq_triton( x: torch.Tensor, lengths: torch.Tensor, - pad_value: float = -float("inf"), + pad_value: float | int = -float("inf"), block_t: int = 64, block_d: int = 64, ) -> torch.Tensor: - """ - Pack sequences of different lengths into a batched tensor. + """Pack sequences of different lengths into a batched tensor. + + Supports float dtypes (any, via fp32 pad) and ``torch.uint8`` (exact-byte + pad — e.g. MXFP4 packed nibbles or ue8m0 scale bytes). For uint8 inputs + ``pad_value`` must be an integer in ``[0, 255]``. Args: - x: [N, ...] - input tensor where N is total number of tokens - lengths: [B] - sequence lengths for each batch - pad_value: value to use for padding - block_t: block size for time dimension - block_d: block size for feature dimension + x: [N, ...] — input tensor where N is total number of tokens. + lengths: [B] — sequence lengths for each batch. + pad_value: value to use for padding. Defaults to ``-inf`` which is + only sensible for float dtypes; pass ``0`` (or any byte) for + uint8 inputs. + block_t: block size for time dimension. + block_d: block size for feature dimension. Returns: - packed: [B, Lmax, ...] - packed tensor + packed: [B, Lmax, ...] — packed tensor. """ + is_uint8 = x.dtype == torch.uint8 + if is_uint8: + assert isinstance(pad_value, int) and 0 <= pad_value <= 255, ( + f"uint8 pack requires an integer pad in [0, 255], got {pad_value!r}" + ) + pad_constexpr: int | float = int(pad_value) + else: + pad_constexpr = float(pad_value) # Handle multi-dimensional input by reshaping to (N, -1) original_shape = x.shape @@ -338,8 +358,6 @@ def pack_seq_triton( B = lengths.numel() Lmax = int(lengths.max().item()) - # Starts are computed inside the kernel from lengths - out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype) grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) @@ -350,17 +368,16 @@ def pack_seq_triton( N, D, Lmax, - PAD_VALUE=float(pad_value), + PAD_VALUE=pad_constexpr, + PAD_IS_UINT8=is_uint8, BLOCK_T=block_t, BLOCK_D=block_d, num_warps=4, num_stages=2, ) - # Reshape output back to original dimensions (except first dimension) if len(original_shape) > 2: - output_shape = (B, Lmax) + original_shape[1:] - out = out.reshape(output_shape) + out = out.reshape((B, Lmax) + original_shape[1:]) return out diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py new file mode 100644 index 000000000000..763f69b671e4 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .cache_utils import ( + combine_topk_swa_indices, + compute_global_topk_indices_and_lens, + dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, + quantize_and_insert_k_cache, +) +from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant +from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant +from .fused_qk_rmsnorm import fused_q_kv_rmsnorm + +__all__ = [ + "MXFP4_BLOCK_SIZE", + "combine_topk_swa_indices", + "compute_global_topk_indices_and_lens", + "dequantize_and_gather_k_cache", + "dequantize_combined_sparse_mla_decode_kv", + "dequantize_global_slots_k_cache", + "fused_indexer_q_rope_quant", + "fused_inv_rope_fp8_quant", + "fused_q_kv_rmsnorm", + "quantize_and_insert_k_cache", +] diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py new file mode 100644 index 000000000000..0c9170439ee8 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -0,0 +1,715 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Triton kernels for DeepseekV4 paged K-cache management and sparse-attention index +preparation. + +- quantize_and_insert_k_cache: quantize bf16 K to UE8M0 FP8 and insert into + the paged cache. +- dequantize_and_gather_k_cache: gather and dequantize FP8 K from the paged + cache for sparse/SWA prefill. +- compute_global_topk_indices_and_lens: map local topk indices to global KV + cache slots and count valid entries. +- combine_topk_swa_indices: concatenate topk compressed indices with SWA + window indices for sparse prefill. +""" + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def quantize_and_insert_k_kernel( + # Input tensors + k_ptr, # [num_tokens, 512] bf16 + slot_mapping_ptr, # [num_tokens] int64 + # Output tensor + k_cache_ptr, # [num_blocks, block_bytes] as uint8 (flattened view) + # Dimensions + num_tokens, + input_dim: tl.constexpr, # 512 + fp8_dim: tl.constexpr, # 448 + bf16_dim: tl.constexpr, # 64 + scale_dim: tl.constexpr, # 8 + quant_block: tl.constexpr, # 64 (quantization block size) + cache_block_size: tl.constexpr, # 64 (paged cache block size) + token_data_size: tl.constexpr, # 576 bytes per token data + block_stride: tl.constexpr, # total bytes per block (padded) + fp8_max: tl.constexpr, + n_quant_blocks: tl.constexpr, # 8 (7 real + 1 padding) +): + """ + Quantize K tensor and insert into paged K cache. + + K Cache block layout (block_size=64 tokens): + - [0, 64*576): Token data, each token has 448 fp8 + 128 bf16 + - [64*576, 64*576 + 64*8): Scales, each token has 8 uint8 scales + - [64*576 + 64*8, block_stride): Padding + + One program per token. + """ + pid = tl.program_id(0) + + if pid >= num_tokens: + return + + # Get slot mapping + slot_idx = tl.load(slot_mapping_ptr + pid) + if slot_idx == -1: + return + + block_idx = slot_idx // cache_block_size + pos_in_block = slot_idx % cache_block_size + + # Input pointer for this token + input_row_ptr = k_ptr + pid * input_dim + + # int64: block_idx * block_stride can exceed 2^31 with many KV-cache blocks + # (e.g. >= 57K at block_stride ~37K). Matches gather path below. + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + + # Token data pointer: token data is stored contiguously at start of block + # Each token's data is at offset pos_in_block * token_data_size + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + + # Scale pointer: scales are stored after ALL token data in the block + # Scale for this token is at offset (64 * 576) + pos_in_block * 8 + token_scale_ptr = ( + cache_block_ptr + cache_block_size * token_data_size + pos_in_block * scale_dim + ) + + # Token data layout: [0:448] fp8, [448:576] bf16 + token_fp8_ptr = token_data_ptr + token_bf16_ptr = token_data_ptr + fp8_dim + + # ========== Quantize and store FP8 portion (first 448 elements) ========== + # Using UE8M0 quantization strategy (scale is power of 2, stored as uint8 exponent) + for qblock_idx in tl.static_range(n_quant_blocks): + qblock_start = qblock_idx * quant_block + + if qblock_start < fp8_dim: + offsets = qblock_start + tl.arange(0, quant_block) + mask = offsets < fp8_dim + + # Load bf16 input + x = tl.load(input_row_ptr + offsets, mask=mask, other=0.0) + + # Compute absmax scale (same as CUDA kernel) + abs_x = tl.abs(x) + block_max = tl.max(abs_x, axis=0) + block_max = tl.maximum(block_max, 1e-4) # Match CUDA: fmaxf(amax, 1e-4) + + # UE8M0: Round scale UP to next power of 2 + # scale = 2^ceil(log2(block_max / fp8_max)) + raw_scale = block_max / fp8_max + log_scale = tl.log2(raw_scale) + exponent = tl.ceil(log_scale) # Round UP to next integer exponent + scale = tl.exp2(exponent) # scale = 2^exponent (power of 2) + + # Quantize to fp8: fp8_value = bf16_value / scale + x_scaled = x / scale + x_clamped = tl.clamp(x_scaled, -fp8_max, fp8_max) + + # Convert to fp8, then bitcast to uint8 for storage + x_fp8 = x_clamped.to(tl.float8e4nv) + x_uint8 = x_fp8.to(tl.uint8, bitcast=True) + + # Store as uint8 (1 byte each) + tl.store(token_fp8_ptr + offsets, x_uint8, mask=mask) + + # UE8M0 scale encoding: stored_value = exponent + 127 (bias) + # During dequant: scale = 2^(stored_value - 127) + encoded_scale = exponent + 127.0 + encoded_scale = tl.maximum(tl.minimum(encoded_scale, 255.0), 0.0) + tl.store(token_scale_ptr + qblock_idx, encoded_scale.to(tl.uint8)) + + # Padding scale at index 7 + tl.store(token_scale_ptr + 7, tl.zeros((), dtype=tl.uint8)) + + # ========== Store BF16 portion (last 64 elements, no quantization) ========== + bf16_input_offset = fp8_dim + + # Process bf16 in chunks of 16 + bf16_out_ptr = token_bf16_ptr.to(tl.pointer_type(tl.bfloat16)) + for i in tl.static_range(bf16_dim // 16): + chunk_offsets = i * 16 + tl.arange(0, 16) + bf16_vals = tl.load(input_row_ptr + bf16_input_offset + chunk_offsets) + tl.store(bf16_out_ptr + chunk_offsets, bf16_vals) + + +def quantize_and_insert_k_cache( + k: torch.Tensor, # [num_tokens, 512] bf16 + k_cache: torch.Tensor, # [num_blocks, block_bytes] uint8 + slot_mapping: torch.Tensor, # [num_tokens] int64 + block_size: int = 64, + is_ue8m0: bool = True, +): + """ + Quantize K tensor and insert into paged K cache. + + K Cache block layout (block_size=64 tokens): + - First 64 * 576 = 36864 bytes: Token data + - Each token: 448 bytes (fp8) + 128 bytes (bf16) + - Next 64 * 8 = 512 bytes: Scales + - Each token: 8 bytes (uint8 scales, 7 real + 1 padding) + - Padded to multiple of 576 + """ + assert k.dim() == 2 and k.shape[1] == 512, ( + f"K must be [num_tokens, 512], got {k.shape}" + ) + assert k.dtype == torch.bfloat16, f"K must be bf16, got {k.dtype}" + assert is_ue8m0, "Only support ue8m0 quantization." + + # NOTE: When using DP, slot_mapping.shape[0] can be less than k.shape[0] due to + # padding. Always use slot_mapping.shape[0] as the token count. + num_tokens = slot_mapping.shape[0] + block_stride = k_cache.stride(0) # bytes per block + + TOKEN_FP8_DIM = 448 + TOKEN_BF16_DIM = 64 + TOKEN_SCALE_DIM = 8 + QUANT_BLOCK_SIZE = 64 + FP8_MAX = 448.0 + TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2 + + grid = (num_tokens,) + + quantize_and_insert_k_kernel[grid]( + k, + slot_mapping, + k_cache, + num_tokens, + input_dim=512, + fp8_dim=TOKEN_FP8_DIM, + bf16_dim=TOKEN_BF16_DIM, + scale_dim=TOKEN_SCALE_DIM, + quant_block=QUANT_BLOCK_SIZE, + cache_block_size=block_size, + token_data_size=TOKEN_DATA_SIZE, + block_stride=block_stride, + fp8_max=FP8_MAX, + n_quant_blocks=8, + ) + + +@triton.jit +def _dequantize_and_gather_k_kernel( + out_ptr, + out_stride0, + out_stride1, + k_cache_ptr, + seq_lens_ptr, + block_table_ptr, + offset, + gather_lens_ptr, + # Constants + max_blocks_per_seq: tl.constexpr, + fp8_dim: tl.constexpr, # 448 + bf16_dim: tl.constexpr, # 64 + scale_dim: tl.constexpr, # 8 + quant_block: tl.constexpr, # 64 (quantization block size) + cache_block_size: tl.constexpr, # 64 or 128 (paged cache block size) + token_data_size: tl.constexpr, # 576 bytes per token data + block_stride: tl.constexpr, # total bytes per block (padded) int32 + output_dim: tl.constexpr, # 512 + fp8_max: tl.constexpr, + n_quant_blocks: tl.constexpr, # 7 real blocks +): + batch_idx = tl.program_id(0) + worker_id = tl.program_id(1) + num_workers = tl.num_programs(1) + + seq_len = tl.load(seq_lens_ptr + batch_idx) + if gather_lens_ptr is not None: # noqa: SIM108 + gather_len = tl.load(gather_lens_ptr + batch_idx) + else: + # Gather all tokens + gather_len = seq_len + start_pos = seq_len - gather_len + + for i in range(worker_id, gather_len, num_workers): + # Calculate the actual token index in the sequence + pos = start_pos + i + + # Calculate which block and position within block + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + + # Get physical block index from block table + block_table_row_ptr = block_table_ptr + batch_idx * max_blocks_per_seq + physical_block_idx = tl.load(block_table_row_ptr + block_in_seq) # int32 + + # int64: physical_block_idx * block_stride can exceed 2^31 with many + # KV-cache blocks (e.g. >= 57K at block_stride ~37K). + cache_block_ptr = k_cache_ptr + physical_block_idx.to(tl.int64) * block_stride + + # Token data pointer + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + + # Scale pointer: after all token data + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + # Token data layout: [0:448] fp8, [448:576] bf16 + token_fp8_ptr = token_data_ptr + token_bf16_ptr = token_data_ptr + fp8_dim + + # Output pointer for this token (flattened) + output_row_ptr = out_ptr + batch_idx * out_stride0 + (offset + i) * out_stride1 + + # ========== Dequantize FP8 portion using UE8M0 ========== + for qblock_idx in tl.static_range(n_quant_blocks): + qblock_start = qblock_idx * quant_block + + if qblock_start < fp8_dim: + offsets = qblock_start + tl.arange(0, quant_block) + mask = offsets < fp8_dim + + # Load quantized fp8 values (stored as uint8) + x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0) + + # Bitcast uint8 back to fp8 + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + + # Convert fp8 to float32 for computation + x_float = x_fp8.to(tl.float32) + + # Load and decode UE8M0 scale + # UE8M0: scale = 2^(stored_value - 127) + encoded_scale = tl.load(token_scale_ptr + qblock_idx) + exponent = encoded_scale.to(tl.float32) - 127.0 + scale = tl.exp2(exponent) + + # Dequantize: bf16_value = fp8_value * scale + x_dequant = x_float * scale + + # Store as bf16 + tl.store(output_row_ptr + offsets, x_dequant.to(tl.bfloat16), mask=mask) + + # ========== Copy BF16 portion directly ========== + bf16_output_offset = fp8_dim # After 448 elements in output + + # Read bf16 from cache + bf16_cache_ptr = token_bf16_ptr.to(tl.pointer_type(tl.bfloat16)) + + # Process in chunks of 16 + for j in tl.static_range(bf16_dim // 16): + chunk_offsets = j * 16 + tl.arange(0, 16) + bf16_vals = tl.load(bf16_cache_ptr + chunk_offsets) + tl.store(output_row_ptr + bf16_output_offset + chunk_offsets, bf16_vals) + + +def dequantize_and_gather_k_cache( + # [num_reqs, max_num_tokens, head_size] + out: torch.Tensor, + # [num_blocks, block_size, head_bytes] + k_cache: torch.Tensor, + # [num_reqs] + seq_lens: torch.Tensor, + # [num_reqs] + gather_lens: torch.Tensor | None, + # [num_reqs, max_blocks_per_seq] + block_table: torch.Tensor, + block_size: int, + offset: int, +) -> None: + TOKEN_FP8_DIM = 448 + TOKEN_BF16_DIM = 64 + TOKEN_SCALE_DIM = 8 + QUANT_BLOCK_SIZE = 64 + FP8_MAX = 448.0 + TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2 + + num_reqs = seq_lens.shape[0] + NUM_WORKERS = 128 + _dequantize_and_gather_k_kernel[(num_reqs, NUM_WORKERS)]( + out, + out.stride(0), + out.stride(1), + k_cache, + seq_lens, + block_table, + offset, + gather_lens, + max_blocks_per_seq=block_table.shape[-1], + fp8_dim=TOKEN_FP8_DIM, + bf16_dim=TOKEN_BF16_DIM, + scale_dim=TOKEN_SCALE_DIM, + quant_block=QUANT_BLOCK_SIZE, + cache_block_size=block_size, + token_data_size=TOKEN_DATA_SIZE, + block_stride=k_cache.stride(0), + output_dim=512, + fp8_max=FP8_MAX, + n_quant_blocks=7, + ) + + +@triton.jit +def _dequantize_global_slots_k_kernel( + out_ptr, + out_stride_token, + out_stride_slot, + k_cache_ptr, + slot_ids_ptr, + slot_ids_stride_token, + slot_ids_stride_slot, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + bf16_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + output_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + topk_idx = tl.program_id(1) + + slot_id = tl.load( + slot_ids_ptr + + token_idx * slot_ids_stride_token + + topk_idx * slot_ids_stride_slot + ) + offsets = tl.arange(0, BLOCK_D) + output_row = out_ptr + token_idx * out_stride_token + topk_idx * out_stride_slot + + if slot_id < 0: + tl.store( + output_row + offsets, + tl.zeros((BLOCK_D,), dtype=tl.float32).to(tl.bfloat16), + mask=offsets < output_dim, + ) + return + + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + cache_block_size * token_data_size + pos_in_block * scale_dim + ) + + fp8_offsets = tl.arange(0, 512) + fp8_mask = fp8_offsets < fp8_dim + x_uint8 = tl.load(token_data_ptr + fp8_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + + scale_offsets = fp8_offsets // quant_block + encoded_scale = tl.load(token_scale_ptr + scale_offsets, mask=fp8_mask, other=127) + scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * scale + tl.store(output_row + fp8_offsets, x_dequant.to(tl.bfloat16), mask=fp8_mask) + + bf16_offsets = tl.arange(0, 64) + bf16_cache_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + bf16_vals = tl.load(bf16_cache_ptr + bf16_offsets, mask=bf16_offsets < bf16_dim) + tl.store( + output_row + fp8_dim + bf16_offsets, + bf16_vals, + mask=bf16_offsets < bf16_dim, + ) + + +def dequantize_global_slots_k_cache( + out: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, +) -> None: + """Dequantize fp8_ds_mla cache rows addressed by physical global slot ids.""" + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0, :] + assert slot_ids.dim() == 2, ( + f"slot_ids must be [num_tokens, topk], got {slot_ids.shape}" + ) + assert out.shape[:2] == slot_ids.shape + assert out.shape[-1] == 512 + assert out.dtype == torch.bfloat16 + assert k_cache.dtype == torch.uint8 + + TOKEN_FP8_DIM = 448 + TOKEN_BF16_DIM = 64 + TOKEN_SCALE_DIM = 8 + QUANT_BLOCK_SIZE = 64 + TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2 + + grid = slot_ids.shape + _dequantize_global_slots_k_kernel[grid]( + out, + out.stride(0), + out.stride(1), + k_cache, + slot_ids, + slot_ids.stride(0), + slot_ids.stride(1), + cache_block_size=block_size, + token_data_size=TOKEN_DATA_SIZE, + block_stride=k_cache.stride(0), + fp8_dim=TOKEN_FP8_DIM, + bf16_dim=TOKEN_BF16_DIM, + scale_dim=TOKEN_SCALE_DIM, + quant_block=QUANT_BLOCK_SIZE, + output_dim=512, + BLOCK_D=triton.next_power_of_2(512), + ) + + +def dequantize_combined_sparse_mla_decode_kv( + combined_kv: torch.Tensor, + compressed_k_cache: torch.Tensor, + compressed_slot_ids: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + swa_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, +) -> None: + """Fill `[compressed, SWA]` decode candidates into one output buffer.""" + assert combined_kv.dim() == 3 + compressed_topk = compressed_slot_ids.shape[-1] + assert combined_kv.shape[0] == compressed_slot_ids.shape[0] + assert combined_kv.shape[-1] == 512 + assert combined_kv.dtype == torch.bfloat16 + assert combined_kv.shape[1] >= compressed_topk + + dequantize_global_slots_k_cache( + combined_kv[:, :compressed_topk], + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_out = combined_kv[:, compressed_topk:] + if swa_out.shape[1] == 0: + return + dequantize_and_gather_k_cache( + swa_out, + swa_k_cache, + seq_lens=seq_lens, + gather_lens=swa_lens, + block_table=block_table, + block_size=swa_block_size, + offset=0, + ) + + +def compute_global_topk_indices_and_lens( + topk_indices: torch.Tensor, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + is_valid_token: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Map local topk indices to global KV cache slots and count valid entries. + + Fuses three operations into a single kernel: + 1. Block-table lookup (local index → global slot id) + 2. Valid-entry counting (topk_lens per token) + 3. Masking padding tokens to length 0 + """ + num_tokens = topk_indices.shape[0] + global_topk_indices = torch.empty_like(topk_indices) + topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device) + _compute_global_topk_indices_and_lens_kernel[(num_tokens,)]( + global_topk_indices, + global_topk_indices.stride(0), + topk_lens, + topk_indices, + topk_indices.stride(0), + topk_indices.shape[-1], + token_to_req_indices, + block_table, + block_table.stride(0), + block_size, + is_valid_token, + TRITON_BLOCK_SIZE=1024, + ) + return global_topk_indices, topk_lens + + +@triton.jit +def _compute_global_topk_indices_and_lens_kernel( + global_topk_indices_ptr, + global_topk_indices_stride, + topk_lens_ptr, + topk_indices_ptr, + topk_indices_stride, + topk, + token_to_req_indices_ptr, + block_table_ptr, + block_table_stride, + block_size, + is_valid_token_ptr, + TRITON_BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + is_valid_token = tl.load(is_valid_token_ptr + token_idx) + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + + count = tl.zeros((), dtype=tl.int32) + for i in range(0, topk, TRITON_BLOCK_SIZE): + offset = i + tl.arange(0, TRITON_BLOCK_SIZE) + mask = offset < topk + + local_idx = tl.load( + topk_indices_ptr + token_idx * topk_indices_stride + offset, + mask=mask, + other=-1, + ) + is_valid = local_idx >= 0 + + block_indices = local_idx // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=mask & is_valid, + ) + block_offsets = local_idx % block_size + + slot_ids = block_numbers * block_size + block_offsets + slot_ids = tl.where(is_valid, slot_ids, -1) + tl.store( + global_topk_indices_ptr + token_idx * global_topk_indices_stride + offset, + slot_ids, + mask=mask, + ) + count += tl.sum(is_valid.to(tl.int32), axis=0) + + # Zero out length for padding tokens. + tl.store(topk_lens_ptr + token_idx, tl.where(is_valid_token, count, 0)) + + +# FlashMLA sparse prefill asserts `params.topk % B_TOPK == 0` (see +# flashmla/csrc/sm100/prefill/sparse/fwd/head{64,128}/phase1.cuh). B_TOPK is +# 64 for the h_q=64 kernel and 128 for h_q=128; pad to 128 to satisfy both. +# The extra slots stay as -1 sentinels and `combined_lens` caps the valid +# range via `topk_length`, so padding is a no-op at kernel level. +_SPARSE_PREFILL_TOPK_ALIGNMENT = 128 + + +def combine_topk_swa_indices( + topk_indices: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + window_size: int, + compress_ratio: int, + topk: int, + M: int, + N: int, +) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = topk_indices.shape[0] + num_reqs = seq_lens.shape[0] + combined_topk = ( + (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) + // _SPARSE_PREFILL_TOPK_ALIGNMENT + * _SPARSE_PREFILL_TOPK_ALIGNMENT + ) + combined_indices = torch.full( + (num_tokens, combined_topk), + fill_value=-1, + dtype=torch.int32, + device=topk_indices.device, + ) + combined_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + + NUM_WORKERS = 128 + _combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)]( + combined_indices, + combined_indices.stride(0), + combined_lens, + topk_indices, + topk_indices.stride(0), + query_start_loc, + seq_lens, + gather_lens, + M, + N, + TOP_K=topk, + COMPRESS_RATIO=compress_ratio, + WINDOW_SIZE=window_size, + PADDED_TOP_K=triton.next_power_of_2(topk_indices.shape[-1]), + ) + return combined_indices, combined_lens + + +@triton.jit +def _combine_topk_swa_indices_kernel( + combined_indices_ptr, + combined_indices_stride, + combined_lens_ptr, + topk_indices_ptr, + topk_indices_stride, + query_start_loc_ptr, + seq_lens_ptr, + gather_lens_ptr, + M, + N, + TOP_K: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + WINDOW_SIZE: tl.constexpr, + PADDED_TOP_K: tl.constexpr, +): + batch_idx = tl.program_id(0) + worker_id = tl.program_id(1) + num_workers = tl.num_programs(1) + + # query_start_loc is a global tensor; rebase to chunk-local offsets + # by subtracting the chunk's starting value. + base = tl.load(query_start_loc_ptr) + query_start = tl.load(query_start_loc_ptr + batch_idx) - base + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) - base + query_len = query_end - query_start + seq_len = tl.load(seq_lens_ptr + batch_idx) + gather_len = tl.load(gather_lens_ptr + batch_idx) + start_pos = seq_len - query_len + # The SWA portion of the gathered buffer starts from position + # (seq_len - gather_len), not position 0. We need this offset + # to correctly index into the gathered buffer. + gather_start = seq_len - gather_len + + for token_idx in range(query_start + worker_id, query_end, num_workers): + # topk_len is fully determined by the query token's absolute position: + # both the C4A indexer and the C128A metadata builder emit + # min((pos + 1) // compress_ratio, topk_tokens) valid entries. + # Caller passes TOP_K=0 for SWA-only layers to zero this out. + token_idx_in_query = token_idx - query_start + pos = start_pos + token_idx_in_query + topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) + swa_len = tl.minimum(pos + 1, WINDOW_SIZE) + + offset = tl.arange(0, PADDED_TOP_K) + mask = offset < topk_len + topk_indices = tl.load( + topk_indices_ptr + token_idx * topk_indices_stride + offset, + mask=mask, + ) + tl.store( + combined_indices_ptr + token_idx * combined_indices_stride + offset, + topk_indices + M * batch_idx, + mask=mask, + ) + offset = tl.arange(0, WINDOW_SIZE) + # Index into gathered buffer: N + (position - gather_start) + # For positions [pos - swa_len + 1, pos], the buffer indices are: + # [N + pos - swa_len + 1 - gather_start, N + pos - gather_start] + tl.store( + combined_indices_ptr + + token_idx * combined_indices_stride + + topk_len + + offset, + M * batch_idx + N + offset + pos - swa_len + 1 - gather_start, + mask=offset < swa_len, + ) + + combined_len = topk_len + swa_len + tl.store(combined_lens_ptr + token_idx, combined_len) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py new file mode 100644 index 000000000000..71a6199e1d9d --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x Triton FP8 einsum kernels for DeepSeek V4.""" + +import torch + +from vllm.triton_utils import tl, triton + + +def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + exp_bits = scale.view(torch.uint8).to(torch.int32) + fp32_bits = exp_bits << 23 + return fp32_bits.view(torch.float32) + + +@triton.jit +def _deepseek_v4_sm12_fp8_einsum_kernel( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_groups: tl.constexpr, + out_rank: tl.constexpr, + hidden_size: tl.constexpr, + a_stride_token: tl.constexpr, + a_stride_group: tl.constexpr, + a_stride_hidden: tl.constexpr, + a_scale_stride_token: tl.constexpr, + a_scale_stride_group: tl.constexpr, + a_scale_stride_hidden: tl.constexpr, + b_stride_group: tl.constexpr, + b_stride_out: tl.constexpr, + b_stride_hidden: tl.constexpr, + b_scale_stride_group: tl.constexpr, + b_scale_stride_out: tl.constexpr, + b_scale_stride_hidden: tl.constexpr, + out_stride_token: tl.constexpr, + out_stride_group: tl.constexpr, + out_stride_rank: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, + BLOCK_OUT: tl.constexpr, + BLOCK_HIDDEN: tl.constexpr, +) -> None: + token_block = tl.program_id(0) + out_block = tl.program_id(1) + group = tl.program_id(2) + + token_offsets = token_block * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + out_offsets = out_block * BLOCK_OUT + tl.arange(0, BLOCK_OUT) + hidden_offsets = tl.arange(0, BLOCK_HIDDEN) + accum = tl.zeros((BLOCK_TOKENS, BLOCK_OUT), dtype=tl.float32) + + for hidden_start in range(0, hidden_size, BLOCK_HIDDEN): + hidden = hidden_start + hidden_offsets + a = tl.load( + a_ptr + + token_offsets[:, None] * a_stride_token + + group * a_stride_group + + hidden[None, :] * a_stride_hidden, + mask=(token_offsets[:, None] < num_tokens) + & (hidden[None, :] < hidden_size), + other=0.0, + ) + b = tl.load( + b_ptr + + group * b_stride_group + + out_offsets[None, :] * b_stride_out + + hidden[:, None] * b_stride_hidden, + mask=(out_offsets[None, :] < out_rank) + & (hidden[:, None] < hidden_size), + other=0.0, + ) + raw = tl.dot(a, b, out_dtype=tl.float32) + hidden_scale_block = hidden_start // BLOCK_HIDDEN + a_scale = tl.load( + a_scale_ptr + + token_offsets * a_scale_stride_token + + group * a_scale_stride_group + + hidden_scale_block * a_scale_stride_hidden, + mask=token_offsets < num_tokens, + other=0.0, + ) + b_scale = tl.load( + b_scale_ptr + + group * b_scale_stride_group + + (out_offsets // 128) * b_scale_stride_out + + hidden_scale_block * b_scale_stride_hidden, + mask=out_offsets < out_rank, + other=0.0, + ) + accum += raw * a_scale[:, None] * b_scale[None, :] + + tl.store( + out_ptr + + token_offsets[:, None] * out_stride_token + + group * out_stride_group + + out_offsets[None, :] * out_stride_rank, + accum, + mask=(token_offsets[:, None] < num_tokens) + & (out_offsets[None, :] < out_rank), + ) + + +def deepseek_v4_sm12_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + """Compute ``bhr,hdr->bhd`` with FP32 block scales on SM12x. + + ``a`` is the transposed output of ``fused_inv_rope_fp8_quant`` with shape + ``[tokens, groups, hidden]``. ``b`` is ``wo_a`` reshaped to + ``[groups, out_rank, hidden]``. + """ + num_tokens, num_groups, hidden_size = a.shape + b_groups, out_rank, b_hidden_size = b.shape + assert b_groups == num_groups + assert b_hidden_size == hidden_size + assert out.shape == (num_tokens, num_groups, out_rank) + assert hidden_size % 128 == 0 + assert out_rank % 128 == 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if a_scale.dtype == e8m0_dtype: + a_scale = _upcast_e8m0_to_fp32(a_scale) + if b_scale.dtype == e8m0_dtype: + b_scale = _upcast_e8m0_to_fp32(b_scale) + assert a_scale.dtype == torch.float32 + assert b_scale.dtype == torch.float32 + + if num_tokens == 0: + return + + block_tokens = 16 + block_out = 128 + block_hidden = 128 + grid = ( + triton.cdiv(num_tokens, block_tokens), + triton.cdiv(out_rank, block_out), + num_groups, + ) + _deepseek_v4_sm12_fp8_einsum_kernel[grid]( + a, + a_scale, + b, + b_scale, + out, + num_tokens, + num_groups, + out_rank, + hidden_size, + a.stride(0), + a.stride(1), + a.stride(2), + a_scale.stride(0), + a_scale.stride(1), + a_scale.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + b_scale.stride(0), + b_scale.stride(1), + b_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_TOKENS=block_tokens, + BLOCK_OUT=block_out, + BLOCK_HIDDEN=block_hidden, + num_warps=4, + num_stages=3, + ) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py new file mode 100644 index 000000000000..26b076f34238 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py @@ -0,0 +1,584 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Fused compressor + FP8/MXFP4 UE8M0 quantization + KV cache insert kernels. + +Three specialized kernels: + - _fused_kv_compress_norm_rope_insert_sparse_attn: + head=512, nope=448 FP8 + rope=64 bf16 + - _fused_kv_compress_norm_rope_insert_indexer_attn: + head=128, all FP8, 1 block/token + - _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn: + head=128, MXFP4 (block=32), 4 ue8m0 bytes + +RoPE is register-based via tl.reshape -> tl.split -> tl.interleave (or the +even/odd halves are consumed directly for MXFP4, no interleave needed). +FP8 UE8M0 quant uses tl.reshape to tile [N_QUANT_BLOCKS, QUANT_BLOCK] for +per-block absmax entirely in registers. MXFP4 does the same tiling on the +even/odd halves, producing (N_QUANT_BLOCKS, MXFP4_BLOCK/2) packed nibbles +and N_QUANT_BLOCKS ue8m0 bytes. +""" + +from vllm.triton_utils import tl, triton + +from .fused_indexer_q import _e2m1_nibble + + +# ============================================================================= +# DeepseekV4 Attention path (head=512, nope=448 FP8 + rope=64 bf16) +# ============================================================================= +@triton.jit +def _fused_kv_compress_norm_rope_insert_sparse_attn( + # ── state cache (compressor internal state) ── + state_cache_ptr, + state_cache_stride0, + state_cache_stride1, + # ── metadata ── + token_to_req_indices_ptr, + positions_ptr, + slot_mapping_ptr, + block_table_ptr, + block_table_stride, + block_size, + # ── RMSNorm ── + rms_norm_weight_ptr, + rms_norm_eps, + # ── RoPE ── + cos_sin_cache_ptr, + cos_sin_stride, + # ── KV cache output ── + k_cache_ptr, + kv_slot_mapping_ptr, + kv_cache_block_size, + # ── constexprs ── + HEAD_SIZE: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + OVERLAP: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + FP8_MAX: tl.constexpr, # 448.0 + QUANT_BLOCK: tl.constexpr, # 64 for DeepseekV4 + TOKEN_STRIDE: tl.constexpr, # 576 for DeepseekV4 + SCALE_DIM: tl.constexpr, # 8 for DeepseekV4 (7 real + 1 pad) + KV_BLOCK_STRIDE: tl.constexpr, +): + """Fused compress → RMSNorm → FP8 quant (nope) → RoPE → bf16 store (rope). + + One program per token; early-exits for non-boundary positions. + + Cache block layout (``block_size`` tokens): + [0, bs*576): token data (448 fp8 + 128 bf16 each) + [bs*576, +bs*8): uint8 UE8M0 scales (7 real + 1 pad each) + """ + token_idx = tl.program_id(0) + + slot_id = tl.load(slot_mapping_ptr + token_idx) + if slot_id < 0: + return + + position = tl.load(positions_ptr + token_idx) + if (position + 1) % COMPRESS_RATIO != 0: + return + + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + + # ── Gather state cache entries ──────────────────────────────────── + start = position - (1 + OVERLAP) * COMPRESS_RATIO + 1 + tokens = tl.arange(0, (1 + OVERLAP) * COMPRESS_RATIO) + pos = start + tokens + mask_pos = pos >= 0 + + block_indices = pos // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=mask_pos, + other=0, + ) + block_offsets = pos % block_size + head_offset = (tokens >= COMPRESS_RATIO).to(tl.int32) * HEAD_SIZE + + block = tl.arange(0, TRITON_BLOCK_SIZE) + mask = block < HEAD_SIZE + block_numbers_i64 = block_numbers.to(tl.int64) + + # Precomputed row base shared by score and kv loads + row_base = ( + state_cache_ptr + + block_numbers_i64 * state_cache_stride0 + + block_offsets * state_cache_stride1 + + head_offset + ) + + combined_mask = mask_pos[:, None] & mask[None, :] + + # ── Softmax + weighted sum ─────────────────────────────────────── + score = tl.load( + row_base[:, None] + STATE_WIDTH + block[None, :], + mask=combined_mask, + other=float("-inf"), + ) + score = tl.softmax(score, dim=0) + + kv = tl.load( + row_base[:, None] + block[None, :], + mask=combined_mask, + other=0.0, + ) + + compressed_kv = tl.sum(kv * score, axis=0) # [TRITON_BLOCK_SIZE] fp32 + + # ── RMSNorm (fp32 throughout) ────────────────────────────────────── + rms_w = tl.load(rms_norm_weight_ptr + block, mask=mask, other=0.0) + variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_SIZE + rrms = tl.rsqrt(variance + rms_norm_eps) + normed = compressed_kv * rrms * rms_w + + # ── KV cache pointers ──────────────────────────────────────────── + kv_slot_idx = tl.load(kv_slot_mapping_ptr + token_idx) + if kv_slot_idx < 0: + return + kv_block_idx = kv_slot_idx // kv_cache_block_size + kv_pos_in_block = kv_slot_idx % kv_cache_block_size + + cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE + fp8_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE + scale_ptr = ( + cache_block_ptr + + kv_cache_block_size * TOKEN_STRIDE + + kv_pos_in_block * SCALE_DIM + ) + + NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM # 448 + HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 # 32 + + # FP8 UE8M0 quant: cast fp32 → bf16 → fp32 before quant to match reference. + N_QUANT_BLOCKS: tl.constexpr = TRITON_BLOCK_SIZE // QUANT_BLOCK + N_NOPE_BLOCKS: tl.constexpr = NOPE_HEAD_DIM // QUANT_BLOCK # 7 + INV_FP8_MAX: tl.constexpr = 1.0 / FP8_MAX + + quant_input = normed.to(tl.bfloat16).to(tl.float32) + quant_2d = tl.reshape(quant_input, (N_QUANT_BLOCKS, QUANT_BLOCK)) + abs_2d = tl.abs(quant_2d) + block_absmax = tl.max(abs_2d, axis=1) # [N_QUANT_BLOCKS] fp32 + block_absmax = tl.maximum(block_absmax, 1e-4) + + raw_scales = block_absmax * INV_FP8_MAX + exponents = tl.ceil(tl.log2(raw_scales)) + inv_scales = tl.exp2(-exponents) + inv_scales_col = tl.reshape(inv_scales, (N_QUANT_BLOCKS, 1)) + x_scaled = quant_2d * inv_scales_col + x_clamped = tl.clamp(x_scaled, -FP8_MAX, FP8_MAX) + x_fp8 = x_clamped.to(tl.float8e4nv) + x_uint8 = x_fp8.to(tl.uint8, bitcast=True) + x_uint8_flat = tl.reshape(x_uint8, (TRITON_BLOCK_SIZE,)) + + nope_mask = block < NOPE_HEAD_DIM + tl.store(fp8_ptr + block, x_uint8_flat, mask=nope_mask) + + scale_idx = tl.arange(0, N_QUANT_BLOCKS) + encoded = exponents + 127.0 + encoded = tl.maximum(tl.minimum(encoded, 255.0), 0.0) + tl.store( + scale_ptr + scale_idx, + encoded.to(tl.uint8), + mask=scale_idx < N_NOPE_BLOCKS, + ) + tl.store(scale_ptr + N_NOPE_BLOCKS, tl.zeros((), dtype=tl.uint8)) + + # Register-based GPT-J RoPE in fp32. + NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 + NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 + + pair_2d = tl.reshape(normed, (NUM_PAIRS, 2)) + even, odd = tl.split(pair_2d) # each [NUM_PAIRS] fp32 + + pair_idx = tl.arange(0, NUM_PAIRS) + rope_pair_local = pair_idx - NOPE_PAIRS + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + + compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO + cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride + cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0) + + new_even = even * cos_v - odd * sin_v + new_odd = odd * cos_v + even * sin_v + result = tl.interleave(new_even, new_odd) # [TRITON_BLOCK_SIZE] fp32 + + # Store rotated rope portion as bf16 into the cache's bf16 area. + bf16_ptr = (fp8_ptr + NOPE_HEAD_DIM).to(tl.pointer_type(tl.bfloat16)) + rope_local = block - NOPE_HEAD_DIM + is_rope = (block >= NOPE_HEAD_DIM) & mask + tl.store(bf16_ptr + rope_local, result.to(tl.bfloat16), mask=is_rope) + + +# ============================================================================= +# Indexer path (head=128, all FP8, single quant block) +# ============================================================================= +@triton.jit +def _fused_kv_compress_norm_rope_insert_indexer_attn( + # ── state cache (compressor internal state) ── + state_cache_ptr, + state_cache_stride0, + state_cache_stride1, + # ── metadata ── + token_to_req_indices_ptr, + positions_ptr, + slot_mapping_ptr, + block_table_ptr, + block_table_stride, + block_size, + # ── RMSNorm ── + rms_norm_weight_ptr, + rms_norm_eps, + # ── RoPE ── + cos_sin_cache_ptr, + cos_sin_stride, + # ── KV cache output ── + k_cache_ptr, + kv_slot_mapping_ptr, + kv_cache_block_size, + # ── constexprs ── + HEAD_SIZE: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + OVERLAP: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + FP8_MAX: tl.constexpr, # 448.0 + QUANT_BLOCK: tl.constexpr, # 128 for indexer + TOKEN_STRIDE: tl.constexpr, # 128 for indexer + SCALE_DIM: tl.constexpr, # 4 for indexer (1 float32) + KV_BLOCK_STRIDE: tl.constexpr, +): + """Fused compress → RMSNorm → RoPE → FP8 quant → store. + + One program per token; early-exits for non-boundary positions. + + Cache block layout: + [0, bs*128): FP8 data (128 bytes/token) + [bs*128, +bs*4): float32 scales (4 bytes/token) + + For head_dim=128 we have exactly one quant block, so we skip the + [N_QUANT_BLOCKS, QUANT_BLOCK] reshape entirely and use a flat + ``tl.max`` reduction. + """ + token_idx = tl.program_id(0) + + slot_id = tl.load(slot_mapping_ptr + token_idx) + if slot_id < 0: + return + + position = tl.load(positions_ptr + token_idx) + if (position + 1) % COMPRESS_RATIO != 0: + return + + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + + # ── Gather state cache entries ──────────────────────────────────── + start = position - (1 + OVERLAP) * COMPRESS_RATIO + 1 + tokens = tl.arange(0, (1 + OVERLAP) * COMPRESS_RATIO) + pos = start + tokens + mask_pos = pos >= 0 + + block_indices = pos // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=mask_pos, + other=0, + ) + block_offsets = pos % block_size + head_offset = (tokens >= COMPRESS_RATIO).to(tl.int32) * HEAD_SIZE + + block = tl.arange(0, TRITON_BLOCK_SIZE) + mask = block < HEAD_SIZE + block_numbers_i64 = block_numbers.to(tl.int64) + + row_base = ( + state_cache_ptr + + block_numbers_i64 * state_cache_stride0 + + block_offsets * state_cache_stride1 + + head_offset + ) + + combined_mask = mask_pos[:, None] & mask[None, :] + + score = tl.load( + row_base[:, None] + STATE_WIDTH + block[None, :], + mask=combined_mask, + other=float("-inf"), + ) + score = tl.softmax(score, dim=0) + + kv = tl.load( + row_base[:, None] + block[None, :], + mask=combined_mask, + other=0.0, + ) + + compressed_kv = tl.sum(kv * score, axis=0) # [TRITON_BLOCK_SIZE] fp32 + + # ── RMSNorm (fp32 throughout) ────────────────────────────────────── + rms_w = tl.load(rms_norm_weight_ptr + block, mask=mask, other=0.0) + variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_SIZE + rrms = tl.rsqrt(variance + rms_norm_eps) + normed = compressed_kv * rrms * rms_w + + # ── KV cache pointers ──────────────────────────────────────────── + kv_slot_idx = tl.load(kv_slot_mapping_ptr + token_idx) + if kv_slot_idx < 0: + return + kv_block_idx = kv_slot_idx // kv_cache_block_size + kv_pos_in_block = kv_slot_idx % kv_cache_block_size + + cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE + fp8_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE + scale_ptr = ( + cache_block_ptr + + kv_cache_block_size * TOKEN_STRIDE + + kv_pos_in_block * SCALE_DIM + ) + + NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM + HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 + + # ── Register-based GPT-J forward RoPE in fp32 ───────────────────── + NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 + NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 + + normed_2d = tl.reshape(normed, (NUM_PAIRS, 2)) + even, odd = tl.split(normed_2d) # each [NUM_PAIRS] fp32 + + pair_idx = tl.arange(0, NUM_PAIRS) + rope_pair_local = pair_idx - NOPE_PAIRS + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + + compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO + cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride + cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0) + + new_even = even * cos_v - odd * sin_v + new_odd = odd * cos_v + even * sin_v + result = tl.interleave(new_even, new_odd) # fp32 + + # ── FP8 UE8M0 quant: single block, flat reduction ──────────────── + tl.static_assert( + TRITON_BLOCK_SIZE == QUANT_BLOCK, + "Indexer expects one quant block (QUANT_BLOCK == TRITON_BLOCK_SIZE)", + ) + INV_FP8_MAX: tl.constexpr = 1.0 / FP8_MAX + + result_bf16 = result.to(tl.bfloat16).to(tl.float32) + absmax = tl.max(tl.abs(result_bf16), axis=0) # scalar + absmax = tl.maximum(absmax, 1e-4) + raw_scale = absmax * INV_FP8_MAX + exponent = tl.ceil(tl.log2(raw_scale)) + inv_scale = tl.exp2(-exponent) + + x_scaled = result_bf16 * inv_scale + x_clamped = tl.clamp(x_scaled, -FP8_MAX, FP8_MAX) + x_fp8 = x_clamped.to(tl.float8e4nv) + x_uint8 = x_fp8.to(tl.uint8, bitcast=True) + + tl.store(fp8_ptr + block, x_uint8, mask=mask) + + # Single float32 scale + scale_val = tl.exp2(exponent) + tl.store(scale_ptr.to(tl.pointer_type(tl.float32)), scale_val) + + +# ============================================================================= +# Indexer path (head=128, MXFP4: 2 nibbles/byte + ue8m0 per 32-elem block) +# ============================================================================= +@triton.jit +def _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn( + # ── state cache (compressor internal state) ── + state_cache_ptr, + state_cache_stride0, + state_cache_stride1, + # ── metadata ── + token_to_req_indices_ptr, + positions_ptr, + slot_mapping_ptr, + block_table_ptr, + block_table_stride, + block_size, + # ── RMSNorm ── + rms_norm_weight_ptr, + rms_norm_eps, + # ── RoPE ── + cos_sin_cache_ptr, + cos_sin_stride, + # ── KV cache output ── + k_cache_ptr, + kv_slot_mapping_ptr, + kv_cache_block_size, + # ── constexprs ── + HEAD_SIZE: tl.constexpr, + TRITON_BLOCK_SIZE: tl.constexpr, + STATE_WIDTH: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + OVERLAP: tl.constexpr, + ROPE_HEAD_DIM: tl.constexpr, + FP8_MAX: tl.constexpr, # unused for MXFP4 (kept for signature parity) + QUANT_BLOCK: tl.constexpr, # 32 for MXFP4 + TOKEN_STRIDE: tl.constexpr, # HEAD_SIZE // 2 = 64 packed bytes/token + SCALE_DIM: tl.constexpr, # HEAD_SIZE // QUANT_BLOCK = 4 ue8m0 bytes/token + KV_BLOCK_STRIDE: tl.constexpr, +): + """Fused compress → RMSNorm → RoPE → MXFP4 quant → store. + + One program per token; early-exits for non-boundary positions. + + Cache block layout (``block_size`` tokens per cache block): + [0, bs*TOKEN_STRIDE): packed MXFP4 nibbles (2 values/byte) + [bs*TOKEN_STRIDE, +bs*SCALE_DIM): ue8m0 scale bytes (one per 32-elem block) + + MXFP4 format: + - E2M1 4-bit values packed two per byte (low nibble first, then high). + - Per-32-element block scale = 2^ceil(log2(amax / 6.0)), stored ue8m0 + (byte = exponent + 127). + - Max representable magnitude = 6.0. + """ + token_idx = tl.program_id(0) + + slot_id = tl.load(slot_mapping_ptr + token_idx) + if slot_id < 0: + return + + position = tl.load(positions_ptr + token_idx) + if (position + 1) % COMPRESS_RATIO != 0: + return + + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + + # ── Gather state cache entries ──────────────────────────────────── + start = position - (1 + OVERLAP) * COMPRESS_RATIO + 1 + tokens = tl.arange(0, (1 + OVERLAP) * COMPRESS_RATIO) + pos = start + tokens + mask_pos = pos >= 0 + + block_indices = pos // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=mask_pos, + other=0, + ) + block_offsets = pos % block_size + head_offset = (tokens >= COMPRESS_RATIO).to(tl.int32) * HEAD_SIZE + + block = tl.arange(0, TRITON_BLOCK_SIZE) + mask = block < HEAD_SIZE + block_numbers_i64 = block_numbers.to(tl.int64) + + row_base = ( + state_cache_ptr + + block_numbers_i64 * state_cache_stride0 + + block_offsets * state_cache_stride1 + + head_offset + ) + + combined_mask = mask_pos[:, None] & mask[None, :] + + score = tl.load( + row_base[:, None] + STATE_WIDTH + block[None, :], + mask=combined_mask, + other=float("-inf"), + ) + score = tl.softmax(score, dim=0) + + kv = tl.load( + row_base[:, None] + block[None, :], + mask=combined_mask, + other=0.0, + ) + + compressed_kv = tl.sum(kv * score, axis=0) # [TRITON_BLOCK_SIZE] fp32 + + # ── RMSNorm (fp32 throughout) ────────────────────────────────────── + rms_w = tl.load(rms_norm_weight_ptr + block, mask=mask, other=0.0) + variance = tl.sum(compressed_kv * compressed_kv, axis=0) / HEAD_SIZE + rrms = tl.rsqrt(variance + rms_norm_eps) + normed = compressed_kv * rrms * rms_w + + # ── KV cache pointers (segregated: values first, then scales) ──── + kv_slot_idx = tl.load(kv_slot_mapping_ptr + token_idx) + if kv_slot_idx < 0: + return + kv_block_idx = kv_slot_idx // kv_cache_block_size + kv_pos_in_block = kv_slot_idx % kv_cache_block_size + + cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE + val_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE + scale_ptr = ( + cache_block_ptr + + kv_cache_block_size * TOKEN_STRIDE + + kv_pos_in_block * SCALE_DIM + ) + + NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM + HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 + + # ── Register-based GPT-J forward RoPE in fp32 ───────────────────── + # We keep the even/odd halves (no tl.interleave afterwards) because the + # MXFP4 per-block absmax / pack naturally operates on (even, odd) pairs. + NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 + NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 + + normed_2d = tl.reshape(normed, (NUM_PAIRS, 2)) + even, odd = tl.split(normed_2d) # each [NUM_PAIRS] fp32 + + pair_idx = tl.arange(0, NUM_PAIRS) + rope_pair_local = pair_idx - NOPE_PAIRS + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + + compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO + cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride + cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope_pair, other=0.0) + + new_even = even * cos_v - odd * sin_v + new_odd = odd * cos_v + even * sin_v + + # bf16 roundtrip for parity with reference / Q-side kernel numerics. + new_even = new_even.to(tl.bfloat16).to(tl.float32) + new_odd = new_odd.to(tl.bfloat16).to(tl.float32) + + # ── MXFP4 quant: tile even/odd halves into (N_BLOCKS, HALF_BLOCK) ── + # Each MXFP4 block of QUANT_BLOCK elements = HALF_BLOCK consecutive pairs, + # so (N_BLOCKS, HALF_BLOCK) rows of even/odd each land exactly one block. + N_QUANT_BLOCKS: tl.constexpr = HEAD_SIZE // QUANT_BLOCK + HALF_BLOCK: tl.constexpr = QUANT_BLOCK // 2 + tl.static_assert(TRITON_BLOCK_SIZE == HEAD_SIZE) + tl.static_assert(HEAD_SIZE % QUANT_BLOCK == 0) + tl.static_assert(TOKEN_STRIDE == HEAD_SIZE // 2) + tl.static_assert(SCALE_DIM == N_QUANT_BLOCKS) + + even_2d = tl.reshape(new_even, (N_QUANT_BLOCKS, HALF_BLOCK)) + odd_2d = tl.reshape(new_odd, (N_QUANT_BLOCKS, HALF_BLOCK)) + + amax = tl.maximum( + tl.max(tl.abs(even_2d), axis=1), + tl.max(tl.abs(odd_2d), axis=1), + ) + amax = tl.maximum(amax, 1e-4) + + # ue8m0 block scale: 2^ceil(log2(amax / 6.0)), stored as (exp + 127) byte. + log2_ratio = tl.ceil(tl.log2(amax / 6.0)) + log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0) + inv_scale = tl.exp2(-log2_ratio) + ue8m0 = (log2_ratio + 127.0).to(tl.uint8) # [N_QUANT_BLOCKS] + + inv_scale_col = tl.reshape(inv_scale, (N_QUANT_BLOCKS, 1)) + lo_nib = _e2m1_nibble(even_2d * inv_scale_col) # (N_BLOCKS, HALF_BLOCK) uint8 + hi_nib = _e2m1_nibble(odd_2d * inv_scale_col) + packed = lo_nib | (hi_nib << 4) + packed_flat = tl.reshape(packed, (TOKEN_STRIDE,)) + + tl.store(val_ptr + tl.arange(0, TOKEN_STRIDE), packed_flat) + tl.store(scale_ptr + tl.arange(0, SCALE_DIM), ue8m0) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py new file mode 100644 index 000000000000..0254a46752c6 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -0,0 +1,415 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton + +# MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. +MXFP4_BLOCK_SIZE = 32 + + +@triton.jit +def _get_cos_sin( + cos_sin_cache_ptr, + cos_sin_cache_stride, + pos, + HALF_ROT_DIM: tl.constexpr, +): + block = tl.arange(0, HALF_ROT_DIM) + cos = tl.load(cos_sin_cache_ptr + pos * cos_sin_cache_stride + block) + cos = cos.to(tl.float32) + sin = tl.load(cos_sin_cache_ptr + pos * cos_sin_cache_stride + block + HALF_ROT_DIM) + sin = sin.to(tl.float32) + return cos, sin + + +@triton.jit +def _e2m1_nibble(x): + """Quantize fp32 x (already scale-divided) to E2M1 4-bit nibble in uint8. + Matches torch.bucketize with boundaries + [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] and right=False (each boundary + belongs to the lower bucket), plus sign bit.""" + abs_x = tl.minimum(tl.abs(x), 6.0) + code = tl.where( + abs_x <= 0.25, + 0.0, + tl.where( + abs_x <= 0.75, + 1.0, + tl.where( + abs_x <= 1.25, + 2.0, + tl.where( + abs_x <= 1.75, + 3.0, + tl.where( + abs_x <= 2.5, + 4.0, + tl.where(abs_x <= 3.5, 5.0, tl.where(abs_x <= 5.0, 6.0, 7.0)), + ), + ), + ), + ), + ) + code_u8 = code.to(tl.uint8) + sign = ((x < 0) & (code_u8 != 0)).to(tl.uint8) + return code_u8 | (sign << 3) + + +@triton.jit +def _quantize_mxfp4_pair(x_lo, x_hi): + """Quantize a block of MXFP4_BLOCK_SIZE fp32 values given as two + interleaved halves (x_lo = values at even positions in the block, + x_hi = values at odd positions). Returns: + - packed : uint8[BLOCK/2] (low nibble = quant(x_lo), high = quant(x_hi)) + - ue8m0 : scalar uint8 (block scale = 2^(ue8m0 - 127)) + """ + amax = tl.maximum(tl.max(tl.abs(x_lo)), tl.max(tl.abs(x_hi))) + amax = tl.maximum(amax, 1e-4) + # ue8m0 block scale: 2^ceil(log2(amax/6.0)). + log2_ratio = tl.math.ceil(tl.math.log2(amax / 6.0)) + log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0) + scale = tl.math.exp2(log2_ratio) + ue8m0 = (log2_ratio + 127.0).to(tl.uint8) + + inv_scale = 1.0 / scale + lo_nib = _e2m1_nibble(x_lo * inv_scale) + hi_nib = _e2m1_nibble(x_hi * inv_scale) + packed = lo_nib | (hi_nib << 4) + return packed, ue8m0 + + +@triton.jit +def _fused_indexer_q_rope_quant_kernel( + pos_ptr, + # Index Q RoPE + index_q_ptr, + index_q_stride0, + index_q_stride1, + index_q_cos_sin_ptr, + index_q_cos_sin_stride, + INDEX_Q_HALF_ROT_DIM: tl.constexpr, + # Index Q Quantize + index_q_fp8_ptr, + index_q_fp8_stride0, + index_q_fp8_stride1, + INDEX_Q_HEAD_DIM: tl.constexpr, + # Index weights + index_weights_ptr, + index_weights_stride, + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out_ptr, + index_weights_out_stride, +): + # Layout matches the unfused reference (DeepseekV4ScalingRotaryEmbedding + # + per_token_group_quant_fp8): GPT-J interleaved RoPE applied to the + # LAST rope_dim dims of each head; the leading [0, NOPE_DIM) is passed + # through unchanged. + INDEX_Q_ROT_DIM: tl.constexpr = 2 * INDEX_Q_HALF_ROT_DIM + INDEX_Q_NOPE_DIM: tl.constexpr = INDEX_Q_HEAD_DIM - INDEX_Q_ROT_DIM + tl.static_assert(INDEX_Q_NOPE_DIM >= 0) + + tok_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + pos = tl.load(pos_ptr + tok_idx) + cos, sin = _get_cos_sin( + index_q_cos_sin_ptr, + index_q_cos_sin_stride, + pos, + INDEX_Q_HALF_ROT_DIM, + ) + half_offset = tl.arange(0, INDEX_Q_HALF_ROT_DIM) + base_ptr = index_q_ptr + tok_idx * index_q_stride0 + head_idx * index_q_stride1 + + # Interleaved (GPT-J) RoPE on dims [NOPE_DIM, HEAD_DIM): + # even = q[NOPE_DIM + 2*i], odd = q[NOPE_DIM + 2*i + 1] + rot_base = base_ptr + INDEX_Q_NOPE_DIM + x_even = tl.load(rot_base + half_offset * 2).to(tl.float32) + x_odd = tl.load(rot_base + half_offset * 2 + 1).to(tl.float32) + r_even = x_even * cos - x_odd * sin + r_odd = x_odd * cos + x_even * sin + + # Match reference numerics: fp32 → bf16 → fp32 before the ue8m0 absmax. + # Same pattern as the K-side compressor kernel (fused_compress_quant_cache.py). + r_even = r_even.to(tl.bfloat16).to(tl.float32) + r_odd = r_odd.to(tl.bfloat16).to(tl.float32) + + amax = tl.maximum(tl.max(tl.abs(r_even)), tl.max(tl.abs(r_odd))) + if INDEX_Q_NOPE_DIM > 0: + nope_offset = tl.arange(0, INDEX_Q_NOPE_DIM) + x_nope = tl.load(base_ptr + nope_offset).to(tl.float32) + amax = tl.maximum(amax, tl.max(tl.abs(x_nope))) + index_q_scale = tl.div_rn(tl.maximum(amax, 1e-4), 448.0) + index_q_scale = tl.math.exp2(tl.math.ceil(tl.math.log2(index_q_scale))) + + # Store quantized values to index_q_fp8 + fp8_base_ptr = ( + index_q_fp8_ptr + tok_idx * index_q_fp8_stride0 + head_idx * index_q_fp8_stride1 + ) + if INDEX_Q_NOPE_DIM > 0: + tl.store( + fp8_base_ptr + nope_offset, + tl.div_rn(x_nope, index_q_scale).to(tl.float8e4nv), + ) + fp8_rot_base = fp8_base_ptr + INDEX_Q_NOPE_DIM + tl.store( + fp8_rot_base + half_offset * 2, + tl.div_rn(r_even, index_q_scale).to(tl.float8e4nv), + ) + tl.store( + fp8_rot_base + half_offset * 2 + 1, + tl.div_rn(r_odd, index_q_scale).to(tl.float8e4nv), + ) + + # FP8 weight-fold contract: + # index_weights_out = index_weights * q_scale * softmax_scale * head_scale + # The per-token-per-head q_scale (fp32) IS folded into the output weights + # here because FP8 Q is stored WITHOUT a companion scale tensor — the + # downstream fp8_fp4_mqa_logits/fp8_fp4_paged_mqa_logits kernels use `weights` to + # apply per-token Q scale inline. See the MXFP4 kernel below for the + # contrasting convention (scales live with the Q values, weights are NOT + # q-scaled). + index_weights = tl.load( + index_weights_ptr + tok_idx * index_weights_stride + head_idx + ) + index_weights = index_weights.to(tl.float32) + index_weights *= index_q_scale + index_weights *= index_weights_softmax_scale + index_weights *= index_weights_head_scale + tl.store( + index_weights_out_ptr + tok_idx * index_weights_out_stride + head_idx, + index_weights, + ) + + +@triton.jit +def _fused_indexer_q_rope_mxfp4_kernel( + pos_ptr, + # Index Q RoPE input (fp/bf16) + index_q_ptr, + index_q_stride0, + index_q_stride1, + index_q_cos_sin_ptr, + index_q_cos_sin_stride, + INDEX_Q_HALF_ROT_DIM: tl.constexpr, + # MXFP4 Q outputs + index_q_mxfp4_ptr, # uint8, (T, H, HEAD_DIM // 2) + index_q_mxfp4_stride0, + index_q_mxfp4_stride1, + index_q_scale_ptr, # uint8 ue8m0, (T, H, HEAD_DIM // BLOCK) + index_q_scale_stride0, + index_q_scale_stride1, + INDEX_Q_HEAD_DIM: tl.constexpr, + MXFP4_BLOCK: tl.constexpr, + # Weights (NO per-token q_scale fold for MXFP4; per-block scales stay + # with the Q values in the output scale tensor). + index_weights_ptr, + index_weights_stride, + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out_ptr, + index_weights_out_stride, +): + INDEX_Q_ROT_DIM: tl.constexpr = 2 * INDEX_Q_HALF_ROT_DIM + INDEX_Q_NOPE_DIM: tl.constexpr = INDEX_Q_HEAD_DIM - INDEX_Q_ROT_DIM + NUM_NOPE_BLOCKS: tl.constexpr = INDEX_Q_NOPE_DIM // MXFP4_BLOCK + NUM_ROPE_BLOCKS: tl.constexpr = INDEX_Q_ROT_DIM // MXFP4_BLOCK + HALF_BLOCK: tl.constexpr = MXFP4_BLOCK // 2 + tl.static_assert(INDEX_Q_NOPE_DIM >= 0) + tl.static_assert(INDEX_Q_NOPE_DIM % MXFP4_BLOCK == 0) + tl.static_assert(INDEX_Q_ROT_DIM % MXFP4_BLOCK == 0) + tl.static_assert(MXFP4_BLOCK % 2 == 0) + + tok_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + pos = tl.load(pos_ptr + tok_idx) + + q_base = index_q_ptr + tok_idx * index_q_stride0 + head_idx * index_q_stride1 + out_base = ( + index_q_mxfp4_ptr + + tok_idx * index_q_mxfp4_stride0 + + head_idx * index_q_mxfp4_stride1 + ) + scale_base = ( + index_q_scale_ptr + + tok_idx * index_q_scale_stride0 + + head_idx * index_q_scale_stride1 + ) + + half_off = tl.arange(0, HALF_BLOCK) + + # ---- NoPE blocks: direct load, pair as (even-index, odd-index) values ---- + for b in tl.static_range(NUM_NOPE_BLOCKS): + base = b * MXFP4_BLOCK + x_lo = tl.load(q_base + base + half_off * 2).to(tl.float32) + x_hi = tl.load(q_base + base + half_off * 2 + 1).to(tl.float32) + packed, ue8m0 = _quantize_mxfp4_pair(x_lo, x_hi) + tl.store(out_base + base // 2 + half_off, packed) + tl.store(scale_base + b, ue8m0) + + # ---- RoPE blocks: apply GPT-J interleaved RoPE to the block's 16 pairs, + # then quantize. Each block covers HALF_BLOCK (=16) cos/sin pairs. ---- + rot_q_base = q_base + INDEX_Q_NOPE_DIM + for b in tl.static_range(NUM_ROPE_BLOCKS): + pair_off = b * HALF_BLOCK + half_off # indices in [0, HALF_ROT_DIM) + cos_b = tl.load( + index_q_cos_sin_ptr + pos * index_q_cos_sin_stride + pair_off + ).to(tl.float32) + sin_b = tl.load( + index_q_cos_sin_ptr + + pos * index_q_cos_sin_stride + + pair_off + + INDEX_Q_HALF_ROT_DIM + ).to(tl.float32) + x_even = tl.load(rot_q_base + pair_off * 2).to(tl.float32) + x_odd = tl.load(rot_q_base + pair_off * 2 + 1).to(tl.float32) + r_even = x_even * cos_b - x_odd * sin_b + r_odd = x_odd * cos_b + x_even * sin_b + # bf16 roundtrip for parity with the FP8 kernel / reference numerics. + r_even = r_even.to(tl.bfloat16).to(tl.float32) + r_odd = r_odd.to(tl.bfloat16).to(tl.float32) + packed, ue8m0 = _quantize_mxfp4_pair(r_even, r_odd) + rope_byte_off = (INDEX_Q_NOPE_DIM + b * MXFP4_BLOCK) // 2 + tl.store(out_base + rope_byte_off + half_off, packed) + tl.store(scale_base + NUM_NOPE_BLOCKS + b, ue8m0) + + # MXFP4 weight-fold contract: + # index_weights_out = index_weights * softmax_scale * head_scale + # NOTE: q_scale is NOT folded here (contrast with the FP8 kernel above). + # MXFP4 Q emits a separate ue8m0 scale tensor of shape + # (T, H, HEAD_DIM // MXFP4_BLOCK) alongside the packed values, so each + # per-block scale is applied by the downstream MXFP4 logits kernel when + # dequantizing Q — there is no per-token scalar to fold into `weights`. + index_weights = tl.load( + index_weights_ptr + tok_idx * index_weights_stride + head_idx + ).to(tl.float32) + index_weights *= index_weights_softmax_scale + index_weights *= index_weights_head_scale + tl.store( + index_weights_out_ptr + tok_idx * index_weights_out_stride + head_idx, + index_weights, + ) + + +def fused_indexer_q_rope_quant( + positions: torch.Tensor, + index_q: torch.Tensor, + index_q_cos_sin_cache: torch.Tensor, + # Index weights + index_weights: torch.Tensor, + index_weights_softmax_scale: float, + index_weights_head_scale: float, + use_fp4: bool = False, +) -> tuple[ + torch.Tensor | tuple[torch.Tensor, torch.Tensor], + torch.Tensor, +]: + """Fused RoPE + quantize Q for the sparse indexer. + + Weight-fold semantics (important — the two paths differ): + + FP8 path (use_fp4=False, default): + q_fp8 : (T, H, HEAD_DIM) float8_e4m3fn, per-token-per-head + scalar scale (NOT stored — folded into weights below) + weights_out = weights * q_scale * softmax_scale * head_scale + Rationale: a single per-token q_scale is a scalar the downstream FP8 + logits kernel would otherwise multiply in. Folding it into `weights` + avoids emitting a separate tensor and is free for the logits kernel. + + MXFP4 path (use_fp4=True): + q_packed : (T, H, HEAD_DIM // 2) uint8 (2 E2M1 nibbles per byte) + q_scale : (T, H, HEAD_DIM // MXFP4_BLOCK_SIZE) uint8 ue8m0 bytes + weights_out = weights * softmax_scale * head_scale + Rationale: MXFP4 has PER-BLOCK (32-element) scales that live with + the Q values — they cannot be folded into a per-token weight + scalar, so `weights` carries only the softmax and head scales. + + Returns (q_quant, weights_out) where q_quant is either a Tensor (FP8) or + a (values, scales) tuple (MXFP4). This matches the union type accepted + by `SparseAttnIndexer.forward_*`. + """ + assert positions.ndim == 1 + assert index_q.ndim == 3 + assert index_q_cos_sin_cache.ndim == 2 + + num_tokens = positions.shape[0] + num_index_q_heads = index_q.shape[1] + index_q_head_dim = index_q.shape[2] + + index_weights_out = torch.empty_like(index_weights, dtype=torch.float32) + + if use_fp4: + assert index_q_head_dim % MXFP4_BLOCK_SIZE == 0, ( + f"head_dim={index_q_head_dim} must be a multiple of MXFP4 block " + f"size {MXFP4_BLOCK_SIZE}" + ) + num_scale_blocks = index_q_head_dim // MXFP4_BLOCK_SIZE + index_q_packed = torch.empty( + (num_tokens, num_index_q_heads, index_q_head_dim // 2), + dtype=torch.uint8, + device=index_q.device, + ) + index_q_scale = torch.empty( + (num_tokens, num_index_q_heads, num_scale_blocks), + dtype=torch.uint8, + device=index_q.device, + ) + _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)]( + positions, + index_q, + index_q.stride(0), + index_q.stride(1), + index_q_cos_sin_cache, + index_q_cos_sin_cache.stride(0), + index_q_cos_sin_cache.shape[-1] // 2, + index_q_packed, + index_q_packed.stride(0), + index_q_packed.stride(1), + index_q_scale, + index_q_scale.stride(0), + index_q_scale.stride(1), + index_q_head_dim, + MXFP4_BLOCK_SIZE, + index_weights, + index_weights.stride(0), + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out, + index_weights_out.stride(0), + num_warps=1, # TODO: Tune this + ) + # Values stay uint8 (2 E2M1 nibbles per byte). Scales are 4 ue8m0 + # bytes per (token, head) reinterpreted as one int32, then squeezed + # from (T, H, 1) to (T, H) to match DeepGEMM's expected q_sf rank + # (prefill wants 2-D (seq_len, num_heads); decode reshapes this to + # 3-D (batch, next_n, num_heads)). + return ( + index_q_packed, + index_q_scale.view(torch.int32).squeeze(-1), + ), index_weights_out + + index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) + _fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)]( + positions, + index_q, + index_q.stride(0), + index_q.stride(1), + index_q_cos_sin_cache, + index_q_cos_sin_cache.stride(0), + index_q_cos_sin_cache.shape[-1] // 2, + index_q_fp8, + index_q_fp8.stride(0), + index_q_fp8.stride(1), + index_q_head_dim, + index_weights, + index_weights.stride(0), + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out, + index_weights_out.stride(0), + num_warps=1, # TODO: Tune this + ) + return index_q_fp8, index_weights_out diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py new file mode 100644 index 000000000000..97c9538889a1 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Fused inverse RoPE + block-scaled FP8 quantization kernel for DeepseekV4 attention. + +Output scale format is pre-transformed (MN-major TMA-aligned; FP32 on SM90, +INT32-packed UE8M0 on SM100) so fp8_einsum skips transform_sf_into_required_layout. +""" + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _fused_inv_rope_fp8_quant_per_head( + o_ptr, + positions_ptr, + cos_sin_cache_ptr, + fp8_ptr, + scale_ptr, + num_tokens, + heads_per_group: tl.constexpr, + o_stride_token, + o_stride_head, + cache_stride_pos, + fp8_stride_group, + fp8_stride_token, + scale_stride_group, + scale_stride_k, + fp8_max: tl.constexpr, + eps: tl.constexpr, + QUANT_GROUP_SIZE: tl.constexpr, + CHUNKS_PER_HEAD: tl.constexpr, + ROPE_START: tl.constexpr, + HALF_ROPE: tl.constexpr, + TMA_ALIGNED_SCALES: tl.constexpr, +): + # int64: stride multiply overflows int32 past num_tokens=32768 (IMA). + pid_token = tl.program_id(0).to(tl.int64) + pid_gh = tl.program_id(1).to(tl.int64) + + g = pid_gh // heads_per_group + head_in_group = pid_gh % heads_per_group + global_head = pid_gh + qb_start = head_in_group * CHUNKS_PER_HEAD + + # Padding rows in the TMA-aligned scale buffer: fill with zero and skip quant. + if pid_token >= num_tokens: + if TMA_ALIGNED_SCALES: + scale_addr = ( + scale_ptr + + g * scale_stride_group + + pid_token + + head_in_group * scale_stride_k + ) + tl.store(scale_addr, tl.zeros((), dtype=tl.int32)) + else: + block_offsets = tl.arange(0, CHUNKS_PER_HEAD) + qb_indices = qb_start + block_offsets + scale_addrs = ( + scale_ptr + + g * scale_stride_group + + pid_token + + qb_indices * scale_stride_k + ) + tl.store(scale_addrs, tl.zeros((CHUNKS_PER_HEAD,), dtype=tl.float32)) + return + + input_base = o_ptr + pid_token * o_stride_token + global_head * o_stride_head + + HEAD_DIM: tl.constexpr = CHUNKS_PER_HEAD * QUANT_GROUP_SIZE + offsets = tl.arange(0, HEAD_DIM) + x = tl.load(input_base + offsets).to(tl.float32) + + rope_abs_start: tl.constexpr = (CHUNKS_PER_HEAD - 1) * QUANT_GROUP_SIZE + ROPE_START + pos = tl.load(positions_ptr + pid_token) + cache_base = cos_sin_cache_ptr + pos * cache_stride_pos + is_rope = offsets >= rope_abs_start + rope_local = offsets - rope_abs_start + + x_partner = tl.load(input_base + (offsets ^ 1), mask=is_rope, other=0.0).to( + tl.float32 + ) + cs_idx = tl.maximum(rope_local >> 1, 0) + cos_v = tl.load(cache_base + cs_idx, mask=is_rope, other=1.0) + sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, mask=is_rope, other=0.0) + x_add = x * cos_v + x_partner * sin_v + x_sub = x * cos_v - x_partner * sin_v + is_even = (rope_local & 1) == 0 + rotated = tl.where(is_even, x_add, x_sub) + x = tl.where(is_rope, rotated, x) + + x_2d = tl.reshape(tl.abs(x), (CHUNKS_PER_HEAD, QUANT_GROUP_SIZE)) + block_absmax = tl.maximum(tl.max(x_2d, axis=1), eps) + scale_raw = block_absmax * (1.0 / fp8_max) + scales = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) + + scales_exp = tl.reshape( + tl.broadcast_to( + tl.reshape(scales, (CHUNKS_PER_HEAD, 1)), + (CHUNKS_PER_HEAD, QUANT_GROUP_SIZE), + ), + (HEAD_DIM,), + ) + x_quant = tl.clamp(x / scales_exp, -fp8_max, fp8_max).to(tl.float8e4nv) + + fp8_base = ( + fp8_ptr + + g * fp8_stride_group + + pid_token * fp8_stride_token + + qb_start * QUANT_GROUP_SIZE + ) + tl.store(fp8_base + offsets, x_quant) + + block_offsets = tl.arange(0, CHUNKS_PER_HEAD) + qb_indices = qb_start + block_offsets + if TMA_ALIGNED_SCALES: + scale_bits = scales.to(tl.int32, bitcast=True) + ue8m0_bytes = (scale_bits >> 23) & 0xFF + packed_val = tl.sum(ue8m0_bytes << (block_offsets * 8)) + scale_addr = ( + scale_ptr + + g * scale_stride_group + + pid_token + + head_in_group * scale_stride_k + ) + tl.store(scale_addr, packed_val) + else: + scale_addrs = ( + scale_ptr + g * scale_stride_group + pid_token + qb_indices * scale_stride_k + ) + tl.store(scale_addrs, scales) + + +def fused_inv_rope_fp8_quant( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + n_groups: int, + heads_per_group: int, + nope_dim: int = 448, + rope_dim: int = 64, + quant_group_size: int = 128, + tma_aligned_scales: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused inverse RoPE + block-scaled FP8 quantization. + + Args: + o: Attention output [num_tokens, num_heads, head_dim] bf16. + positions: Token positions [num_tokens] int64. + cos_sin_cache: Precomputed [max_pos, rope_dim] with cos||sin. + n_groups: Number of output groups. + heads_per_group: Heads per group. + nope_dim: Non-RoPE dimensions per head (default 448). + rope_dim: RoPE dimensions per head (default 64). + quant_group_size: FP8 quantization block size (default 128). + tma_aligned_scales: Output INT32 packed UE8M0 for SM100 (True) + or FP32 for SM90 (False). + + Returns: + o_fp8: [T, G, D] float8_e4m3fn, strides (D, T*D, 1). + o_scale: Pre-transformed scale tensor for fp8_einsum. + """ + from vllm.utils.deep_gemm import get_tma_aligned_size + + num_tokens, num_heads, head_dim = o.shape + assert num_heads == n_groups * heads_per_group + assert head_dim == nope_dim + rope_dim + assert head_dim % quant_group_size == 0 + assert nope_dim % quant_group_size == (quant_group_size - rope_dim) + assert rope_dim % 2 == 0 + assert cos_sin_cache.shape[-1] == rope_dim + assert cos_sin_cache.dtype == torch.float32 + + d = heads_per_group * head_dim + num_scale_blocks = d // quant_group_size + chunks_per_head = head_dim // quant_group_size + + fp8_dtype = torch.float8_e4m3fn + fp8_max = torch.finfo(fp8_dtype).max + + fp8_buf = torch.empty( + (n_groups, num_tokens, d), + dtype=fp8_dtype, + device=o.device, + ) + + tma_aligned_T = get_tma_aligned_size(num_tokens, 4) + if tma_aligned_scales: + packed_sf_k = (num_scale_blocks + 3) // 4 + scale_buf = torch.empty( + n_groups * packed_sf_k * tma_aligned_T, + dtype=torch.int32, + device=o.device, + ).as_strided( + (n_groups, num_tokens, packed_sf_k), + (packed_sf_k * tma_aligned_T, 1, tma_aligned_T), + ) + else: + scale_buf = torch.empty( + n_groups * num_scale_blocks * tma_aligned_T, + dtype=torch.float32, + device=o.device, + ).as_strided( + (n_groups, num_tokens, num_scale_blocks), + (num_scale_blocks * tma_aligned_T, 1, tma_aligned_T), + ) + + common_args = dict( + heads_per_group=heads_per_group, + o_stride_token=o.stride(0), + o_stride_head=o.stride(1), + cache_stride_pos=cos_sin_cache.stride(0), + fp8_stride_group=fp8_buf.stride(0), + fp8_stride_token=fp8_buf.stride(1), + scale_stride_group=scale_buf.stride(0), + scale_stride_k=scale_buf.stride(2), + fp8_max=fp8_max, + eps=1e-10, + QUANT_GROUP_SIZE=quant_group_size, + CHUNKS_PER_HEAD=chunks_per_head, + ROPE_START=nope_dim % quant_group_size, + HALF_ROPE=rope_dim // 2, + TMA_ALIGNED_SCALES=tma_aligned_scales, + num_stages=1, + launch_pdl=False, + ) + + grid = (tma_aligned_T, n_groups * heads_per_group) + _fused_inv_rope_fp8_quant_per_head[grid]( + o, + positions, + cos_sin_cache, + fp8_buf, + scale_buf, + num_tokens, + **common_args, + num_warps=1, + ) + + return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_qk_rmsnorm.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_qk_rmsnorm.py new file mode 100644 index 000000000000..0dd348a46e26 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_qk_rmsnorm.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _fused_q_kv_rmsnorm_kernel( + q_ptr, + q_out_ptr, + q_weight_ptr, + q_in_stride, + q_out_stride, + kv_ptr, + kv_out_ptr, + kv_weight_ptr, + kv_in_stride, + kv_out_stride, + eps, + Q_SIZE: tl.constexpr, + KV_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + # num_tokens goes on grid-x (max 2**31 - 1); task goes on grid-y. + # CUDA's grid-y/z are capped at 65535, so putting num_tokens there crashes + # the launch at max-num-batched-tokens >= 65536 with "invalid argument". + # int64: q_in_stride can be ~24K (128 heads × 192) and overflows int32 + # past num_tokens ~87K under large chunked prefill. + token_idx = tl.program_id(0).to(tl.int64) + pid_task = tl.program_id(1) + + if pid_task == 0: + SIZE = Q_SIZE + row_in = q_ptr + token_idx * q_in_stride + weight_ptr = q_weight_ptr + row_out = q_out_ptr + token_idx * q_out_stride + else: + SIZE = KV_SIZE + row_in = kv_ptr + token_idx * kv_in_stride + weight_ptr = kv_weight_ptr + row_out = kv_out_ptr + token_idx * kv_out_stride + + # RMSNorm in fp32 throughout — matches csrc/layernorm_kernels.cu's + # `(scalar_t)(x * s_variance * w)` and DeepseekV4's compressor kernel, which + # keep x, rrms, and w all in fp32 and perform a single cast at store. + block = tl.arange(0, BLOCK_SIZE) + mask = block < SIZE + x = tl.load(row_in + block, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / SIZE + rrms = tl.rsqrt(variance + eps) + w = tl.load(weight_ptr + block, mask=mask, other=0.0).to(tl.float32) + y = x * rrms * w + tl.store(row_out + block, y.to(row_out.dtype.element_ty), mask=mask) + + +def fused_q_kv_rmsnorm( + qr: torch.Tensor, + kv: torch.Tensor, + q_weight: torch.Tensor, + kv_weight: torch.Tensor, + eps: float, +) -> tuple[torch.Tensor, torch.Tensor]: + assert qr.ndim == 2 and kv.ndim == 2 + assert qr.shape[0] == kv.shape[0], ( + f"token dim mismatch: qr={qr.shape}, kv={kv.shape}" + ) + assert qr.stride(-1) == 1 and kv.stride(-1) == 1 + assert q_weight.is_contiguous() and kv_weight.is_contiguous() + + q_size = qr.shape[1] + kv_size = kv.shape[1] + num_tokens = qr.shape[0] + qr_out = torch.empty_like(qr) + kv_out = torch.empty_like(kv) + if num_tokens == 0: + return qr_out, kv_out + + block_size = triton.next_power_of_2(max(q_size, kv_size)) + _fused_q_kv_rmsnorm_kernel[(num_tokens, 2)]( + qr, + qr_out, + q_weight, + qr.stride(0), + qr_out.stride(0), + kv, + kv_out, + kv_weight, + kv.stride(0), + kv_out.stride(0), + eps, + Q_SIZE=q_size, + KV_SIZE=kv_size, + BLOCK_SIZE=block_size, + ) + return qr_out, kv_out diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index eaa95dfe49f7..8b3bf85ab782 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -54,8 +54,14 @@ def __init__( metrics_collector, ) - # Needs special handling for find_longest_cache_hit if eagle is enabled - self.use_eagle = use_eagle + # KV cache group indices that get the EAGLE last-block drop. + self.eagle_group_ids: set[int] = { + i for i, g in enumerate(kv_cache_config.kv_cache_groups) if g.is_eagle_group + } + # Conservatively fall back to flag all groups when no group is flagged. + if use_eagle and not self.eagle_group_ids: + self.eagle_group_ids = set(range(len(kv_cache_config.kv_cache_groups))) + self.single_type_managers = tuple( get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_group.kv_cache_spec, @@ -357,7 +363,7 @@ def find_longest_cache_hit( kv_cache_group_ids=[0], block_pool=self.block_pool, kv_cache_spec=self.kv_cache_spec, - use_eagle=self.use_eagle, + use_eagle=0 in self.eagle_group_ids, alignment_tokens=self.block_size, dcp_world_size=self.dcp_world_size, pcp_world_size=self.pcp_world_size, @@ -450,6 +456,14 @@ def verify_and_split_kv_cache_groups(self) -> None: block_sizes = [spec.block_size for spec, _, _ in attention_groups] self.lcm_block_size = lcm(*block_sizes) + # Attention-group indices (into ``self.attention_groups``) that + # contain at least one EAGLE/MTP KV cache group. + self.eagle_attn_group_indices: set[int] = { + i + for i, (_, group_ids, _) in enumerate(self.attention_groups) + if any(gid in self.eagle_group_ids for gid in group_ids) + } + def find_longest_cache_hit( self, block_hashes: list[BlockHash], @@ -485,49 +499,62 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups # Simple hybrid (1 full attn + 1 other): one iteration suffices. - # Full attn is always first if it exists. This avoids EAGLE drops - # being applied multiple times to non-full-attn groups. - # FIXME (yifan): However, for complex hybrid models with multiple attn - # groups, we still have the EAGLE spiral block dropping problem. See - # discussion in issue https://github.com/vllm-project/vllm/issues/32802. + # Full attn is always first if it exists. is_simple_hybrid = len(self.attention_groups) == 2 and isinstance( self.attention_groups[0][0], FullAttentionSpec ) + # Attention-group indices whose EAGLE drop is verified at the current + # ``curr_hit_length``. Each eagle group applies the drop at most once + # per candidate length (see issue #32802). + eagle_verified: set[int] = set() + while True: curr_hit_length = hit_length - for spec, group_ids, manager_cls in self.attention_groups: - is_full_attn = isinstance(spec, FullAttentionSpec) - - # Full attention: reuse cached blocks (downward-closed property) + for idx, (spec, group_ids, manager_cls) in enumerate(self.attention_groups): cached_blocks = hit_blocks_by_group[group_ids[0]] - if is_full_attn and cached_blocks is not None: - # For full attention, we only need to compute the cache hit - # length once. Starting from the second iteration, if the - # curr_hit_length is reduced by other groups, we can simply - # keep the first (curr_hit_length // block_size) blocks from - # the last iteration. - num_blocks = curr_hit_length // spec.block_size - curr_hit_length = num_blocks * spec.block_size - else: - hit_blocks = manager_cls.find_longest_cache_hit( - block_hashes=_get_block_hashes(spec), - max_length=curr_hit_length, - kv_cache_group_ids=group_ids, - block_pool=self.block_pool, - kv_cache_spec=spec, - use_eagle=self.use_eagle, - alignment_tokens=self.lcm_block_size, + if isinstance(spec, FullAttentionSpec) and cached_blocks is not None: + # Full attention is downward-closed: we only need to look + # up cached blocks once; on subsequent iterations just trim + # to the (reduced) current hit length. + curr_hit_length = ( + curr_hit_length // spec.block_size * spec.block_size ) - curr_hit_length = len(hit_blocks[0]) * spec.block_size - for group_id, blocks in zip(group_ids, hit_blocks): - hit_blocks_by_group[group_id] = blocks + continue + + use_eagle = ( + idx in self.eagle_attn_group_indices and idx not in eagle_verified + ) + + _max_length = curr_hit_length + if use_eagle: + # Eagle needs to match one more block and then pop the last. + _max_length = min( + curr_hit_length + spec.block_size, max_cache_hit_length + ) + hit_blocks = manager_cls.find_longest_cache_hit( + block_hashes=_get_block_hashes(spec), + max_length=_max_length, + kv_cache_group_ids=group_ids, + block_pool=self.block_pool, + kv_cache_spec=spec, + use_eagle=use_eagle, + alignment_tokens=self.lcm_block_size, + ) + _new_hit_length = len(hit_blocks[0]) * spec.block_size + if use_eagle: + eagle_verified.add(idx) + elif _new_hit_length < curr_hit_length: + # length shrunk; invalidate previous eagle verifications + eagle_verified.clear() + curr_hit_length = _new_hit_length + for group_id, blocks in zip(group_ids, hit_blocks): + hit_blocks_by_group[group_id] = blocks if curr_hit_length >= hit_length: break hit_length = curr_hit_length - # Simple hybrid: exit after one iteration if is_simple_hybrid: break diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 8f4963fcc873..879cd0928c1e 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -4,18 +4,19 @@ import copy import hashlib +import math import os from collections import defaultdict from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass, replace from functools import partial -from typing import Any, NewType, TypeAlias, overload +from typing import Any, NewType, TypeAlias, cast, overload from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.hashing import sha256_cbor, xxhash_cbor -from vllm.utils.math_utils import cdiv +from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import format_gib from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, @@ -24,6 +25,9 @@ KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, + MambaSpec, + MLAAttentionSpec, + SlidingWindowMLASpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, ) @@ -562,6 +566,72 @@ def hash_block_tokens( ) +def resolve_kv_cache_block_sizes( + kv_cache_config: KVCacheConfig, + vllm_config: VllmConfig, +) -> tuple[int, int]: + """Resolve (scheduler_block_size, hash_block_size). + + - ``scheduler_block_size`` is the token-alignment invariant used by the + scheduler (e.g. for ``num_computed_tokens`` rounding). Single group: + ``cache_config.block_size * dcp * pcp``. Multiple groups: LCM of every + group's block size — context parallelism is not supported here. + - ``hash_block_size`` is the granularity at which ``Request.block_hashes`` + is computed. Single group: equals scheduler block size. Multiple groups: + ``cache_config.hash_block_size`` override if set, else the GCD of group + block sizes; every group's block size must be divisible by it. Returns + the scheduler block size (i.e. disables finer hashing) if block hashing + is inactive or a mamba group's block size diverges from the cache + block size (mamba_cache_mode != "align"). + """ + cache_config = vllm_config.cache_config + dcp = vllm_config.parallel_config.decode_context_parallel_size + pcp = vllm_config.parallel_config.prefill_context_parallel_size + groups = kv_cache_config.kv_cache_groups + + if len(groups) <= 1: # Single group: block_size * dcp * pcp + bs = cache_config.block_size * dcp * pcp + return bs, bs + + if dcp != 1 or pcp != 1: + raise ValueError( + "Hybrid KV cache groups with multiple block sizes do not " + "support context parallelism (dcp_world_size/pcp_world_size > 1)." + ) + + group_block_sizes = [g.kv_cache_spec.block_size for g in groups] + scheduler_block_size = math.lcm(*group_block_sizes) + + # Block hashes are only consumed by prefix caching and KV connectors + # (P/D, offloading); when neither is active, keep hash_block_size equal + # to the scheduler block size. + connector_enabled = vllm_config.kv_transfer_config is not None + if not (cache_config.enable_prefix_caching or connector_enabled): + return scheduler_block_size, scheduler_block_size + + # Mamba groups with block_size != cache_config.block_size + # (mamba_cache_mode != "align") break divisibility; back off to the + # scheduler block size. + if any( + isinstance(g.kv_cache_spec, MambaSpec) + and g.kv_cache_spec.block_size != cache_config.block_size + for g in groups + ): + return scheduler_block_size, scheduler_block_size + + requested = cache_config.hash_block_size + hash_block_size = ( + requested if requested is not None else math.gcd(*group_block_sizes) + ) + if any(bs % hash_block_size != 0 for bs in group_block_sizes): + raise ValueError( + f"Invalid hash_block_size={hash_block_size}; all KV cache group " + f"block sizes must be divisible by hash_block_size. " + f"Got group block sizes={group_block_sizes}." + ) + return scheduler_block_size, hash_block_size + + def get_request_block_hasher( block_size: int, caching_hash_fn: Callable[[Any], bytes], @@ -1089,6 +1159,63 @@ def _get_kv_cache_groups_uniform_page_size( return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) +def _get_kv_cache_config_deepseek_v4( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + available_memory: int, +) -> tuple[int, list[KVCacheTensor]]: + """DeepseekV4 KV cache tensor layout planning. + + Precondition: kv_cache_groups[0] is the full-MLA group; its page sizes + define the canonical bucket set. Non-full-MLA groups must have been + page_size-padded upstream (see _get_kv_cache_groups_uniform_groups) so + every layer's page_size matches one of the full-MLA bucket sizes. + + For each group, bucket its layers by page_size_bytes and place each + layer at tuple_idx = position-within-bucket. Emit one KVCacheTensor + per (tuple_idx, bucket) whose shared_by is the union of per-group + layers at that slot. + """ + full_mla_spec = kv_cache_groups[0].kv_cache_spec + assert isinstance(full_mla_spec, UniformTypeKVCacheSpecs) + page_sizes = sorted(full_mla_spec.get_page_sizes()) + layer_tuple_page_bytes = sum(page_sizes) + + # Pre-bucket each group's layers by page_size (registration order within + # bucket). bucketed[g_idx][page_size] = [layer_name, ...]. + bucketed: list[dict[int, list[str]]] = [] + for group in kv_cache_groups: + assert isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs) + specs = group.kv_cache_spec.kv_cache_specs + b: dict[int, list[str]] = defaultdict(list) + for name in group.layer_names: + b[specs[name].page_size_bytes].append(name) + bucketed.append(b) + + # num_layer_tuples = longest bucket list across all groups. For the + # full-MLA group this equals the count of layers in the largest + # per-page-size bucket (= get_num_layer_tuples()); for SWA sub-groups + # this equals the sub-group size (each has a single page_size). + num_layer_tuples = max(len(layers) for b in bucketed for layers in b.values()) + + num_blocks = available_memory // (layer_tuple_page_bytes * num_layer_tuples) + num_blocks = may_override_num_blocks(vllm_config, num_blocks) + + kv_cache_tensors: list[KVCacheTensor] = [] + for tuple_idx in range(num_layer_tuples): + for ps in page_sizes: + shared_by: list[str] = [] + for b in bucketed: + bucket = b.get(ps) + if bucket is not None and tuple_idx < len(bucket): + shared_by.append(bucket[tuple_idx]) + kv_cache_tensors.append( + KVCacheTensor(size=ps * num_blocks, shared_by=shared_by) + ) + + return num_blocks, kv_cache_tensors + + def get_kv_cache_config_from_groups( vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec], @@ -1120,7 +1247,7 @@ def get_kv_cache_config_from_groups( kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs ): # Special case: all layers have the same type of KV cache but with - # different hidden size. Allocate different amount of memory for each + # different hidden sizes. Allocate different amount of memory for each # layer based on its hidden size. num_blocks = ( available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes @@ -1136,6 +1263,15 @@ def get_kv_cache_config_from_groups( ) for layer_name in kv_cache_groups[0].layer_names ] + elif all( + isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs) + for group in kv_cache_groups + ): + # DeepseekV4: UniformTypeKVCacheSpecs but multiple groups. + # Delegate to the DeepseekV4-specific allocator. + num_blocks, kv_cache_tensors = _get_kv_cache_config_deepseek_v4( + vllm_config, kv_cache_groups, available_memory + ) else: # General case: # We will have group_size memory pools, each is shared by one layer from @@ -1206,9 +1342,41 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): has_chunked_local_attention = any( isinstance(spec, ChunkedLocalAttentionSpec) for spec in kv_cache_spec.values() ) + has_swa_mla = any( + isinstance(spec, SlidingWindowMLASpec) for spec in kv_cache_spec.values() + ) + + uniform_block_size: int | None = None + if has_swa_mla: + # For DeepseekV4, block sizes can be different for different KV cache groups. + # E.g., Full MLA: 256; SWA MLA: 64; C4 partial states: 4, C128 states: 8. + assert has_full_attention + any_full_spec = next( + iter( + spec + for spec in kv_cache_spec.values() + if isinstance(spec, FullAttentionSpec) + ) + ) + uniform_block_size = any_full_spec.block_size + if has_full_attention and (has_sliding_window or has_chunked_local_attention): for layer_name, spec in kv_cache_spec.items(): - if isinstance(spec, SlidingWindowSpec): + if isinstance(spec, SlidingWindowMLASpec): + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=uniform_block_size + if uniform_block_size is not None + else spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + dtype=spec.dtype, + page_size_padded=spec.page_size_padded, + cache_dtype_str=spec.cache_dtype_str, + alignment=spec.alignment, + compress_ratio=spec.compress_ratio, + model_version=spec.model_version, + ) + elif isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( block_size=spec.block_size, num_kv_heads=spec.num_kv_heads, @@ -1237,6 +1405,204 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): ) +def group_and_unify_kv_cache_specs( + kv_cache_spec: dict[str, KVCacheSpec], +) -> list[UniformTypeKVCacheSpecs] | None: + """ + Group the KV cache specs and unify each group into one UniformTypeKVCacheSpecs. + Currently, this is only used for DeepseekV4. + """ + if not any( + isinstance(spec, SlidingWindowMLASpec) for spec in kv_cache_spec.values() + ): + return None + + mla_specs: dict[str, KVCacheSpec] = {} + grouped_swa_mla_specs: dict[tuple[int, int], dict[str, KVCacheSpec]] = defaultdict( + dict + ) + # NOTE: Here we group SWA layers by (block_size, sliding_window), which separates + # SWA layers, C4I+C4A layers, and C128A layers into three different groups. It can + # be fragile with only block_size and sliding_window as keys, but fine for now. + for name, spec in kv_cache_spec.items(): + if isinstance(spec, SlidingWindowMLASpec): + grouped_swa_mla_specs[(spec.block_size, spec.sliding_window)][name] = spec + elif isinstance(spec, MLAAttentionSpec): + mla_specs[name] = spec + + assert len(mla_specs) > 0 + mla_uniform_spec = UniformTypeKVCacheSpecs.from_specs(mla_specs) + assert mla_uniform_spec is not None + + swa_uniform_specs: list[UniformTypeKVCacheSpecs] = [] + for spec_dict in grouped_swa_mla_specs.values(): + uniform_spec = UniformTypeKVCacheSpecs.from_specs(spec_dict) + assert uniform_spec is not None + swa_uniform_specs.append(uniform_spec) + + return [mla_uniform_spec, *swa_uniform_specs] + + +def _approximate_gcd(values: Sequence[int], *, lower_bound: int | None = None) -> int: + """Pick a chunk size that minimizes total upward padding. + + Each x is rounded up to a multiple of d: + + x -> ceil(x / d) * d + + Total padding is: + + pad(d) = sum_i (ceil(x_i / d) * d - x_i) + + We brute-force d in [lower_bound, max(values)] (fine for small lists / small + maxima) and return the d with minimum padding. Ties prefer larger d. + """ + if not values: + raise ValueError("values must be non-empty") + if any(x <= 0 for x in values): + raise ValueError(f"values must be positive, got: {list(values)!r}") + + min_d = max(1, lower_bound if lower_bound is not None else 1) + max_d = max(values) + if min_d > max_d: + return min_d + + best_d = min_d + best_pad: int | None = None + for d in range(min_d, max_d + 1): + pad = sum((d - (x % d)) % d for x in values) + if best_pad is None or pad < best_pad or (pad == best_pad and d > best_d): + best_pad = pad + best_d = d + + return best_d + + +def _get_kv_cache_groups_uniform_groups( + grouped_specs: list[UniformTypeKVCacheSpecs], +) -> list[KVCacheGroupSpec]: + """ + Generate the KV cache groups from the grouped specs. + """ + assert len(grouped_specs) > 0 and all( + isinstance(spec, UniformTypeKVCacheSpecs) for spec in grouped_specs + ) + # For now, we restrict the first grouped_spec to be UniformTypeKVCacheSpecs + # containing only MLAAttentionSpec. + full_mla_spec = grouped_specs[0] + assert all( + isinstance(spec, MLAAttentionSpec) + for spec in full_mla_spec.kv_cache_specs.values() + ) + full_mla_group = KVCacheGroupSpec( + layer_names=list(full_mla_spec.kv_cache_specs.keys()), + kv_cache_spec=full_mla_spec, + ) + + # We define a layer tuple as a group of layers with different page sizes, and + # one UniformTypeKVCacheSpecs contains a list of layer tuples. + # For example, if we have 11 C4 layers and 10 C128 layers, we can define a layer + # tuple as [C4I, C4A, C128], and the full_mla_group will contain "11" layer tuples. + # The other uniform KV cache specs will be similarly partitioned into layer tuples. + # Say we have 21 SWA layers, all with the same page size, then we will have "21" + # layer tuples. + num_layer_tuples_per_group: list[int] = [ + g_spec.get_num_layer_tuples() for g_spec in grouped_specs + ] + # Choose `num_layer_tuples` to minimize total padding across groups. + num_layer_tuples = _approximate_gcd( + num_layer_tuples_per_group, lower_bound=num_layer_tuples_per_group[0] + ) + # Round up to the nearest multiple of `num_layer_tuples` (i.e., padding) + num_layer_tuples_per_group = [ + round_up(x, num_layer_tuples) for x in num_layer_tuples_per_group + ] + + swa_mla_specs = grouped_specs[1:] + assert all( + isinstance(spec, SlidingWindowMLASpec) + for group in swa_mla_specs + for spec in group.kv_cache_specs.values() + ) + + # Split each SWA UniformKV group into smaller groups to align their #(layer tuples) + # Possibly padding layer tuples for this. + # Additionally, we also pad KV blocks in each SWA layer, to align the page size + # with the corresponding layer in the full-MLA group. + all_page_sizes = full_mla_spec.get_page_sizes() + swa_mla_groups = [] + for sm_spec in swa_mla_specs: + sm_page_sizes = sm_spec.get_page_sizes() + layers_per_size: dict[int, list[str]] = defaultdict(list) + assert max(sm_page_sizes) <= max(all_page_sizes) + + # Unify page size by padding layers' page_size to the nearest larger page_size. + # Compute candidate (nearest larger page_size) for each unique page size. + size_to_candidate: dict[int, int] = {} + for ps in sm_page_sizes: + size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps) + # Pad and collect layer names per page size. + for layer_name, layer_spec in sm_spec.kv_cache_specs.items(): + current_size = layer_spec.page_size_bytes + candidate = size_to_candidate[current_size] + if current_size < candidate: + object.__setattr__(layer_spec, "page_size_padded", candidate) + layers_per_size[candidate].append(layer_name) + # NOTE(yifan): for now, inside a UniformKV group, each page_size should + # have the same number of layers. This also means we don't need to pad layers + # inside a partial-full layer tuple. + assert len(set(len(layers) for layers in layers_per_size.values())) == 1 + num_layers_per_size = len(next(iter(layers_per_size.values()))) + + # Split layers inside each UniformKV group for aligned #(layers). + # See `_get_kv_cache_groups_uniform_page_size` for more details. + num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples) + layer_tuples = list(zip(*layers_per_size.values())) + for i in range(num_tuple_groups): + group_layer_tuples = layer_tuples[i::num_tuple_groups] + # Flatten tuples and build dict for from_specs + group_layer_names = [ + name for layer_tuple in group_layer_tuples for name in layer_tuple + ] + group_layer_specs = { + name: sm_spec.kv_cache_specs[name] for name in group_layer_names + } + sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs) + assert sub_sm_spec is not None + swa_mla_groups.append( + KVCacheGroupSpec( + layer_names=group_layer_names, + kv_cache_spec=sub_sm_spec, + ) + ) + + return [full_mla_group, *swa_mla_groups] + + +def _annotate_eagle_groups_deepseek_v4( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + kv_cache_groups: list[KVCacheGroupSpec], +) -> None: + spec_config = vllm_config.speculative_config + if spec_config is None or not spec_config.use_eagle(): + return + # Detection uses the merged MLA spec's model_version. + if not any( + getattr(spec, "model_version", None) == "deepseek_v4" + for spec in kv_cache_spec.values() + ): + return + # DeepseekV4's MTP attention layer is always the last layer, and we flag whichever + # group contains it. + # FIXME(yifan): avoid/generalize this hacky check. + last_layer = next(reversed(kv_cache_spec)) + for group in kv_cache_groups: + if last_layer in group.layer_names: + group.is_eagle_group = True + break + + def get_kv_cache_groups( vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] ) -> list[KVCacheGroupSpec]: @@ -1268,6 +1634,14 @@ def get_kv_cache_groups( # full attention, or all layers are sliding window attention with the # same window size). Put all layers into one group. return _get_kv_cache_groups_uniform_type(uniform_spec) + elif grouped_specs := group_and_unify_kv_cache_specs(kv_cache_spec): + # DeepseekV4 case: All layers need the same number of token slots, + # yet some layers are full attention while others are sliding window + # attention in different sizes. Need to group layers into multiple + # UniformTypeKVCacheSpecs. + kv_cache_groups = _get_kv_cache_groups_uniform_groups(grouped_specs) + _annotate_eagle_groups_deepseek_v4(vllm_config, kv_cache_spec, kv_cache_groups) + return kv_cache_groups # As KVCacheManager can only allocate memory of one size, we need to unify # the page size of the layers. For cases cannot be unified, this function @@ -1360,15 +1734,40 @@ def _max_memory_usage_bytes_from_groups( if not kv_cache_groups: return 0 - # UniformTypeKVCacheSpecs special case (single group, per-layer specs) if len(kv_cache_groups) == 1 and isinstance( kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs ): + # UniformTypeKVCacheSpecs special case (single group, per-layer specs) per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs return sum( spec.max_memory_usage_bytes(vllm_config) for spec in per_layer_specs.values() ) + elif all( + isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs) + for group in kv_cache_groups + ): + # Special case (only DeepseekV4 for now): all groups are + # UniformTypeKVCacheSpecs. + # They must already be page_size aligned and share a common padded + # layer-tuple layout. Even groups with fewer actual tuples still reserve + # the global number of tuple slots in the shared tensor layout. + full_mla_spec = cast(UniformTypeKVCacheSpecs, kv_cache_groups[0].kv_cache_spec) + layer_tuple_bytes = sum(full_mla_spec.get_page_sizes()) + num_layer_tuples = max( + cast(UniformTypeKVCacheSpecs, group.kv_cache_spec).get_num_layer_tuples() + for group in kv_cache_groups + ) + + total_max_mem_usage_bytes = 0 + for group in kv_cache_groups: + group_spec = cast(UniformTypeKVCacheSpecs, group.kv_cache_spec) + g_max_mem_usage_pages = group_spec.max_memory_usage_pages(vllm_config) + g_max_mem_usage_page_bytes = ( + num_layer_tuples * g_max_mem_usage_pages * layer_tuple_bytes + ) + total_max_mem_usage_bytes += g_max_mem_usage_page_bytes + return total_max_mem_usage_bytes # General case: group_size pools, each shared by one layer per group # Memory = group_size * page_size * blocks_for_max_len @@ -1515,7 +1914,13 @@ def _project_kv_cache_groups_to_worker( for layer_name in worker_layer_names }, ) - projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec)) + projected_groups.append( + KVCacheGroupSpec( + worker_layer_names, + group_spec, + is_eagle_group=group.is_eagle_group and bool(worker_layer_names), + ) + ) return projected_groups @@ -1698,10 +2103,7 @@ def __iter__(self) -> Iterator[BlockHash]: def _get_value_at(self, idx: int) -> BlockHash: base = idx * self.scale_factor end = base + self.scale_factor - merged_hash: bytes = self.block_hashes[base] - for i in range(base + 1, end): - merged_hash += self.block_hashes[i] - return BlockHash(merged_hash) + return BlockHash(b"".join(self.block_hashes[base:end])) BlockHashList = list[BlockHash] | BlockHashListWithBlockSize diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index b44f2db1926b..264811a556d3 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -41,6 +41,7 @@ def __init__( kv_cache_config: "KVCacheConfig", structured_output_manager: "StructuredOutputManager", block_size: int, + hash_block_size: int, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 40b5899f0457..2f22adf8a8e6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -71,6 +71,7 @@ def __init__( kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, block_size: int, + hash_block_size: int | None = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, @@ -222,6 +223,8 @@ def __init__( self.num_lookahead_tokens = self.num_spec_tokens # Create the KV cache manager. + if hash_block_size is None: + hash_block_size = block_size self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, @@ -231,7 +234,7 @@ def __init__( enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, pcp_world_size=self.pcp_world_size, - hash_block_size=self.block_size, + hash_block_size=hash_block_size, metrics_collector=self.kv_metrics_collector, ) # Bind GPU block pool to the KV connector. This must happen after @@ -2018,7 +2021,7 @@ def _connector_finished( # the connector. self.kv_cache_manager.remove_skipped_blocks( request_id=request.request_id, - total_computed_tokens=request.num_tokens, + total_computed_tokens=request.num_computed_tokens, ) block_ids = self.kv_cache_manager.get_block_ids(request.request_id) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 30061462008f..f63e954400eb 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -20,6 +20,7 @@ MambaSpec, MLAAttentionSpec, SinkFullAttentionSpec, + SlidingWindowMLASpec, SlidingWindowSpec, TQFullAttentionSpec, ) @@ -534,12 +535,10 @@ def find_longest_cache_hit( ): # Skip prefix matching check if the block is not aligned with # `alignment_tokens`. - if ( - num_contiguous_blocks == 0 - and block_size != alignment_tokens # Faster for common case. - and (i + 1) * block_size % alignment_tokens != 0 - ): - continue + if num_contiguous_blocks == 0 and block_size != alignment_tokens: + post_pop_blocks = i if use_eagle else i + 1 + if (post_pop_blocks * block_size) % alignment_tokens != 0: + continue # Add the cached block to the computed blocks. for computed, cached in zip(computed_blocks, cached_block): computed[i] = cached @@ -1118,6 +1117,7 @@ def __init__( TQFullAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, + SlidingWindowMLASpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, CrossAttentionSpec: CrossAttentionManager, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0be3273c5aca..36864ba738bf 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -45,6 +45,7 @@ get_kv_cache_configs, get_request_block_hasher, init_none_hash, + resolve_kv_cache_block_sizes, ) from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput @@ -137,10 +138,8 @@ def __init__( logger.warning("Disabling chunked prefill for model without KVCache") vllm_config.scheduler_config.enable_chunked_prefill = False - scheduler_block_size = ( - vllm_config.cache_config.block_size - * vllm_config.parallel_config.decode_context_parallel_size - * vllm_config.parallel_config.prefill_context_parallel_size + scheduler_block_size, hash_block_size = resolve_kv_cache_block_sizes( + kv_cache_config, vllm_config ) self.scheduler: SchedulerInterface = Scheduler( @@ -150,6 +149,7 @@ def __init__( include_finished_set=include_finished_set, log_stats=self.log_stats, block_size=scheduler_block_size, + hash_block_size=hash_block_size, ) self.use_spec_decode = vllm_config.speculative_config is not None if self.scheduler.connector is not None: # type: ignore @@ -207,7 +207,7 @@ def __init__( init_none_hash(caching_hash_fn) self.request_block_hasher = get_request_block_hasher( - scheduler_block_size, caching_hash_fn + hash_block_size, caching_hash_fn ) self.step_fn = ( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index bc8422d4f4b5..1d296c0bb4c1 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -4,6 +4,7 @@ from __future__ import annotations import copy +from collections import Counter from dataclasses import dataclass, fields, replace from enum import IntEnum from math import prod @@ -13,11 +14,11 @@ from typing_extensions import Self from vllm.logger import init_logger +from vllm.utils.math_utils import cdiv, round_up +from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim if TYPE_CHECKING: from vllm.config import VllmConfig -from vllm.utils.math_utils import cdiv -from vllm.utils.torch_utils import get_dtype_size, nvfp4_kv_cache_full_dim logger = init_logger(__name__) @@ -95,6 +96,10 @@ def page_size_bytes(self) -> int: """ raise NotImplementedError + @property + def storage_block_size(self) -> int: + return self.block_size + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: """ The maximum possible memory usage of this KV cache in bytes. @@ -269,6 +274,15 @@ def real_page_size_bytes(self) -> int: ) +def _apply_alignment_padding(spec: MLAAttentionSpec | SlidingWindowMLASpec): + if spec.alignment is None: + return + actual_page_size = spec.real_page_size_bytes + padded_page_size = round_up(actual_page_size, spec.alignment) + if padded_page_size != actual_page_size: + object.__setattr__(spec, "page_size_padded", padded_page_size) + + @dataclass(frozen=True, kw_only=True) class TQFullAttentionSpec(FullAttentionSpec): """FullAttentionSpec with TQ-aware page size. @@ -299,15 +313,31 @@ def merge(cls, specs: list[Self]) -> Self: class MLAAttentionSpec(FullAttentionSpec): # TODO(Lucas/Chen): less hacky way to do this cache_dtype_str: str | None = None + # DeepseekV4 only fields. Non-DeepseekV4 MLA models leave these at defaults. + alignment: int | None = None # Default to None for no padding. + compress_ratio: int = 1 # Default to 1 for no compression. + model_version: str | None = None + + def __post_init__(self): + super().__post_init__() + _apply_alignment_padding(self) + + @property + def storage_block_size(self) -> int: + return self.block_size // self.compress_ratio @property def real_page_size_bytes(self) -> int: if self.cache_dtype_str == "fp8_ds_mla": - # See `vllm/v1/attention/backends/mla/flashmla_sparse.py` - # for details. + if self.model_version == "deepseek_v4": + # DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token. + # head_size stays semantic (512); bytes are determined here. + return self.storage_block_size * 584 + # V3.2 main MLA: 656-byte custom layout (kv_lora_rank=512 + + # qk_rope_head_dim=64, head_size=576). See flashmla_sparse.py. return self.block_size * 656 return ( - self.block_size + self.storage_block_size * self.num_kv_heads * self.head_size * get_dtype_size(self.dtype) @@ -319,9 +349,15 @@ def merge(cls, specs: list[Self]) -> Self: "All attention layers in the same KV cache group must be MLAAttentionSpec." ) cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) - assert len(cache_dtype_str_set) == 1, ( + compress_ratio_set = set(spec.compress_ratio for spec in specs) + model_version_set = set(spec.model_version for spec in specs) + assert ( + len(cache_dtype_str_set) == 1 + and len(compress_ratio_set) == 1 + and len(model_version_set) == 1 + ), ( "All attention layers in the same KV cache group must use the same " - "quantization method." + "quantization method, compress ratio, and model version." ) return cls( block_size=specs[0].block_size, @@ -331,6 +367,8 @@ def merge(cls, specs: list[Self]) -> Self: kv_quant_mode=specs[0].kv_quant_mode, page_size_padded=specs[0].page_size_padded, cache_dtype_str=cache_dtype_str_set.pop(), + compress_ratio=compress_ratio_set.pop(), + model_version=model_version_set.pop(), ) @@ -393,6 +431,71 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes +@dataclass(frozen=True, kw_only=True) +class SlidingWindowMLASpec(SlidingWindowSpec): + """Sliding window attention with MLA cache format.""" + + cache_dtype_str: str | None = None + # DeepseekV4-only: see MLAAttentionSpec.model_version. + alignment: int | None = None # Default to None for no padding. + compress_ratio: int = 1 + model_version: str | None = None + + def __post_init__(self): + _apply_alignment_padding(self) + + @property + def storage_block_size(self) -> int: + return self.block_size // self.compress_ratio + + @property + def real_page_size_bytes(self) -> int: + if self.model_version == "deepseek_v4": + # DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token. + return self.storage_block_size * 584 + assert self.model_version is None, ( + f"Unsupported model version: {self.model_version}" + ) + return ( + self.storage_block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + assert all(isinstance(spec, SlidingWindowMLASpec) for spec in specs), ( + "All attention layers in the same KV cache group must be " + "SlidingWindowMLASpec." + ) + cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) + compress_ratio_set = set(spec.compress_ratio for spec in specs) + model_version_set = set(spec.model_version for spec in specs) + sliding_window_set = set(spec.sliding_window for spec in specs) + assert ( + len(cache_dtype_str_set) == 1 + and len(compress_ratio_set) == 1 + and len(model_version_set) == 1 + and len(sliding_window_set) == 1 + ), ( + "All attention layers in the same KV cache group must use the same " + "quantization method, compress ratio, model version and sliding " + "window size." + ) + return cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + dtype=specs[0].dtype, + page_size_padded=specs[0].page_size_padded, + sliding_window=sliding_window_set.pop(), + cache_dtype_str=cache_dtype_str_set.pop(), + compress_ratio=compress_ratio_set.pop(), + model_version=model_version_set.pop(), + ) + + @dataclass(frozen=True) class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] @@ -527,7 +630,17 @@ def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool: # Different block sizes, not uniform. return False one_spec = next(iter(kv_cache_specs.values())) - if isinstance(one_spec, FullAttentionSpec): + # NOTE: Check subclasses before parent classes since isinstance() + # returns True for subclasses. + if isinstance(one_spec, SlidingWindowMLASpec): + # SlidingWindowMLASpec is uniform if all specs are SlidingWindowMLASpec + # with the same sliding_window size. + return all( + isinstance(spec, SlidingWindowMLASpec) + and spec.sliding_window == one_spec.sliding_window + for spec in kv_cache_specs.values() + ) + elif isinstance(one_spec, FullAttentionSpec): return all( isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values() ) @@ -571,6 +684,21 @@ def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None: else: return None + # NOTE: below util functions are only used by DeepseekV4 for now. + def get_page_sizes(self) -> list[int]: + return list(set(spec.page_size_bytes for spec in self.kv_cache_specs.values())) + + def get_num_layer_tuples(self) -> int: + return Counter( + spec.page_size_bytes for spec in self.kv_cache_specs.values() + ).most_common(1)[0][1] + + def max_memory_usage_pages(self, vllm_config: VllmConfig) -> int: + return max( + cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes) + for spec in self.kv_cache_specs.values() + ) + @dataclass class KVCacheTensor: @@ -593,6 +721,8 @@ class KVCacheGroupSpec: layer_names: list[str] # The KV cache spec of this manager layer kv_cache_spec: KVCacheSpec + # Whether this group contains EAGLE/MTP draft attention layers. + is_eagle_group: bool = False @dataclass diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index 44156b60c0da..94e09b209cfd 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -84,6 +84,15 @@ def __init__( self.hidden_size = self.draft_model_config.get_hidden_size() self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() + # DeepSeek V4 MTP consumes the target's pre-hc_head residual stream, + # shape (T, hc_mult * hidden_size). Expand the hidden_states buffer + # so target_hidden_states fits; detect DeepseekV4 via draft hf_config. + draft_hf_config = self.draft_model_config.hf_config + if hasattr(draft_hf_config, "compress_ratios") and hasattr( + draft_hf_config, "hc_mult" + ): + self.hidden_size = self.hidden_size * draft_hf_config.hc_mult + # Unifying eagle, draft model, and parallel drafting support. # DFlash always uses parallel drafting (all tokens in one pass), # but has an additional slot for the next_token_id (does not shift like EAGLE) @@ -1308,9 +1317,12 @@ def load_model(self, target_model: nn.Module) -> None: self.vllm_config, AttentionLayerBase, # type: ignore[type-abstract] ) - self._draft_attn_layer_names = ( - set(all_attn_layers.keys()) - target_attn_layer_names - ) + # Filter to only layers that have KV cache specs. + self._draft_attn_layer_names = { + name + for name in (set(all_attn_layers.keys()) - target_attn_layer_names) + if all_attn_layers[name].get_kv_cache_spec(self.vllm_config) is not None + } if self.supports_mm_inputs: # Even if the target model is multimodal, we can also use @@ -1514,6 +1526,17 @@ def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None: "Shared target model lm_head with MTP shared_head.head." ) + if hasattr(target_language_model.model, "topk_indices_buffer"): + if hasattr(self.model.model, "topk_indices_buffer"): + del self.model.model.topk_indices_buffer + self.model.model.topk_indices_buffer = ( + target_language_model.model.topk_indices_buffer + ) + logger.info( + "Detected MTP model with topk_indices_buffer. " + "Sharing target model topk_indices_buffer with the draft model." + ) + if self.use_local_argmax_reduction: if not hasattr(self.model, "get_top_tokens"): raise ValueError( diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 354be3cd2a40..732219a4bafb 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -20,6 +20,11 @@ KVCacheSpec, UniformTypeKVCacheSpecs, ) +from vllm.v1.worker.kv_cache_view_utils import ( + get_kv_cache_block_axis, + get_kv_cache_stride_order, + view_kv_cache_with_layout, +) from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache @@ -162,29 +167,32 @@ def _reshape_kv_cache( attn_backend = attn_backends[layer_name] kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, - kv_cache_spec.block_size, + kv_cache_spec.storage_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype, ) - # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - - kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) - inv_order = [ - kv_cache_stride_order.index(i) - for i in range(len(kv_cache_stride_order)) - ] - - dtype = kv_cache_spec.dtype - raw_tensor = raw_tensor.view(dtype) - raw_tensor = raw_tensor.view(kv_cache_shape) - kv_caches[layer_name] = raw_tensor.permute(*inv_order) + kv_cache_stride_order = get_kv_cache_stride_order( + attn_backend, + kv_cache_shape, + ) + block_axis = get_kv_cache_block_axis( + attn_backend, + kv_cache_spec.storage_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype, + ) + kv_caches[layer_name] = view_kv_cache_with_layout( + raw_tensor=raw_tensor, + kv_cache_shape=kv_cache_shape, + kv_cache_stride_order=kv_cache_stride_order, + block_axis=block_axis, + dtype=kv_cache_spec.dtype, + page_size_bytes=kv_cache_spec.page_size_bytes, + page_size_padded=kv_cache_spec.page_size_padded, + ) return kv_caches @@ -230,6 +238,7 @@ def build_attn_metadata( seq_lens_cpu_upper_bound: torch.Tensor | None = None, dcp_local_seq_lens: torch.Tensor | None = None, encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None, + positions: torch.Tensor | None = None, ) -> dict[str, Any]: seq_lens = seq_lens[:num_reqs] if dcp_local_seq_lens is not None: @@ -256,6 +265,7 @@ def build_attn_metadata( slot_mapping=slot_mapping, causal=True, dcp_local_seq_lens=dcp_local_seq_lens, + positions=positions, ) if encoder_seq_lens and i in encoder_seq_lens: encoder_seq_lens_gpu, encoder_seq_lens_cpu = encoder_seq_lens[i] diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index b1bf56ec16b2..14fb51587cb3 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -486,11 +486,20 @@ def _dummy_run( device=self.device, ), ) + + # Let the target override the hidden state fed to the drafter + # (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). The + # target returns a persistent buffer sized at max_num_batched_tokens; + # slice to the active token count that propose() expects. + spec_hidden_states = hidden_states + if hasattr(self.model, "get_mtp_target_hidden_states"): + pre_hc_hidden_states = self.model.get_mtp_target_hidden_states() + spec_hidden_states = pre_hc_hidden_states[: hidden_states.shape[0]] # type: ignore[union-attr] self.speculator.propose( input_batch=input_batch, attn_metadata=attn_metadata, slot_mappings=slot_mappings_by_layer, - last_hidden_states=hidden_states, + last_hidden_states=spec_hidden_states, aux_hidden_states=aux_hidden_states, num_sampled=torch.ones( input_batch.num_reqs, dtype=torch.int32, device=self.device @@ -808,7 +817,6 @@ def prepare_inputs( out=seq_lens_cpu_upper_bound_np[:num_reqs], ) seq_lens_cpu_upper_bound = torch.from_numpy(seq_lens_cpu_upper_bound_np) - return InputBatch( req_ids=req_ids, num_reqs=num_reqs, @@ -1233,11 +1241,19 @@ def sample_tokens( if self.speculator is not None: assert self.sampler is not None + # Let the target override the hidden state fed to the drafter + # (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). The + # target returns a persistent buffer sized at max_num_batched_tokens; + # slice to the active token count that propose() expects. + spec_hidden_states = hidden_states + if hasattr(self.model, "get_mtp_target_hidden_states"): + pre_hc_hidden_states = self.model.get_mtp_target_hidden_states() + spec_hidden_states = pre_hc_hidden_states[: hidden_states.shape[0]] # type: ignore[union-attr] draft_tokens = self.speculator.propose( input_batch, attn_metadata, slot_mappings_by_layer, - hidden_states, + spec_hidden_states, aux_hidden_states, num_sampled, num_rejected, diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 5d36b12f9c27..1b8ee066eeff 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -193,5 +193,6 @@ def prepare_attn( kv_cache_config=kv_cache_config, seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, dcp_local_seq_lens=input_batch.dcp_local_seq_lens, + positions=input_batch.positions, ) return attn_metadata diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index b1126983539f..c3e5cf1f6888 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -53,6 +53,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() + # Widen for HC-multiplexed residuals (e.g. DeepSeek V4 feeds the MTP + # draft the target's pre-hc_head (T, hc_mult * hidden_size) residual). + # Non-HC models default to hc_mult=1 and are unaffected. + hc_mult = getattr(self.draft_model_config.hf_config, "hc_mult", 1) + self.hidden_size = self.hidden_size * hc_mult self.vocab_size = self.draft_model_config.get_vocab_size() self.dtype = vllm_config.model_config.dtype diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/utils.py b/vllm/v1/worker/gpu/spec_decode/eagle/utils.py index ee37eadb2a8e..39514ee7e91f 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/utils.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/utils.py @@ -49,4 +49,19 @@ def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Mod del eagle_model.lm_head eagle_model.lm_head = target_model.lm_head + # MTP models call compute_logits via shared_head.head (a + # ParallelLMHead inside each MTP layer), not self.model.lm_head. + # If the checkpoint omits a copy of the lm_head weights at the + # MTP layer path, shared_head.head stays uninitialised and + # produces zero/NaN logits. Share it explicitly from the target. + inner = getattr(eagle_model, "model", None) + layers = getattr(inner, "layers", None) if inner is not None else None + if layers is not None: + items = layers.values() if isinstance(layers, nn.ModuleDict) else layers + for layer in items: + sh = getattr(layer, "shared_head", None) + if sh is not None and hasattr(sh, "head"): + del sh.head + sh.head = target_model.lm_head + return eagle_model diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index b2683966b315..6268ea0ba673 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -104,6 +104,7 @@ def add_request( self.num_computed_prefill_tokens[req_idx] = num_computed_tokens self.num_computed_tokens_np[req_idx] = num_computed_tokens self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens) + self.num_computed_tokens_np[req_idx] = num_computed_tokens if num_computed_tokens > 0 and num_computed_tokens <= prefill_len: # For PD disagg or resumed requests: set last_sampled to the last diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e9b72ae7c920..1445428f54a4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -191,6 +191,11 @@ from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper +from vllm.v1.worker.kv_cache_view_utils import ( + get_kv_cache_block_axis, + get_kv_cache_stride_order, + view_kv_cache_with_layout, +) from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.ubatch_utils import ( @@ -2184,6 +2189,7 @@ def _get_block_table(kv_cache_gid: int): slot_mapping=slot_mapping_gid_0, causal=True, is_prefilling=is_prefilling, + positions=self.positions[:num_tokens_padded], ) if self.dcp_world_size > 1: @@ -4671,6 +4677,16 @@ def propose_draft_token_ids( next_token_ids, valid_sampled_tokens_count ) + # Let the target override the hidden state fed to the drafter + # (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). Safe to + # rebind here: hidden_states was already consumed for sampling + # above and is not used again in this branch. + alt = getattr( + self.get_model(), "get_mtp_target_hidden_states", lambda: None + )() + if alt is not None: + hidden_states = alt + num_rejected_tokens_gpu = None if spec_decode_metadata is None: token_indices_to_sample = None @@ -6587,38 +6603,42 @@ def _reshape_kv_cache_tensors( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block + # For MLA with compression, storage_block_size != block_size + if kv_cache_spec.storage_block_size != kv_cache_spec.block_size: + shape_block_size = kv_cache_spec.storage_block_size + else: + shape_block_size = kernel_block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, - kernel_block_size, + shape_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=self.cache_config.cache_dtype, ) dtype = kv_cache_spec.dtype - try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - # The allocation respects the backend-defined stride order - # to ensure the semantic remains consistent for each - # backend. We first obtain the generic kv cache shape and - # then permute it according to the stride order which could - # result in a non-contiguous tensor. - kv_cache_shape = tuple( - kv_cache_shape[i] for i in kv_cache_stride_order + kv_cache_stride_order = get_kv_cache_stride_order( + attn_backend, + kv_cache_shape, ) - # Maintain original KV shape view. - inv_order = [ - kv_cache_stride_order.index(i) - for i in range(len(kv_cache_stride_order)) - ] - kv_caches[layer_name] = ( - kv_cache_raw_tensors[layer_name] - .view(dtype) - .view(kv_cache_shape) - .permute(*inv_order) + block_axis = get_kv_cache_block_axis( + attn_backend, + shape_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + self.cache_config.cache_dtype, ) + assert kv_cache_shape[block_axis] == kernel_num_blocks + kv_caches[layer_name] = view_kv_cache_with_layout( + raw_tensor=kv_cache_raw_tensors[layer_name], + kv_cache_shape=kv_cache_shape, + kv_cache_stride_order=kv_cache_stride_order, + block_axis=block_axis, + dtype=dtype, + page_size_bytes=kv_cache_spec.page_size_bytes, + page_size_padded=kv_cache_spec.page_size_padded, + ) + elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] diff --git a/vllm/v1/worker/kv_cache_view_utils.py b/vllm/v1/worker/kv_cache_view_utils.py new file mode 100644 index 000000000000..92fc00eabd56 --- /dev/null +++ b/vllm/v1/worker/kv_cache_view_utils.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +import torch + +from vllm.utils.torch_utils import get_dtype_size +from vllm.v1.attention.backend import AttentionBackend + + +def get_kv_cache_stride_order( + attn_backend: type[AttentionBackend], + kv_cache_shape: Sequence[int], +) -> tuple[int, ...]: + try: + stride_order = attn_backend.get_kv_cache_stride_order() + assert len(stride_order) == len(kv_cache_shape) + return stride_order + except (AttributeError, NotImplementedError): + return tuple(range(len(kv_cache_shape))) + + +def get_kv_cache_block_axis( + attn_backend: type[AttentionBackend], + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str, +) -> int: + try: + return attn_backend.get_kv_cache_block_axis( + block_size, + num_kv_heads, + head_size, + cache_dtype_str=cache_dtype_str, + ) + except (AttributeError, NotImplementedError): + return 0 + + +def _paged_kv_cache_strides( + ordered_shape: Sequence[int], + ordered_block_axis: int, + page_stride: int, +) -> tuple[int, ...]: + """Build strides for a block-paged raw allocation. + + The block axis crosses page boundaries. All other physical axes are laid out + contiguously inside a page, regardless of whether they appear before or + after the block axis in the backend's stride order. + """ + strides = [0] * len(ordered_shape) + inner_stride = 1 + for axis in reversed(range(len(ordered_shape))): + if axis == ordered_block_axis: + continue + strides[axis] = inner_stride + inner_stride *= ordered_shape[axis] + + assert page_stride >= inner_stride + strides[ordered_block_axis] = page_stride + return tuple(strides) + + +def view_kv_cache_with_layout( + *, + raw_tensor: torch.Tensor, + kv_cache_shape: Sequence[int], + kv_cache_stride_order: Sequence[int], + block_axis: int, + dtype: torch.dtype, + page_size_bytes: int, + page_size_padded: int | None, +) -> torch.Tensor: + """View a raw KV allocation as the backend's semantic cache shape.""" + ordered_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + inv_order = tuple( + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ) + + raw_tensor = raw_tensor.view(dtype) + if page_size_padded is None: + ordered_kv_cache = raw_tensor.view(ordered_shape) + else: + ordered_block_axis = kv_cache_stride_order.index(block_axis) + dtype_size = get_dtype_size(dtype) + page_stride = page_size_bytes // dtype_size + strides = _paged_kv_cache_strides( + ordered_shape, + ordered_block_axis, + page_stride, + ) + ordered_kv_cache = torch.as_strided( + raw_tensor, + size=ordered_shape, + stride=strides, + ) + + return ordered_kv_cache.permute(*inv_order)