diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 2f2e7ecc1829..2f817bc5e6a5 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -91,6 +91,32 @@ constexpr float kFp8Max = 448.0f; constexpr int kNumLanes = 32; constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16 +__device__ __forceinline__ uint4 packFp8E4M3x16(float const* values, + float const scale) { +#ifndef USE_ROCM + uint4 out; + auto* out2 = reinterpret_cast<__nv_fp8x2_storage_t*>(&out); + #pragma unroll + for (int i = 0; i < kElemsPerLane / 2; i++) { + float2 scaled = + make_float2(values[2 * i] * scale, values[2 * i + 1] * scale); + scaled.x = fminf(fmaxf(scaled.x, -kFp8Max), kFp8Max); + scaled.y = fminf(fmaxf(scaled.y, -kFp8Max), kFp8Max); + out2[i] = __nv_cvt_float2_to_fp8x2(scaled, __NV_SATFINITE, __NV_E4M3); + } + return out; +#else + uint8_t out_bytes[kElemsPerLane]; + #pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + float scaled = values[i] * scale; + scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max); + out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); + } + return *reinterpret_cast(out_bytes); +#endif +} + // ──────────────────────────────────────────────────────────────────────────── // Small inline helpers // ──────────────────────────────────────────────────────────────────────────── @@ -127,20 +153,27 @@ __device__ __forceinline__ float warpSum(float val) { // them). The KV branch only inserts the first `num_tokens_insert` tokens // (= slot_mapping length) into the paged cache. // -template +template __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( - scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place - scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16 - uint8_t* __restrict__ k_cache, // [num_blocks, block_stride] - int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64 - int64_t const* __restrict__ position_ids, // [N] i64 - float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 + scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16 + uint8_t* __restrict__ q_fp8_out, // [N, H, 512] fp8, optional + int64_t const q_fp8_stride0, // elements, fp8 == bytes + int64_t const q_fp8_stride1, // elements, fp8 == bytes + scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16 + uint8_t* __restrict__ k_cache, // legacy uint8 or full fp8 + int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64 + int64_t const* __restrict__ position_ids, // [N] i64 + float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 + float const* __restrict__ fp8_scale_ptr, // scalar, full-cache fp8 only + float const* __restrict__ q_fp8_scale_inv, // scalar, q fp8 only float const eps, - int const num_tokens_full, // = q.size(0) = kv.size(0) - int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full - int const num_heads_q, // H - int const cache_block_size, // tokens per paged-cache block - int const kv_block_stride) { // bytes per paged-cache block + int const num_tokens_full, // = q.size(0) = kv.size(0) + int const num_tokens_insert, // = slot_mapping.size(0), ≤ num_tokens_full + int const num_heads_q, // H + int const cache_block_size, // tokens per paged-cache block + int64_t const kv_block_stride, // bytes per paged-cache block + int64_t const kv_token_stride) { // bytes per token, unused by legacy #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) // BF16 _typeConvert specialization is unavailable on pre-Ampere. The // DeepseekV4 kernel only runs with bf16 inputs in practice, so compile a @@ -256,30 +289,41 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( } // ═══════════════════════════════════════════════════════════════════════ - // Q branch: cast to bf16 and store back in place. + // Q branch: cast and store. Legacy writes bf16 in place. Full-cache + // per-tensor-FP8 writes q_fp8 and leaves q unchanged. // ═══════════════════════════════════════════════════════════════════════ if (!isKV) { - uint4 out0, out1; - typename Converter::packed_hip_type* po0 = - reinterpret_cast(&out0); - typename Converter::packed_hip_type* po1 = - reinterpret_cast(&out1); + if constexpr (STORE_Q_FP8) { + float const scale_inv = VLLM_LDG(q_fp8_scale_inv); + uint4 const out = packFp8E4M3x16(elements, scale_inv); + uint8_t* dst = q_fp8_out + + static_cast(tokenIdx) * q_fp8_stride0 + + static_cast(slotIdx) * q_fp8_stride1 + dim_base; + *reinterpret_cast(dst) = out; + } else { + uint4 out0, out1; + typename Converter::packed_hip_type* po0 = + reinterpret_cast(&out0); + typename Converter::packed_hip_type* po1 = + reinterpret_cast(&out1); #pragma unroll - for (int i = 0; i < 4; i++) { - po0[i] = Converter::convert( - make_float2(elements[2 * i], elements[2 * i + 1])); - } + for (int i = 0; i < 4; i++) { + po0[i] = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + } #pragma unroll - for (int i = 0; i < 4; i++) { - po1[i] = Converter::convert( - make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + for (int i = 0; i < 4; i++) { + po1[i] = Converter::convert( + make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + } + scalar_t_in* dst = + q_inout + + (static_cast(tokenIdx) * num_heads_q + slotIdx) * + kHeadDim + + dim_base; + *reinterpret_cast(dst) = out0; + *reinterpret_cast(dst + 8) = out1; } - scalar_t_in* dst = - q_inout + - (static_cast(tokenIdx) * num_heads_q + slotIdx) * kHeadDim + - dim_base; - *reinterpret_cast(dst) = out0; - *reinterpret_cast(dst + 8) = out1; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) cudaTriggerProgrammaticLaunchCompletion(); #endif @@ -299,6 +343,39 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( int64_t const block_idx = slot_id / cache_block_size; int64_t const pos_in_block = slot_id % cache_block_size; + if constexpr (STORE_FULL_CACHE) { + uint8_t* cache_row = k_cache + block_idx * kv_block_stride + + pos_in_block * kv_token_stride; + if constexpr (STORE_KV_FP8) { + float const inv_scale = 1.0f / VLLM_LDG(fp8_scale_ptr); + uint4 const out = packFp8E4M3x16(elements, inv_scale); + *reinterpret_cast(cache_row + dim_base) = out; + } else { + uint4 out0, out1; + typename Converter::packed_hip_type* po0 = + reinterpret_cast(&out0); + typename Converter::packed_hip_type* po1 = + reinterpret_cast(&out1); +#pragma unroll + for (int i = 0; i < 4; i++) { + po0[i] = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + po1[i] = Converter::convert( + make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + } + scalar_t_in* dst = reinterpret_cast(cache_row) + dim_base; + *reinterpret_cast(dst) = out0; + *reinterpret_cast(dst + 8) = out1; + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + return; + } + uint8_t* block_base = k_cache + block_idx * static_cast(kv_block_stride); uint8_t* token_fp8_ptr = block_base + pos_in_block * kTokenDataBytes; @@ -431,18 +508,127 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( config.numAttrs = (sm_version >= 90) ? 1 : 0; cudaLaunchKernelEx( - &config, fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, - q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, - num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, - kv_block_stride); + &config, + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, + q_inout, nullptr, 0, 0, kv_in, k_cache, slot_mapping, position_ids, + cos_sin_cache, nullptr, nullptr, eps, num_tokens_full, num_tokens_insert, + num_heads_q, cache_block_size, kv_block_stride, 0); #else // ROCm: use standard kernel launch syntax (no PDL/stream serialization) // clang-format off - fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + <<>>( + q_inout, nullptr, 0, 0, kv_in, k_cache, slot_mapping, position_ids, + cos_sin_cache, nullptr, nullptr, eps, num_tokens_full, + num_tokens_insert, num_heads_q, cache_block_size, kv_block_stride, 0); +#endif +} + +template +void launchFusedDeepseekV4QNormRopeFullCacheBF16Insert( + scalar_t_in* q_inout, scalar_t_in const* kv_in, uint8_t* k_cache, + int64_t const* slot_mapping, int64_t const* position_ids, + float const* cos_sin_cache, float const eps, int const num_tokens_full, + int const num_tokens_insert, int const num_heads_q, + int const cache_block_size, int64_t const kv_block_stride, + int64_t const kv_token_stride, cudaStream_t stream) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; + int64_t const total_warps = + static_cast(num_tokens_full) * (num_heads_q + 1); + int const grid = + static_cast((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock); + +#ifndef USE_ROCM + static int const sm_version = getSMVersion(); + TORCH_CHECK( + sm_version >= 80, + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert requires " + "sm_80+ (Ampere or newer); got sm_", + sm_version); + cudaLaunchConfig_t config; + config.gridDim = dim3(grid); + config.blockDim = dim3(kBlockSize); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attrs; + config.numAttrs = (sm_version >= 90) ? 1 : 0; + + cudaLaunchKernelEx( + &config, + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, + q_inout, nullptr, 0, 0, kv_in, k_cache, slot_mapping, position_ids, + cos_sin_cache, nullptr, nullptr, eps, num_tokens_full, num_tokens_insert, + num_heads_q, cache_block_size, kv_block_stride, kv_token_stride); +#else + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel <<>>( - q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, - eps, num_tokens_full, num_tokens_insert, num_heads_q, - cache_block_size, kv_block_stride); + q_inout, nullptr, 0, 0, kv_in, k_cache, slot_mapping, position_ids, + cos_sin_cache, nullptr, nullptr, eps, num_tokens_full, + num_tokens_insert, num_heads_q, cache_block_size, kv_block_stride, + kv_token_stride); +#endif +} + +template +void launchFusedDeepseekV4QNormRopeFullCacheFP8Insert( + scalar_t_in* q_in, scalar_t_in const* kv_in, uint8_t* q_fp8_out, + int64_t const q_fp8_stride0, int64_t const q_fp8_stride1, + uint8_t* k_cache, int64_t const* slot_mapping, + int64_t const* position_ids, float const* cos_sin_cache, + float const* fp8_scale, float const* q_fp8_scale_inv, float const eps, + int const num_tokens_full, int const num_tokens_insert, + int const num_heads_q, int const cache_block_size, + int64_t const kv_block_stride, int64_t const kv_token_stride, + cudaStream_t stream) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; + int64_t const total_warps = + static_cast(num_tokens_full) * (num_heads_q + 1); + int const grid = + static_cast((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock); + +#ifndef USE_ROCM + static int const sm_version = getSMVersion(); + TORCH_CHECK( + sm_version >= 80, + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert requires " + "sm_80+ (Ampere or newer); got sm_", + sm_version); + cudaLaunchConfig_t config; + config.gridDim = dim3(grid); + config.blockDim = dim3(kBlockSize); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.attrs = attrs; + config.numAttrs = (sm_version >= 90) ? 1 : 0; + + cudaLaunchKernelEx( + &config, + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, + q_in, q_fp8_out, q_fp8_stride0, q_fp8_stride1, kv_in, k_cache, + slot_mapping, position_ids, cos_sin_cache, fp8_scale, q_fp8_scale_inv, + eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, + kv_block_stride, kv_token_stride); +#else + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + <<>>( + q_in, q_fp8_out, q_fp8_stride0, q_fp8_stride1, kv_in, k_cache, + slot_mapping, position_ids, cos_sin_cache, fp8_scale, + q_fp8_scale_inv, eps, num_tokens_full, num_tokens_insert, + num_heads_q, cache_block_size, kv_block_stride, kv_token_stride); #endif } @@ -509,3 +695,143 @@ void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( cache_block_size_i, kv_block_stride, stream); }); } + +void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( + torch::Tensor const& q, // [N, H, 512] bf16, read-only + torch::Tensor const& kv, // [N, 512] bf16, read-only + torch::Tensor& q_fp8, // [N, H, 512] fp8 e4m3 + torch::Tensor& k_cache, // [num_blocks, block_size, 512] fp8 + torch::Tensor const& slot_mapping, // [num_tokens_insert] int64 + torch::Tensor const& position_ids, // [N] int64 + torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] float32 + torch::Tensor const& fp8_scale, // scalar float32 + torch::Tensor const& q_fp8_scale_inv, // scalar float32 + double eps, int64_t cache_block_size) { + TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA"); + TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); + TORCH_CHECK(q_fp8.is_cuda() && q_fp8.is_contiguous(), + "q_fp8 must be contiguous CUDA"); + TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); + TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, + "slot_mapping must be int64 CUDA"); + TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, + "position_ids must be int64 CUDA"); + TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); + TORCH_CHECK(fp8_scale.is_cuda() && fp8_scale.dtype() == torch::kFloat32 && + fp8_scale.numel() == 1, + "fp8_scale must be a scalar float32 CUDA tensor"); + TORCH_CHECK(q_fp8_scale_inv.is_cuda() && + q_fp8_scale_inv.dtype() == torch::kFloat32 && + q_fp8_scale_inv.numel() == 1, + "q_fp8_scale_inv must be a scalar float32 CUDA tensor"); + TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]"); + TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); + TORCH_CHECK(q.dtype() == kv.dtype(), "q and kv dtype must match"); + TORCH_CHECK(q_fp8.sizes() == q.sizes(), "q_fp8 must match q shape"); + TORCH_CHECK(q_fp8.dtype() == torch::kFloat8_e4m3fn, + "q_fp8 must be float8_e4m3fn"); + TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size && + k_cache.size(2) == 512, + "k_cache shape [num_blocks, cache_block_size, 512]"); + TORCH_CHECK(k_cache.dtype() == torch::kFloat8_e4m3fn, + "k_cache must be float8_e4m3fn"); + TORCH_CHECK(k_cache.stride(2) == 1, + "k_cache last dimension must be contiguous"); + TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, + "cos_sin_cache must be float32"); + + int const num_tokens_full = static_cast(q.size(0)); + int const num_tokens_insert = static_cast(slot_mapping.size(0)); + TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && + static_cast(position_ids.size(0)) == num_tokens_full, + "q/kv/position_ids row counts must match"); + TORCH_CHECK(num_tokens_insert <= num_tokens_full, + "slot_mapping must not exceed q row count"); + int const num_heads_q = static_cast(q.size(1)); + int const cache_block_size_i = static_cast(cache_block_size); + + at::cuda::OptionalCUDAGuard device_guard(device_of(q)); + auto stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_HALF_TYPES( + q.scalar_type(), "fused_deepseek_v4_qnorm_rope_full_cache_fp8_insert", [&] { + using qkv_scalar_t = scalar_t; + vllm::deepseek_v4_fused_ops:: + launchFusedDeepseekV4QNormRopeFullCacheFP8Insert( + reinterpret_cast(q.data_ptr()), + reinterpret_cast(kv.data_ptr()), + reinterpret_cast(q_fp8.data_ptr()), + q_fp8.stride(0), q_fp8.stride(1), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(position_ids.data_ptr()), + cos_sin_cache.data_ptr(), fp8_scale.data_ptr(), + q_fp8_scale_inv.data_ptr(), static_cast(eps), + num_tokens_full, num_tokens_insert, num_heads_q, + cache_block_size_i, k_cache.stride(0) * k_cache.element_size(), + k_cache.stride(1) * k_cache.element_size(), stream); + }); +} + +void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert( + torch::Tensor& q, // [N, H, 512] bf16, in place + torch::Tensor const& kv, // [N, 512] bf16, read-only + torch::Tensor& k_cache, // [num_blocks, block_size, 512] bf16 + torch::Tensor const& slot_mapping, // [num_tokens_insert] int64 + torch::Tensor const& position_ids, // [N] int64 + torch::Tensor const& cos_sin_cache, // [max_pos, rope_dim] float32 + double eps, int64_t cache_block_size) { + TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA"); + TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); + TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); + TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, + "slot_mapping must be int64 CUDA"); + TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, + "position_ids must be int64 CUDA"); + TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); + TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]"); + TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); + TORCH_CHECK(q.dtype() == kv.dtype(), "q and kv dtype must match"); + TORCH_CHECK(q.dtype() == torch::kBFloat16, "q and kv must be bfloat16"); + TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size && + k_cache.size(2) == 512, + "k_cache shape [num_blocks, cache_block_size, 512]"); + TORCH_CHECK(k_cache.dtype() == torch::kBFloat16, "k_cache must be bfloat16"); + TORCH_CHECK(k_cache.stride(2) == 1, + "k_cache last dimension must be contiguous"); + TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, + "cos_sin_cache must be float32"); + + int const num_tokens_full = static_cast(q.size(0)); + int const num_tokens_insert = static_cast(slot_mapping.size(0)); + TORCH_CHECK(static_cast(kv.size(0)) == num_tokens_full && + static_cast(position_ids.size(0)) == num_tokens_full, + "q/kv/position_ids row counts must match"); + TORCH_CHECK(num_tokens_insert <= num_tokens_full, + "slot_mapping must not exceed q row count"); + int const num_heads_q = static_cast(q.size(1)); + int const cache_block_size_i = static_cast(cache_block_size); + + at::cuda::OptionalCUDAGuard device_guard(device_of(q)); + auto stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_HALF_TYPES( + q.scalar_type(), "fused_deepseek_v4_qnorm_rope_full_cache_bf16_insert", [&] { + using qkv_scalar_t = scalar_t; + vllm::deepseek_v4_fused_ops:: + launchFusedDeepseekV4QNormRopeFullCacheBF16Insert( + reinterpret_cast(q.data_ptr()), + reinterpret_cast(kv.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(position_ids.data_ptr()), + cos_sin_cache.data_ptr(), static_cast(eps), + num_tokens_full, num_tokens_insert, num_heads_q, + cache_block_size_i, k_cache.stride(0) * k_cache.element_size(), + k_cache.stride(1) * k_cache.element_size(), stream); + }); +} diff --git a/csrc/ops.h b/csrc/ops.h index 16a78f570cf6..3ffa9a5bed44 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -105,6 +105,18 @@ void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size); +void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( + torch::Tensor const& q, torch::Tensor const& kv, torch::Tensor& q_fp8, + torch::Tensor& k_cache, torch::Tensor const& slot_mapping, + torch::Tensor const& position_ids, torch::Tensor const& cos_sin_cache, + torch::Tensor const& fp8_scale, torch::Tensor const& q_fp8_scale_inv, + double eps, int64_t cache_block_size); + +void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert( + torch::Tensor& q, torch::Tensor const& kv, torch::Tensor& k_cache, + torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, + torch::Tensor const& cos_sin_cache, double eps, int64_t cache_block_size); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7562d90c0b99..e99beca249d2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -194,6 +194,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); + // Full-cache per-tensor FP8 variant for FlashInfer sparse MLA. Reuses the + // same CUDA warp-slot kernel structure as the legacy UE8M0 op, but writes Q + // to a separate FP8 tensor and KV into a full 512-wide FP8 paged cache. + ops.def( + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(" + "Tensor q, Tensor kv, Tensor! q_fp8, Tensor! k_cache, " + "Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, " + "Tensor fp8_scale, Tensor q_fp8_scale_inv, float eps, " + "int cache_block_size) -> ()"); + ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert", + torch::kCUDA, + &fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert); + + ops.def( + "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(" + "Tensor! q, Tensor kv, Tensor! k_cache, Tensor slot_mapping, " + "Tensor position_ids, Tensor cos_sin_cache, float eps, " + "int cache_block_size) -> ()"); + ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert", + torch::kCUDA, + &fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 3f6feecf5f78..416d7073ffa1 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -217,7 +217,7 @@ MLA decode backends are selected using the standard | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | -| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | +| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_per_tensor`, `fp8_inc`, `fp8_ds_mla`, `fp8_e4m3` | 64 | 512, 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index 46d226e0f74e..03c64aa3fd83 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -68,7 +68,7 @@ def apply_rope_gptj_last_k( nope_dim = head_dim - rope_dim # Gather cos/sin for each token position: [num_tokens, rope_dim] - cs = cos_sin_cache[positions].to(torch.float32) # [N, rope_dim] + cs = cos_sin_cache[positions.long()].to(torch.float32) # [N, rope_dim] cos = cs[..., :half] # [N, half] sin = cs[..., half:] # [N, half] @@ -113,6 +113,18 @@ def _op_available() -> bool: return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert") +def _full_cache_fp8_op_available() -> bool: + return hasattr( + torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert" + ) + + +def _full_cache_bf16_op_available() -> bool: + return hasattr( + torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert" + ) + + pytestmark = pytest.mark.skipif( not torch.cuda.is_available() or not _op_available(), reason="CUDA not available or fused DeepseekV4 op not built in", @@ -125,6 +137,109 @@ def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs) ) +def _call_full_cache_fp8_fused( + q, + kv, + q_fp8, + k_cache, + slot_mapping, + positions, + cos_sin_cache, + fp8_scale, + q_fp8_scale_inv, + eps, + bs, +): + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( + q, + kv, + q_fp8, + k_cache, + slot_mapping, + positions.long(), + cos_sin_cache, + fp8_scale, + q_fp8_scale_inv, + eps, + bs, + ) + + +def _call_full_cache_bf16_fused( + q, + kv, + k_cache, + slot_mapping, + positions, + cos_sin_cache, + eps, + bs, +): + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert( + q, + kv, + k_cache, + slot_mapping, + positions.long(), + cos_sin_cache, + eps, + bs, + ) + + +def _fp8_full_cache_reference( + q, + kv, + k_cache, + q_fp8, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + fp8_scale, + q_fp8_scale_inv, +): + q_ref = rmsnorm_no_weight(q, eps) + q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache) + q_fp8.copy_( + torch.clamp(q_ref.float() * q_fp8_scale_inv, -FP8_MAX, FP8_MAX).to( + torch.float8_e4m3fn + ) + ) + + kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache) + valid = slot_mapping >= 0 + slots = slot_mapping[valid] + block_idx = slots // block_size + pos_in_block = slots % block_size + k_cache[block_idx, pos_in_block] = torch.clamp( + kv_ref[valid].float() / fp8_scale, -FP8_MAX, FP8_MAX + ).to(torch.float8_e4m3fn) + + +def _bf16_full_cache_reference( + q, + kv, + k_cache, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, +): + q_ref = rmsnorm_no_weight(q, eps) + q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache) + + kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache) + valid = slot_mapping >= 0 + slots = slot_mapping[valid] + block_idx = slots // block_size + pos_in_block = slots % block_size + k_cache[block_idx, pos_in_block] = kv_ref[valid] + return q_ref + + # ── Test 1: Q path numerical parity ────────────────────────────────────────── @@ -357,3 +472,133 @@ def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int): torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) + + +@pytest.mark.skipif( + not _full_cache_fp8_op_available(), + reason="full-cache per-tensor FP8 DeepseekV4 op not built in", +) +@pytest.mark.parametrize("num_tokens", [4, 17]) +@pytest.mark.parametrize("n_heads", [8, 17]) +@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64]) +def test_full_cache_per_tensor_fp8_matches_reference( + num_tokens: int, + n_heads: int, + positions_dtype: torch.dtype, +): + torch.manual_seed(4) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + block_size = 16 + max_pos = 4096 + + q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device) + kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=positions_dtype, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + fp8_scale = torch.tensor([1.0], dtype=torch.float32, device=device) + q_fp8_scale_inv = torch.tensor([1.0], dtype=torch.float32, device=device) + + q_fp8_ref = torch.empty_like(q, dtype=torch.float8_e4m3fn) + q_fp8_fused = torch.empty_like(q, dtype=torch.float8_e4m3fn) + k_cache_ref = torch.zeros( + num_blocks, block_size, HEAD_DIM, dtype=torch.float8_e4m3fn, device=device + ) + k_cache_fused = torch.zeros_like(k_cache_ref) + + _fp8_full_cache_reference( + q, + kv, + k_cache_ref, + q_fp8_ref, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + fp8_scale, + q_fp8_scale_inv, + ) + + _call_full_cache_fp8_fused( + q.clone(), + kv, + q_fp8_fused, + k_cache_fused, + slot_mapping, + positions, + cos_sin_cache, + fp8_scale, + q_fp8_scale_inv, + eps, + block_size, + ) + + torch.testing.assert_close( + q_fp8_fused.float(), q_fp8_ref.float(), rtol=0, atol=0.25 + ) + torch.testing.assert_close( + k_cache_fused.float(), k_cache_ref.float(), rtol=0, atol=0.25 + ) + + +@pytest.mark.skipif( + not _full_cache_bf16_op_available(), + reason="full-cache BF16 DeepseekV4 op not built in", +) +@pytest.mark.parametrize("num_tokens", [4, 17]) +@pytest.mark.parametrize("n_heads", [8, 17]) +@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64]) +def test_full_cache_bf16_matches_reference( + num_tokens: int, + n_heads: int, + positions_dtype: torch.dtype, +): + torch.manual_seed(5) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + block_size = 16 + max_pos = 4096 + + q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device) + kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=positions_dtype, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + q_fused = q.clone() + k_cache_ref = torch.zeros( + num_blocks, block_size, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + k_cache_fused = torch.zeros_like(k_cache_ref) + q_ref = _bf16_full_cache_reference( + q, + kv, + k_cache_ref, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + ) + + _call_full_cache_bf16_fused( + q_fused, + kv, + k_cache_fused, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + ) + + torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) diff --git a/tests/kernels/test_fused_inv_rope_fp8_quant.py b/tests/kernels/test_fused_inv_rope_fp8_quant.py index 10561a8a0304..cf001e37b97f 100644 --- a/tests/kernels/test_fused_inv_rope_fp8_quant.py +++ b/tests/kernels/test_fused_inv_rope_fp8_quant.py @@ -725,6 +725,7 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups): This catches stride/layout bugs that only manifest when the einsum kernel actually consumes the quantized activations. """ + from deep_gemm.testing import calc_diff from deep_gemm.utils.math import ceil_div from vllm.utils.deep_gemm import ( @@ -809,8 +810,6 @@ def test_einsum_end_to_end(num_tokens, num_heads, n_groups): # Einsum output: Triton and CUDA both rotate in fp32 now, so diffs # come from fp32 ordering and UE8M0 boundary shifts only. # Use relative diff (same metric as test_fp8_einsum.py). - from deep_gemm.testing import calc_diff - z_diff = calc_diff(z_fused, z_ref) assert z_diff < 0.01, ( f"Einsum output diff too large: {z_diff:.6f} (expected < 0.01)" diff --git a/tests/test_config.py b/tests/test_config.py index 57d1e1bc686b..5ceaeaf3d19f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,6 +14,7 @@ import vllm.config.vllm as vllm_config_module from vllm.compilation.backends import VllmBackend from vllm.config import ( + CacheConfig, CompilationConfig, KernelConfig, ModelConfig, @@ -33,10 +34,21 @@ OptimizationLevel, ) from vllm.platforms import current_platform +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE DEVICE_TYPE = current_platform.device_type +def test_fp8_per_tensor_cache_dtype(): + cfg = CacheConfig(cache_dtype="fp8_per_tensor") + + assert cfg.cache_dtype == "fp8_per_tensor" + assert ( + STR_DTYPE_TO_TORCH_DTYPE["fp8_per_tensor"] + is STR_DTYPE_TO_TORCH_DTYPE["fp8_inc"] + ) + + def test_compile_config_repr_succeeds(): # setup: VllmBackend mutates the config object config = VllmConfig() diff --git a/vllm/config/cache.py b/vllm/config/cache.py index ae5023f1e348..ccd56f68e875 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -22,6 +22,7 @@ "fp8", "fp8_e4m3", "fp8_e5m2", + "fp8_per_tensor", "fp8_inc", "fp8_ds_mla", "turboquant_k8v4", diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index 48628fec46e0..8d5e917bc98a 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -129,11 +129,13 @@ def __init__( dtype: torch.dtype, compress_ratio: int, prefix: str, + alignment: int | None = 576, ): super().__init__() self.state_dim = state_dim self.dtype = dtype self.prefix = prefix + self.alignment = alignment self.kv_cache = torch.tensor([]) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -165,7 +167,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: head_size=self.state_dim, dtype=self.dtype, sliding_window=self.sliding_window, - alignment=576, # NOTE: FlashMLA requires 576B alignment + alignment=self.alignment, ) def forward(self): ... @@ -185,6 +187,7 @@ def __init__( prefix: str = "", k_cache_prefix="", use_fp4_cache: bool = False, + state_cache_alignment: int | None = 576, ): super().__init__() self.compress_ratio = compress_ratio @@ -232,6 +235,7 @@ def __init__( dtype=state_dtype, compress_ratio=compress_ratio, prefix=f"{prefix}.state_cache", + alignment=state_cache_alignment, ) # Save reference to static_forward_context for forward-time KV cache lookup. @@ -337,7 +341,60 @@ def forward( # - position used: (positions // compress_ratio) * compress_ratio cos_sin_cache = rotary_emb.cos_sin_cache k_cache_metadata = cast(Any, attn_metadata[self.k_cache_prefix]) - kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache + k_cache_layer = self._static_forward_context[self.k_cache_prefix] + kv_cache = k_cache_layer.kv_cache + + if self.head_dim == 512: + if kv_cache.dtype != torch.uint8: + assert kv_cache.dtype in (torch.bfloat16, torch.float8_e4m3fn) + self._fused_kernel[(num_actual,)]( + # state cache + state_cache, + state_cache.stride(0), + state_cache.stride(1), + # metadata + token_to_req_indices, + positions, + slot_mapping, + block_table, + block_table.stride(0), + block_size, + # RMSNorm + self.norm.weight, + self.rms_norm_eps, + # RoPE + cos_sin_cache, + cos_sin_cache.stride(0), + # KV cache + kv_cache, + k_cache_metadata.slot_mapping, + kv_cache.shape[1], + ( + k_cache_layer._flashinfer_fp8_kv_scale + if kv_cache.dtype != torch.uint8 + else self.norm.weight + ), + # constexprs + HEAD_SIZE=self.head_dim, + TRITON_BLOCK_SIZE=triton.next_power_of_2(self.head_dim), + STATE_WIDTH=state_width, + COMPRESS_RATIO=self.compress_ratio, + OVERLAP=self.overlap, + ROPE_HEAD_DIM=self.rope_head_dim, + FP8_MAX=448.0, + QUANT_BLOCK=self._quant_block, + TOKEN_STRIDE=self._token_stride, + SCALE_DIM=self._scale_dim, + KV_BLOCK_STRIDE=kv_cache.stride(0), + KV_TOKEN_STRIDE=( + kv_cache.stride(1) if kv_cache.dtype != torch.uint8 else 0 + ), + STORE_FULL_CACHE=kv_cache.dtype != torch.uint8, + STORE_FULL_FP8=kv_cache.dtype == torch.float8_e4m3fn, + num_warps=self._num_warps, + **pdl_kwargs, + ) + return self._fused_kernel[(num_actual,)]( # state cache diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index e9a6ec9a587c..356cb2d61d2a 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -21,6 +21,7 @@ from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( + build_flashinfer_mixed_sparse_indices, combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, @@ -55,10 +56,12 @@ GroupShape, ) from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, ) +from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -79,12 +82,72 @@ logger = init_logger(__name__) +_FLASHINFER_DSV4_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +_flashinfer_dsv4_workspace_by_device: dict[torch.device, torch.Tensor] = {} +FlashInferSparseIndexMetadata = tuple[ + torch.Tensor, # compressed KV cache consumed by FlashInfer. + torch.Tensor, # query_start_loc. + torch.Tensor, # query_start_loc_cpu. + torch.Tensor, # seq_lens. + torch.Tensor, # sparse_indices. + torch.Tensor, # sparse_topk_lens. +] + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time # reservation in attention_impl's dummy-run branch). PREFILL_CHUNK_SIZE = 4 +def _normalize_dsv4_kv_cache_dtype( + cache_config: CacheConfig | None, +) -> str: + kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "auto" + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + assert cache_config is not None + cache_config.cache_dtype = "fp8_ds_mla" + return "fp8_ds_mla" + if kv_cache_dtype == "fp8_inc": + assert cache_config is not None + cache_config.cache_dtype = "fp8_per_tensor" + return "fp8_per_tensor" + return kv_cache_dtype + + +def _dsv4_kv_cache_torch_dtype( + kv_cache_dtype: str, + vllm_config: VllmConfig, +) -> torch.dtype: + if kv_cache_dtype == "fp8_ds_mla": + return torch.uint8 + if kv_cache_dtype in ("fp8_per_tensor", "fp8_inc"): + return torch.float8_e4m3fn + if kv_cache_dtype == "bfloat16": + return torch.bfloat16 + if kv_cache_dtype == "auto": + dtype = kv_cache_dtype_str_to_dtype(kv_cache_dtype, vllm_config.model_config) + if dtype == torch.bfloat16: + return dtype + raise ValueError( + "DeepSeek V4 FlashInfer sparse MLA supports only BF16 or per-tensor " + f"FP8 E4M3 KV cache; got kv_cache_dtype={kv_cache_dtype}. Use " + "`bfloat16`/`auto` for BF16, `fp8_per_tensor` for per-tensor FP8, or " + "`fp8`/`fp8_ds_mla` for the legacy UE8M0 FlashMLA path." + ) + + +def _get_flashinfer_dsv4_workspace(device: torch.device) -> torch.Tensor: + workspace = _flashinfer_dsv4_workspace_by_device.get(device) + if workspace is None: + workspace = torch.zeros( + _FLASHINFER_DSV4_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, + ) + _flashinfer_dsv4_workspace_by_device[device] = workspace + return workspace + + @dataclass class DeepseekV4MLAModules: """Modules used in DeepseekV4 MLA.""" @@ -230,10 +293,16 @@ def __init__( self.ln_events = [torch.cuda.Event() for _ in range(4)] assert cache_config is not None, "DeepseekV4 attention requires cache_config" + kv_cache_dtype = _normalize_dsv4_kv_cache_dtype(cache_config) + kv_cache_torch_dtype = _dsv4_kv_cache_torch_dtype( + kv_cache_dtype, mla_modules.vllm_config + ) + if kv_cache_dtype == "fp8_ds_mla": + logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") self.swa_cache_layer = DeepseekV4SWACache( head_dim=self.head_dim, window_size=self.window_size, - dtype=torch.uint8, + dtype=kv_cache_torch_dtype, prefix=f"{prefix}.swa_cache", cache_config=cache_config, ) @@ -278,6 +347,12 @@ def __init__( rotate=True, prefix=f"{prefix}.compressor", k_cache_prefix=self.mla_attn.prefix, + # Legacy FlashMLA state/KV pages need 576B alignment. The + # FlashInfer BF16/per-tensor FP8 path shares state pages with + # contiguous C4 KV pages, so padding would break page matching. + state_cache_alignment=( + 576 if self.mla_attn.kv_cache_dtype == "fp8_ds_mla" else None + ), ) def forward( @@ -286,11 +361,18 @@ def forward( hidden_states: torch.Tensor, llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: - # Pre-allocate attention output with FlashMLA-padded head count. - # The op writes into `o_padded`; we slice to n_local_heads after. + # FlashMLA requires 64/128 heads. FlashInfer full-cache modes run on + # the actual local head count, avoiding padded Q/output work. num_tokens = hidden_states.shape[0] - o_padded = torch.empty( - (num_tokens, self.padded_heads, self.head_dim), + use_flashinfer_full_cache = ( + self.mla_attn.kv_cache_torch_dtype != torch.uint8 + and not current_platform.is_rocm() + ) + output_heads = ( + self.n_local_heads if use_flashinfer_full_cache else self.padded_heads + ) + o_attn = torch.empty( + (num_tokens, output_heads, self.head_dim), dtype=hidden_states.dtype, device=hidden_states.device, ) @@ -299,10 +381,10 @@ def forward( torch.ops.vllm.deepseek_v4_attention( hidden_states, positions, - o_padded, + o_attn, self.layer_name, ) - o = o_padded[:, : self.n_local_heads, :] + o = o_attn if use_flashinfer_full_cache else o_attn[:, : self.n_local_heads, :] # Keep ROCm on the BF16 reference wo_a path util kernel ready. if current_platform.is_rocm(): @@ -433,24 +515,27 @@ def attention_impl( # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride # on the default stream so q stays on its consumer stream (mla_attn - # downstream reads q on default). Indexer/compressor go on aux for + # downstream reads q on current). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. if self.indexer is not None: aux_stream = ( self.aux_stream_list[0] if self.aux_stream_list is not None else None ) indexer = self.indexer - # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None compressor = self.compressor - def wq_b_kv_insert_and_compress() -> torch.Tensor: + def wq_b_kv_insert_and_compress() -> tuple[ + torch.Tensor, torch.Tensor | None + ]: q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + q_fp8 = self._fused_qnorm_rope_kv_insert( + q, kv, positions, attn_metadata + ) compressor(kv_score, positions, self.rotary_emb) - return q + return q, q_fp8 - q, _ = maybe_execute_in_parallel( + (q, q_fp8), _ = maybe_execute_in_parallel( wq_b_kv_insert_and_compress, lambda: indexer( hidden_states, @@ -471,12 +556,14 @@ def wq_b_kv_insert_and_compress() -> torch.Tensor: ) compressor = self.compressor - def wq_b_kv_insert() -> torch.Tensor: + def wq_b_kv_insert() -> tuple[torch.Tensor, torch.Tensor | None]: q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - return q + q_fp8 = self._fused_qnorm_rope_kv_insert( + q, kv, positions, attn_metadata + ) + return q, q_fp8 - q, _ = maybe_execute_in_parallel( + (q, q_fp8), _ = maybe_execute_in_parallel( wq_b_kv_insert, lambda: compressor(kv_score, positions, self.rotary_emb), self.ln_events[0], @@ -486,35 +573,51 @@ def wq_b_kv_insert() -> torch.Tensor: else: # SWA-only layer: no compressor, no overlap. q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + q_fp8 = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) # Handle dummy run (no metadata). if not isinstance(attn_metadata, dict): # Reserve _forward_prefill's bf16-gather workspace; the dummy # run returns before mla_attn runs, so without this the shared - # workspace locks below the real prefill size. + # workspace locks below the real prefill size. The per-tensor FP8 + # FlashInfer path reads the cache directly and does not need it. sub = self.mla_attn - swa_only = sub.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio - ) - M = N + sub.window_size + sub.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) + if sub.kv_cache_torch_dtype != torch.float8_e4m3fn: + swa_only = sub.compress_ratio <= 1 + N = ( + 0 + if swa_only + else (sub.max_model_len + sub.compress_ratio - 1) + // sub.compress_ratio + ) + M = N + sub.window_size + sub.max_num_batched_tokens + current_workspace_manager().get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ) out.zero_() return - # Pad q to FlashMLA-required head count (64 or 128) - if self.n_local_heads < self.padded_heads: + q_for_attn = q_fp8 if q_fp8 is not None else q + + # Pad q only for the legacy FlashMLA path, which requires 64 or 128 + # heads. FlashInfer full-cache modes keep the actual local head count. + if ( + q_fp8 is None + and self.mla_attn.kv_cache_torch_dtype == torch.uint8 + and self.n_local_heads < self.padded_heads + ): pad_size = self.padded_heads - self.n_local_heads - q = F.pad(q, (0, 0, 0, pad_size), value=0.0) + q_for_attn = F.pad(q_for_attn, (0, 0, 0, pad_size), value=0.0) - # MLA attention writes into the pre-allocated `out` buffer - # ([num_tokens, padded_heads, head_dim]). - self.mla_attn(q, kv, positions, output=out) + # MLA attention writes into the pre-allocated `out` buffer. FlashMLA + # gets a padded-head buffer; FlashInfer full-cache modes get actual + # local heads. + self.mla_attn( + q_for_attn, + kv, + positions, + output=out, + ) def _fused_qnorm_rope_kv_insert( self, @@ -524,9 +627,9 @@ def _fused_qnorm_rope_kv_insert( attn_metadata: ( dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None ), - ) -> None: + ) -> torch.Tensor | None: if not isinstance(attn_metadata, dict): - return + return None swa_metadata = cast( "DeepseekSparseSWAMetadata | None", @@ -535,22 +638,61 @@ def _fused_qnorm_rope_kv_insert( assert swa_metadata is not None swa_kv_cache = self.swa_cache_layer.kv_cache - swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) # Horizontally fused: # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE - # KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert + # KV side: GPT-J RoPE + paged cache insert. The uint8 cache keeps the + # legacy UE8M0 layout; BF16/FP8 caches store the full 512-wide vector. # kv is unchanged; mla_attn reads kv solely via swa_kv_cache. - torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - q, - kv, - swa_kv_cache_2d, - swa_metadata.slot_mapping, - positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, - self.eps, - swa_metadata.block_size, - ) + if swa_kv_cache.dtype == torch.uint8: + swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + q, + kv, + swa_kv_cache_2d, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + self.eps, + swa_metadata.block_size, + ) + return None + + if swa_kv_cache.dtype == torch.bfloat16: + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert( + q, + kv, + swa_kv_cache, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + self.eps, + swa_metadata.block_size, + ) + return None + + if swa_kv_cache.dtype == torch.float8_e4m3fn: + q_fp8 = torch.empty( + (q.shape[0], self.n_local_heads, q.shape[-1]), + dtype=torch.float8_e4m3fn, + device=q.device, + ) + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert( + q, + kv, + q_fp8, + swa_kv_cache, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + self.mla_attn._flashinfer_fp8_kv_scale, + self.mla_attn._flashinfer_fp8_q_scale_inv, + self.eps, + swa_metadata.block_size, + ) + return q_fp8 + + raise AssertionError(f"Unsupported SWA KV cache dtype {swa_kv_cache.dtype}") def deepseek_v4_attention( @@ -687,29 +829,39 @@ def __init__( vllm_config.scheduler_config.max_num_batched_tokens ) self.max_model_len = vllm_config.model_config.max_model_len - # DeepseekV4 only supports fp8 kv-cache format for now. - kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" - - assert kv_cache_dtype.startswith("fp8"), ( - f"DeepseekV4 only supports fp8 kv-cache format for now, " - f"got {kv_cache_dtype}" - ) assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), ( - "Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now" + "DeepseekV4 requires the sparse MLA metadata/cache backend" + ) + kv_cache_dtype = _normalize_dsv4_kv_cache_dtype(cache_config) + self.kv_cache_torch_dtype = _dsv4_kv_cache_torch_dtype( + kv_cache_dtype, vllm_config ) - # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format - # Automatically convert fp8 kv-cache format to "fp8_ds_mla" - if ( - issubclass(self.get_attn_backend(), FlashMLASparseBackend) - and kv_cache_dtype.startswith("fp8") - and kv_cache_dtype != "fp8_ds_mla" - ): - assert cache_config is not None - cache_config.cache_dtype = "fp8_ds_mla" - kv_cache_dtype = "fp8_ds_mla" - logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") self.kv_cache_dtype = kv_cache_dtype + fp8_q_scale = 1.0 + fp8_kv_scale = 1.0 + if self.kv_cache_torch_dtype == torch.float8_e4m3fn: + # TODO: load the per-tensor FP8 Q and KV scales from checkpoint + # weights. Use unit scales until the scale tensor names are wired. + fp8_q_scale = 1.0 + fp8_kv_scale = 1.0 + self.register_buffer( + "_flashinfer_fp8_q_scale", + torch.tensor([fp8_q_scale], dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "_flashinfer_fp8_q_scale_inv", + torch.tensor([1.0 / fp8_q_scale], dtype=torch.float32), + persistent=False, + ) + self.register_buffer( + "_flashinfer_fp8_kv_scale", + torch.tensor([fp8_kv_scale], dtype=torch.float32), + persistent=False, + ) + self._flashinfer_fp8_bmm1_scale = self.scale * fp8_q_scale * fp8_kv_scale + self._flashinfer_fp8_bmm2_scale = fp8_kv_scale # Register with compilation context for metadata lookup compilation_config = vllm_config.compilation_config @@ -738,10 +890,14 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: block_size=vllm_config.cache_config.block_size, num_kv_heads=1, head_size=self.head_dim, - dtype=torch.uint8, + dtype=self.kv_cache_torch_dtype, compress_ratio=self.compress_ratio, cache_dtype_str=self.kv_cache_dtype, - alignment=576, # NOTE: FlashMLA requires 576B alignment + # FlashMLA's legacy fp8_ds_mla layout needs 576B page alignment. + # FlashInfer DSV4 BF16/per-tensor FP8 sparse decode treats the KV + # pool as a flat contiguous token array, so padding would skew + # physical sparse indices after the first page. + alignment=576 if self.kv_cache_dtype == "fp8_ds_mla" else None, model_version="deepseek_v4", ) @@ -755,8 +911,12 @@ def forward( assert output.shape == q.shape, ( f"output buffer shape {output.shape} must match q shape {q.shape}" ) - assert output.dtype == q.dtype, ( - f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" + expected_output_dtype = ( + torch.bfloat16 if q.dtype == torch.float8_e4m3fn else q.dtype + ) + assert output.dtype == expected_output_dtype, ( + f"output buffer dtype {output.dtype} must match expected attention " + f"output dtype {expected_output_dtype} for q dtype {q.dtype}" ) if current_platform.is_rocm(): @@ -791,6 +951,23 @@ def forward( num_prefills = swa_metadata.num_prefills num_decode_tokens = swa_metadata.num_decode_tokens + if self.kv_cache_torch_dtype != torch.uint8: + if current_platform.is_rocm(): + raise NotImplementedError( + "DeepSeek V4 BF16/per-tensor FP8 FlashInfer sparse MLA " + "cache path is CUDA-only." + ) + self._forward_flashinfer( + q=q, + kv_cache=self_kv_cache, + swa_k_cache=swa_kv_cache, + swa_metadata=swa_metadata, + attn_metadata=flashmla_metadata, + swa_only=swa_only, + output=output, + ) + return + if num_prefills > 0: self._forward_prefill( q=q[num_decode_tokens:], @@ -849,6 +1026,8 @@ def _forward_decode( swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens + assert self.kv_cache_torch_dtype == torch.uint8 + # We treat queries in the same seq as different queries # and later we only attend by generated indices. # q arrives pre-padded to self.padded_heads by the outer wrapper. @@ -903,6 +1082,238 @@ def _forward_decode( out=output.unsqueeze(1), ) + def _build_flashinfer_sparse_index_metadata( + self, + kv_cache: torch.Tensor | None, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + swa_only: bool, + ) -> FlashInferSparseIndexMetadata: + num_decodes = swa_metadata.num_decodes + num_prefills = swa_metadata.num_prefills + num_decode_tokens = swa_metadata.num_decode_tokens + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_reqs = num_decodes + num_prefills + num_tokens = num_decode_tokens + num_prefill_tokens + + assert swa_metadata.seq_lens is not None + assert swa_metadata.query_start_loc is not None + assert swa_metadata.query_start_loc_cpu is not None + assert swa_metadata.token_to_req_indices is not None + assert swa_metadata.decode_swa_indices is not None + assert swa_metadata.block_table is not None + + decode_swa_indices = swa_metadata.decode_swa_indices.reshape( + num_decode_tokens, self.window_size + ) + decode_compressed_topk_lens = None + decode_compressed_indices_are_local = False + decode_is_valid_token = None + + if swa_only: + assert self.topk_indices_buffer is not None + compressed_kv_cache = swa_k_cache + decode_compressed_indices = None + prefill_topk_indices = self.topk_indices_buffer[ + num_decode_tokens:num_tokens, :0 + ] + compressed_block_table = None + compressed_block_size = swa_metadata.block_size + top_k = 0 + else: + assert kv_cache is not None + assert attn_metadata is not None + compressed_kv_cache = kv_cache + compressed_block_table = attn_metadata.block_table[:num_reqs] + compressed_block_size = attn_metadata.block_size // self.compress_ratio + + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + if num_prefill_tokens > 0: + prefill_topk_indices = self.topk_indices_buffer[ + num_decode_tokens:num_tokens + ] + top_k = prefill_topk_indices.shape[-1] + else: + prefill_topk_indices = self.topk_indices_buffer[:0, :0] + top_k = 0 + + decode_compressed_indices_are_local = True + assert swa_metadata.is_valid_token is not None + decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens] + if num_decode_tokens > 0: + decode_compressed_indices = self.topk_indices_buffer[ + :num_decode_tokens + ] + else: + # The decode-side pointers are unused when there are no + # decode tokens. Keep their logical width aligned with the + # mixed-batch case so pure-prefill steps reuse the same + # Triton specialization compiled during graph capture. + decode_compressed_indices = prefill_topk_indices[:0] + else: + if num_prefill_tokens > 0: + assert attn_metadata.c128a_prefill_topk_indices is not None + prefill_topk_indices = attn_metadata.c128a_prefill_topk_indices + top_k = prefill_topk_indices.shape[-1] + else: + prefill_topk_indices = decode_swa_indices[:0, :0] + top_k = 0 + + if num_decode_tokens > 0: + assert attn_metadata.c128a_global_decode_topk_indices is not None + assert attn_metadata.c128a_decode_topk_lens is not None + decode_compressed_indices = ( + attn_metadata.c128a_global_decode_topk_indices.view( + num_decode_tokens, -1 + ) + ) + decode_compressed_topk_lens = attn_metadata.c128a_decode_topk_lens + if num_prefill_tokens == 0: + prefill_topk_indices = decode_compressed_indices[:0, :0] + else: + # As above, these decode tensors are unused for pure prefill. + # Preserve the C128A topk width and lens-present flag to + # share the mixed-batch sparse-index kernel variant. + decode_compressed_indices = prefill_topk_indices[:0] + decode_compressed_topk_lens = swa_metadata.seq_lens[:0] + + query_start_loc = swa_metadata.query_start_loc[: num_reqs + 1] + query_start_loc_cpu = swa_metadata.query_start_loc_cpu[: num_reqs + 1] + seq_lens = swa_metadata.seq_lens[:num_reqs] + assert seq_lens.dtype == torch.int32 + sparse_indices, sparse_topk_lens = build_flashinfer_mixed_sparse_indices( + decode_swa_indices, + decode_compressed_indices, + decode_compressed_topk_lens, + prefill_topk_indices[:num_prefill_tokens], + query_start_loc, + seq_lens, + swa_metadata.token_to_req_indices[:num_tokens], + swa_metadata.block_table[:num_reqs], + swa_metadata.block_size, + compressed_block_table, + compressed_block_size, + self.window_size, + self.compress_ratio, + top_k, + decode_compressed_indices_are_local=decode_compressed_indices_are_local, + decode_is_valid_token=decode_is_valid_token, + ) + return ( + compressed_kv_cache, + query_start_loc, + query_start_loc_cpu, + seq_lens, + sparse_indices, + sparse_topk_lens, + ) + + def _forward_flashinfer( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + swa_only: bool, + output: torch.Tensor, + ) -> None: + assert self.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn) + num_decodes = swa_metadata.num_decodes + num_prefills = swa_metadata.num_prefills + num_decode_tokens = swa_metadata.num_decode_tokens + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_reqs = num_decodes + num_prefills + num_tokens = num_decode_tokens + num_prefill_tokens + if num_tokens == 0: + return + + flashinfer_sparse_metadata = self._build_flashinfer_sparse_index_metadata( + kv_cache=kv_cache, + swa_k_cache=swa_k_cache, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + swa_only=swa_only, + ) + ( + compressed_kv_cache, + query_start_loc, + query_start_loc_cpu, + seq_lens, + sparse_indices, + sparse_topk_lens, + ) = flashinfer_sparse_metadata + + # CUDA graph execution can pad q/output past the scheduled token count. + # The FlashInfer DSV4 launcher validates sparse_indices against real + # tokens, so keep the tensors restricted to the scheduled token range. + query = q[:num_tokens] + output = output[:num_tokens] + bmm1_scale: float | torch.Tensor = self.scale + bmm2_scale: float | torch.Tensor = 1.0 + if self.kv_cache_torch_dtype == torch.float8_e4m3fn: + assert query.dtype == torch.float8_e4m3fn + bmm1_scale = self._flashinfer_fp8_bmm1_scale + bmm2_scale = self._flashinfer_fp8_bmm2_scale + else: + assert query.dtype == torch.bfloat16 + query = query.contiguous() + + workspace = _get_flashinfer_dsv4_workspace(q.device) + + if num_decode_tokens > 0: + decode_query_start_loc = query_start_loc[: num_decodes + 1] + decode_query_start_loc_cpu = query_start_loc_cpu[: num_decodes + 1] + decode_query_lens_cpu = ( + decode_query_start_loc_cpu[1:] - decode_query_start_loc_cpu[:-1] + ) + flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( + query=query[:num_decode_tokens], + swa_kv_cache=swa_k_cache, + workspace_buffer=workspace, + sparse_indices=sparse_indices[:num_decode_tokens], + compressed_kv_cache=compressed_kv_cache, + sparse_topk_lens=sparse_topk_lens[:num_decode_tokens], + seq_lens=seq_lens[:num_decodes], + out=output[:num_decode_tokens], + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + sinks=self.attn_sink, + cum_seq_lens_q=decode_query_start_loc, + max_q_len=int(decode_query_lens_cpu.max().item()), + ) + + if num_prefill_tokens > 0: + assert swa_metadata.prefill_query_start_loc is not None + prefill_query_start_loc = swa_metadata.prefill_query_start_loc + prefill_query_start_loc_cpu = query_start_loc_cpu[ + num_decodes : num_reqs + 1 + ] + prefill_query_lens_cpu = ( + prefill_query_start_loc_cpu[1:] - prefill_query_start_loc_cpu[:-1] + ) + prefill_query = query[num_decode_tokens:num_tokens] + prefill_output = output[num_decode_tokens:num_tokens] + prefill_sparse_indices = sparse_indices[num_decode_tokens:num_tokens] + prefill_sparse_topk_lens = sparse_topk_lens[num_decode_tokens:num_tokens] + flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( + query=prefill_query, + swa_kv_cache=swa_k_cache, + workspace_buffer=workspace, + sparse_indices=prefill_sparse_indices, + compressed_kv_cache=compressed_kv_cache, + sparse_topk_lens=prefill_sparse_topk_lens, + seq_lens=seq_lens[num_decodes:num_reqs], + out=prefill_output, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + sinks=self.attn_sink, + cum_seq_lens_q=prefill_query_start_loc, + max_q_len=int(prefill_query_lens_cpu.max().item()), + ) + def _forward_prefill( self, q: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 791d9b1bf5ed..3cd5a9cb73a1 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -172,6 +172,19 @@ def expert_dtype(self) -> str: @property def is_scale_e8m0(self) -> bool: + try: + hf_config = get_current_vllm_config().model_config.hf_config + except Exception: + hf_config = None + + scale_fmt = getattr(hf_config, "scale_fmt", None) + if scale_fmt is None and hf_config is not None: + quantization_config = getattr(hf_config, "quantization_config", None) + if isinstance(quantization_config, dict): + scale_fmt = quantization_config.get("scale_fmt") + if scale_fmt is not None: + return scale_fmt == "ue8m0" + # FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert # checkpoints (Flash-Base) store them as float32. return self.expert_dtype == "fp4" @@ -470,10 +483,11 @@ def __init__( # Register in the static forward context so the custom-op wrapper # can look up this module by name from within a torch.compile graph. - compilation_config = vllm_config.compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self + if prefix: + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self def _map_global_expert_id(self, expert_id: int) -> int: if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx: diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 70abd8a6c503..d08624d6fca2 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -45,6 +45,8 @@ def kernel_warmup(worker: "Worker"): elif has_flashinfer() and current_platform.has_device_capability(90): flashinfer_autotune(worker.model_runner) + deepseek_v4_flashinfer_sparse_mla_warmup(worker) + # FlashInfer attention warmup # Only warmup if the model has FlashInfer attention groups # and is not a pooling model @@ -78,6 +80,130 @@ def _is_flashinfer_backend(backend): ) +def deepseek_v4_flashinfer_sparse_mla_warmup(worker: "Worker") -> None: + """Warm the DSV4 FlashInfer sparse-index builder variants. + + CUDA graph capture exercises mixed batches, but Triton can still see the + first real prefill wave as a separate specialization for the per-layer C4A + and C128A index shapes. Compile those tiny index-builder launches during + engine warmup so they do not appear as inference-time bubbles. + """ + from vllm.v1.attention.backends.mla.sparse_swa import ( + _compute_prefill_metadata_kernel, + ) + from vllm.v1.attention.ops.deepseek_v4_ops.cache_utils import ( + build_flashinfer_mixed_sparse_indices, + ) + + hf_config = worker.vllm_config.model_config.hf_config + compress_ratios = { + int(ratio) for ratio in getattr(hf_config, "compress_ratios", ()) + } + if not compress_ratios: + return + + window_size = int(getattr(hf_config, "sliding_window", 0)) + if window_size <= 0: + return + + logger.info("Warming up DeepSeek V4 FlashInfer sparse MLA index kernels.") + device = worker.model_runner.device + index_topk = int(getattr(hf_config, "index_topk", 0)) + max_model_len = worker.vllm_config.model_config.max_model_len + max_num_seqs = max(1, worker.scheduler_config.max_num_seqs) + + def _prefill_batch_sizes() -> list[int]: + sizes: list[int] = [] + size = 1 + while size < max_num_seqs: + sizes.append(size) + size *= 2 + sizes.append(max_num_seqs) + return sizes + + max_prefill_reqs = max(_prefill_batch_sizes()) + seq_lens = torch.ones((max_prefill_reqs,), device=device, dtype=torch.int32) + query_start_loc = torch.arange( + max_prefill_reqs + 1, device=device, dtype=torch.int32 + ) + prefill_query_start_loc = torch.empty( + max_prefill_reqs + 1, device=device, dtype=torch.int32 + ) + prefill_gather_lens = torch.empty( + max_prefill_reqs, device=device, dtype=torch.int32 + ) + for num_prefills in _prefill_batch_sizes(): + _compute_prefill_metadata_kernel[(1,)]( + prefill_query_start_loc[: num_prefills + 1], + prefill_gather_lens[:num_prefills], + seq_lens[:num_prefills], + query_start_loc[: num_prefills + 1], + num_prefills, + 0, + window_size, + BLOCK_SIZE=1 << num_prefills.bit_length(), + ) + + for compress_ratio in sorted(compress_ratios): + if compress_ratio == 4: + topk = index_topk + decode_compressed_indices_are_local = True + has_decode_compressed_lens = False + elif compress_ratio == 128: + topk = (max_model_len + compress_ratio - 1) // compress_ratio + topk = ((topk + 127) // 128) * 128 + decode_compressed_indices_are_local = False + has_decode_compressed_lens = True + else: + continue + + if topk <= 0: + continue + + decode_swa_indices = torch.zeros( + (1, window_size), device=device, dtype=torch.int32 + ) + decode_compressed_indices = torch.zeros( + (1, topk), device=device, dtype=torch.int32 + ) + prefill_topk_indices = torch.zeros((1, topk), device=device, dtype=torch.int32) + query_start_loc = torch.tensor([0, 1, 2], device=device, dtype=torch.int32) + seq_lens = torch.tensor([1, 2], device=device, dtype=torch.int32) + token_to_req_indices = torch.tensor([0, 1], device=device, dtype=torch.int32) + swa_block_table = torch.zeros((2, 1), device=device, dtype=torch.int32) + compressed_block_table = torch.zeros((2, 1), device=device, dtype=torch.int32) + decode_compressed_topk_lens = ( + torch.ones((1,), device=device, dtype=torch.int32) + if has_decode_compressed_lens + else None + ) + decode_is_valid_token = ( + torch.ones((1,), device=device, dtype=torch.bool) + if decode_compressed_indices_are_local + else None + ) + + build_flashinfer_mixed_sparse_indices( + decode_swa_indices, + decode_compressed_indices, + decode_compressed_topk_lens, + prefill_topk_indices, + query_start_loc, + seq_lens, + token_to_req_indices, + swa_block_table, + 256, + compressed_block_table, + max(1, 256 // compress_ratio), + window_size, + compress_ratio, + topk, + decode_compressed_indices_are_local=decode_compressed_indices_are_local, + decode_is_valid_token=decode_is_valid_token, + ) + torch.accelerator.synchronize() + + def flashinfer_autotune(runner: "GPUModelRunner") -> None: """ Autotune FlashInfer operations. diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index e6c497c0b450..ac62e9b279af 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -83,7 +83,6 @@ def __getitem__(self, key): return getattr(configs, value) - _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( afmoe="AfmoeConfig", bagel="BagelConfig", diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 44fcc19c2d2b..23dfc985bf0a 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -71,6 +71,14 @@ def _missing(*_: Any, **__: Any) -> NoReturn: ) +def _missing_dsv4_sparse_mla(*_: Any, **__: Any) -> NoReturn: + raise RuntimeError( + "flashinfer.mla.trtllm_batch_decode_sparse_mla_dsv4 is not available. " + "Install a FlashInfer build that includes DeepSeek V4 sparse MLA " + "TRTLLM-GEN support." + ) + + def _get_submodule(module_name: str) -> Any | None: """Safely import a submodule and return it, or None if not available.""" try: @@ -137,6 +145,97 @@ def wrapper(*args, **kwargs): trtllm_fp4_block_scale_moe = _lazy_import_wrapper( "flashinfer", "trtllm_fp4_block_scale_moe" ) +flashinfer_trtllm_batch_decode_sparse_mla_dsv4 = _lazy_import_wrapper( + "flashinfer.mla", + "trtllm_batch_decode_sparse_mla_dsv4", + fallback_fn=_missing_dsv4_sparse_mla, +) + + +@functools.cache +def _get_dsv4_sparse_mla_raw_impl(): + if not has_flashinfer(): + return None + core = _get_submodule("flashinfer.mla._core") + if core is None: + return None + op = core.get_trtllm_gen_fmha_module() + run_func = getattr(op, "trtllm_paged_attention_decode_sparse_mla_dsv4", None) + if run_func is None: + run_func = getattr(op, "dsv4_sparse_mla", None) + if run_func is None: + return None + return run_func, core.device_support_pdl, core.get_device_sm_count + + +def flashinfer_trtllm_batch_decode_sparse_mla_dsv4_raw( + *, + query: torch.Tensor, + swa_kv_cache: torch.Tensor, + workspace_buffer: torch.Tensor, + sparse_indices: torch.Tensor, + compressed_kv_cache: torch.Tensor, + sparse_topk_lens: torch.Tensor, + seq_lens: torch.Tensor, + out: torch.Tensor, + bmm1_scale: float | torch.Tensor = 1.0, + bmm2_scale: float | torch.Tensor = 1.0, + sinks: torch.Tensor | None = None, + cum_seq_lens_q: torch.Tensor | None = None, + max_q_len: int | None = None, + enable_pdl: bool | None = None, +) -> torch.Tensor: + """Unchecked DeepSeek V4 sparse MLA launcher for hot vLLM decode paths. + + The caller must provide HND-compatible 3D/4D KV caches, contiguous INT32 + metadata, a BF16 output tensor, and launcher-ready scale tensors. This skips + FlashInfer's Python validation, which otherwise adds syncs and pointwise + kernels on every attention layer. + """ + impl = _get_dsv4_sparse_mla_raw_impl() + if impl is None: + return _missing_dsv4_sparse_mla() + + run_func, device_support_pdl, get_device_sm_count = impl + if enable_pdl is None: + enable_pdl = device_support_pdl(query.device) + + if swa_kv_cache.ndim == 3: + swa_kv_cache = swa_kv_cache.unsqueeze(1) + if compressed_kv_cache.ndim == 3: + compressed_kv_cache = compressed_kv_cache.unsqueeze(1) + + if cum_seq_lens_q is None: + batch_size, q_len_per_request = query.shape[:2] + query_flat = query.flatten(0, 1) + else: + batch_size = cum_seq_lens_q.numel() - 1 + assert max_q_len is not None + q_len_per_request = max_q_len + query_flat = query + + run_func( + out, + query_flat, + compressed_kv_cache, + swa_kv_cache, + workspace_buffer, + sparse_indices, + seq_lens, + sparse_topk_lens, + bmm1_scale, + bmm2_scale, + batch_size, + q_len_per_request, + get_device_sm_count(query.device), + enable_pdl, + workspace_buffer.numel() * workspace_buffer.element_size(), + sinks, + cum_seq_lens_q, + ) + return out + + # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( "flashinfer.autotuner", @@ -891,6 +990,7 @@ def is_flashinfer_cudnn_fp8_prefill_attn_supported() -> bool: "flashinfer_cute_dsl_fused_moe_nvfp4", "flashinfer_convert_sf_to_mma_layout", "trtllm_fp4_block_scale_moe", + "flashinfer_trtllm_batch_decode_sparse_mla_dsv4", "autotune", "has_flashinfer_moe", "has_flashinfer_comm", diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 12ec5b0fcc66..3db6cc2195ae 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -41,6 +41,7 @@ "int8": torch.int8, "int8_per_token_head": torch.int8, "fp8_per_token_head": torch.uint8, + "fp8_per_tensor": torch.float8_e4m3fn, "fp8_inc": torch.float8_e4m3fn, "fp8_ds_mla": torch.uint8, "turboquant_k8v4": torch.uint8, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 797179076969..99d3b002558d 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -95,8 +95,11 @@ class FlashMLASparseBackend(AttentionBackend): supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "bfloat16", + "fp8_per_tensor", + "fp8_inc", "fp8_ds_mla", "fp8", # alias for fp8_ds_mla + "fp8_e4m3", # alias for fp8_ds_mla ] @staticmethod @@ -668,7 +671,7 @@ def _build_c128a_metadata( # `c128a_global_decode_topk_indices.shape[0]` lines up with q in # `_forward_decode`. The per-token C128A kernel handles non-uniform # query lengths. - (num_decodes, _, num_decode_tokens, num_prefill_tokens) = ( + (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( split_decodes_and_prefills( cm, decode_threshold=self.reorder_batch_threshold or 1, diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index bfa3b7285dbd..8364e044c150 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -73,7 +73,7 @@ def __init__( # determines the SWA block size of 64 tokens per block. # TODO(yifan): make SWA block size automatically determined and configurable. self.block_size = 64 - assert self.dtype == torch.uint8 + assert self.dtype in (torch.uint8, torch.bfloat16, torch.float8_e4m3fn) def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: return SlidingWindowMLASpec( @@ -83,7 +83,10 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: dtype=self.dtype, sliding_window=self.window_size, cache_dtype_str=self.cache_config.cache_dtype, - alignment=576, # NOTE: FlashMLA requires 576B alignment + # The legacy fp8_ds_mla FlashMLA layout needs 576B alignment. + # FlashInfer DSV4 BF16/per-tensor FP8 reads sparse indices as + # flat token offsets, so those cache pages must remain contiguous. + alignment=576 if self.cache_config.cache_dtype == "fp8_ds_mla" else None, model_version="deepseek_v4", ) @@ -166,6 +169,7 @@ class DeepseekSparseSWAMetadata: num_prefill_tokens: int = 0 # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. + prefill_query_start_loc: torch.Tensor | None = None prefill_seq_lens: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None @@ -278,6 +282,7 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + assert seq_lens.dtype == torch.int32 # Split into decode and prefill portions using configurable threshold (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = ( @@ -396,28 +401,34 @@ def _build_deepseek_v4_metadata( # --- Prefill query metadata (single Triton kernel + CPU slicing) --- if num_prefills > 0: + pfx_query_start_loc = torch.empty( + num_prefills + 1, dtype=torch.int32, device=seq_lens.device + ) pfx_gather_lens = torch.empty( num_prefills, dtype=torch.int32, device=seq_lens.device ) _compute_prefill_metadata_kernel[(1,)]( + pfx_query_start_loc, pfx_gather_lens, seq_lens, query_start_loc, num_prefills, num_decodes, self.window_size, - BLOCK_SIZE=triton.next_power_of_2(num_prefills), + BLOCK_SIZE=triton.next_power_of_2(num_prefills + 1), ) + result["prefill_query_start_loc"] = pfx_query_start_loc result["prefill_seq_lens"] = seq_lens[num_decodes:] result["prefill_gather_lens"] = pfx_gather_lens return result -@triton.jit +@triton.jit(do_not_specialize=["num_prefills", "num_decodes", "window_size"]) def _compute_prefill_metadata_kernel( # Outputs + prefill_query_start_loc_ptr, prefill_gather_lens_ptr, # Inputs seq_lens_ptr, @@ -427,8 +438,18 @@ def _compute_prefill_metadata_kernel( window_size, BLOCK_SIZE: tl.constexpr, ): - """Compute prefill gather_lens in a single pass.""" + """Compute prefill-local query offsets and gather_lens in a single pass.""" offset = tl.arange(0, BLOCK_SIZE) + qsl_base = tl.load(query_start_loc_ptr + num_decodes) + + qsl_mask = offset < (num_prefills + 1) + qsl_value = tl.load( + query_start_loc_ptr + num_decodes + offset, + mask=qsl_mask, + other=qsl_base, + ) + tl.store(prefill_query_start_loc_ptr + offset, qsl_value - qsl_base, mask=qsl_mask) + mask = offset < num_prefills seq_len = tl.load(seq_lens_ptr + num_decodes + offset, mask=mask) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index 959a79f292a5..bc247adfb6b4 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .cache_utils import ( + build_flashinfer_mixed_sparse_indices, combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, @@ -13,6 +14,7 @@ __all__ = [ "MXFP4_BLOCK_SIZE", + "build_flashinfer_mixed_sparse_indices", "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", "dequantize_and_gather_k_cache", diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index dfb107b515eb..e71e5fe6f497 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -14,12 +14,26 @@ window indices for sparse prefill. """ +from functools import lru_cache + import torch from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_cutedsl +@lru_cache(maxsize=1) +def _has_dequant_gather_k_cutedsl() -> bool: + if not has_cutedsl(): + return False + try: + from cutlass import cute + + return hasattr(cute.nvgpu, "LoadCacheMode") + except Exception: + return False + + @triton.jit def quantize_and_insert_k_kernel( # Input tensors @@ -364,7 +378,8 @@ def dequantize_and_gather_k_cache( block_size: int, offset: int, ) -> None: - if has_cutedsl(): + assert k_cache.dtype == torch.uint8 + if _has_dequant_gather_k_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. from .dequant_gather_k_cutedsl import dequantize_and_gather_k_cache_cutedsl @@ -412,6 +427,296 @@ def compute_global_topk_indices_and_lens( return global_topk_indices, topk_lens +def build_flashinfer_mixed_sparse_indices( + decode_swa_indices: torch.Tensor, + decode_compressed_indices: torch.Tensor | None, + decode_compressed_topk_lens: torch.Tensor | None, + prefill_topk_indices: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + token_to_req_indices: torch.Tensor, + swa_block_table: torch.Tensor, + swa_block_size: int, + compressed_block_table: torch.Tensor | None, + compressed_block_size: int, + window_size: int, + compress_ratio: int, + topk: int, + decode_compressed_indices_are_local: bool = False, + decode_is_valid_token: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build FlashInfer DSV4 sparse indices for decode-first mixed batches.""" + assert decode_swa_indices.dtype == torch.int32 + assert decode_swa_indices.dim() == 2 + assert decode_swa_indices.shape[-1] == window_size + if decode_compressed_topk_lens is not None: + assert decode_compressed_topk_lens.dtype == torch.int32 + assert prefill_topk_indices.dtype == torch.int32 + assert prefill_topk_indices.dim() == 2 + assert query_start_loc.dtype == torch.int32 + assert seq_lens.dtype == torch.int32 + assert token_to_req_indices.dtype == torch.int32 + assert swa_block_table.dtype == torch.int32 + + num_decode_tokens = decode_swa_indices.shape[0] + num_prefill_tokens = prefill_topk_indices.shape[0] + num_tokens = num_decode_tokens + num_prefill_tokens + assert token_to_req_indices.shape[0] >= num_tokens + if decode_compressed_topk_lens is not None: + assert decode_compressed_topk_lens.shape[0] >= num_decode_tokens + + decode_compressed_topk = 0 + if decode_compressed_indices is None: + decode_compressed_indices = prefill_topk_indices + else: + assert decode_compressed_indices.dtype == torch.int32 + assert decode_compressed_indices.dim() == 2 + assert decode_compressed_indices.shape[0] == num_decode_tokens + decode_compressed_topk = decode_compressed_indices.shape[-1] + if decode_compressed_topk > 0 and decode_compressed_indices_are_local: + assert decode_is_valid_token is not None + assert decode_is_valid_token.dtype == torch.bool + assert decode_is_valid_token.shape[0] >= num_decode_tokens + else: + decode_is_valid_token = token_to_req_indices + + if compressed_block_table is None: + compressed_block_table = swa_block_table + assert compressed_block_table.dtype == torch.int32 + has_decode_compressed_lens = decode_compressed_topk_lens is not None + if decode_compressed_topk_lens is None: + decode_compressed_topk_lens = token_to_req_indices + + padded_topk = max(topk, decode_compressed_topk) + padded_topk = (padded_topk + 3) // 4 * 4 + sparse_indices = torch.empty( + (num_tokens, window_size + padded_topk), + dtype=torch.int32, + device=decode_swa_indices.device, + ) + sparse_topk_lens = torch.empty( + num_tokens, dtype=torch.int32, device=decode_swa_indices.device + ) + if num_tokens == 0: + return sparse_indices, sparse_topk_lens + + window_block_size = triton.next_power_of_2(max(window_size, 1)) + topk_block_size = triton.next_power_of_2(max(padded_topk, 1)) + max_block_size = max(window_block_size, topk_block_size) + num_warps = 4 if max_block_size >= 256 else 1 + + _build_flashinfer_mixed_sparse_indices_kernel[(num_tokens,)]( + sparse_indices, + sparse_indices.stride(0), + sparse_topk_lens, + decode_swa_indices, + decode_swa_indices.stride(0), + decode_compressed_indices, + decode_compressed_indices.stride(0), + decode_compressed_topk_lens, + decode_is_valid_token, + prefill_topk_indices, + prefill_topk_indices.stride(0), + query_start_loc, + seq_lens, + token_to_req_indices, + swa_block_table, + swa_block_table.stride(0), + swa_block_size, + compressed_block_table, + compressed_block_table.stride(0), + compressed_block_size, + NUM_DECODE_TOKENS=num_decode_tokens, + WINDOW_SIZE=window_size, + COMPRESS_RATIO=compress_ratio, + TOP_K=topk, + PADDED_TOP_K=padded_topk, + PREFILL_TOPK_STRIDE=prefill_topk_indices.shape[-1], + DECODE_COMPRESSED_TOPK=decode_compressed_topk, + DECODE_COMPRESSED_INDICES_ARE_LOCAL=decode_compressed_indices_are_local, + HAS_DECODE_COMPRESSED_LENS=has_decode_compressed_lens, + WINDOW_BLOCK_SIZE=window_block_size, + TOPK_BLOCK_SIZE=topk_block_size, + num_warps=num_warps, + ) + return sparse_indices, sparse_topk_lens + + +@triton.jit( + do_not_specialize=[ + "sparse_indices_stride", + "decode_swa_stride", + "decode_compressed_stride", + "prefill_topk_stride", + "swa_block_table_stride", + "swa_block_size", + "compressed_block_table_stride", + "compressed_block_size", + "NUM_DECODE_TOKENS", + "PREFILL_TOPK_STRIDE", + ] +) +def _build_flashinfer_mixed_sparse_indices_kernel( + sparse_indices_ptr, + sparse_indices_stride, + sparse_topk_lens_ptr, + decode_swa_indices_ptr, + decode_swa_stride, + decode_compressed_indices_ptr, + decode_compressed_stride, + decode_compressed_topk_lens_ptr, + decode_is_valid_token_ptr, + prefill_topk_indices_ptr, + prefill_topk_stride, + query_start_loc_ptr, + seq_lens_ptr, + token_to_req_indices_ptr, + swa_block_table_ptr, + swa_block_table_stride, + swa_block_size, + compressed_block_table_ptr, + compressed_block_table_stride, + compressed_block_size, + NUM_DECODE_TOKENS, + WINDOW_SIZE: tl.constexpr, + COMPRESS_RATIO: tl.constexpr, + TOP_K: tl.constexpr, + PADDED_TOP_K: tl.constexpr, + PREFILL_TOPK_STRIDE, + DECODE_COMPRESSED_TOPK: tl.constexpr, + DECODE_COMPRESSED_INDICES_ARE_LOCAL: tl.constexpr, + HAS_DECODE_COMPRESSED_LENS: tl.constexpr, + WINDOW_BLOCK_SIZE: tl.constexpr, + TOPK_BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + + if token_idx < NUM_DECODE_TOKENS: + for i in range(0, WINDOW_SIZE, WINDOW_BLOCK_SIZE): + offset = i + tl.arange(0, WINDOW_BLOCK_SIZE) + mask = offset < WINDOW_SIZE + values = tl.load( + decode_swa_indices_ptr + token_idx * decode_swa_stride + offset, + mask=mask, + other=-1, + ) + tl.store( + sparse_indices_ptr + token_idx * sparse_indices_stride + offset, + values, + mask=mask, + ) + + compressed_len = tl.zeros((), dtype=tl.int32) + for i in range(0, PADDED_TOP_K, TOPK_BLOCK_SIZE): + offset = i + tl.arange(0, TOPK_BLOCK_SIZE) + mask = offset < PADDED_TOP_K + values = tl.load( + decode_compressed_indices_ptr + + token_idx * decode_compressed_stride + + offset, + mask=offset < DECODE_COMPRESSED_TOPK, + other=-1, + ) + if DECODE_COMPRESSED_INDICES_ARE_LOCAL: + token_valid = tl.load(decode_is_valid_token_ptr + token_idx) + is_valid = values >= 0 + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + block_indices = values // compressed_block_size + block_numbers = tl.load( + compressed_block_table_ptr + + req_idx * compressed_block_table_stride + + block_indices, + mask=mask & is_valid, + other=-1, + ) + block_offsets = values % compressed_block_size + values = block_numbers * compressed_block_size + block_offsets + values = tl.where(is_valid, values, -1) + compressed_len += tl.sum((is_valid & token_valid).to(tl.int32), axis=0) + tl.store( + sparse_indices_ptr + + token_idx * sparse_indices_stride + + WINDOW_SIZE + + offset, + values, + mask=mask, + ) + + if DECODE_COMPRESSED_TOPK == 0: + compressed_len = tl.zeros((), dtype=tl.int32) + elif not DECODE_COMPRESSED_INDICES_ARE_LOCAL: + if HAS_DECODE_COMPRESSED_LENS: + compressed_len = tl.load(decode_compressed_topk_lens_ptr + token_idx) + else: + compressed_len = tl.full((), DECODE_COMPRESSED_TOPK, dtype=tl.int32) + + tl.store(sparse_topk_lens_ptr + token_idx, WINDOW_SIZE + compressed_len) + return + + prefill_idx = token_idx - NUM_DECODE_TOKENS + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + query_start = tl.load(query_start_loc_ptr + req_idx) + query_end = tl.load(query_start_loc_ptr + req_idx + 1) + query_len = query_end - query_start + seq_len = tl.load(seq_lens_ptr + req_idx) + start_pos = seq_len - query_len + token_idx_in_query = token_idx - query_start + pos = start_pos + token_idx_in_query + swa_len = tl.minimum(pos + 1, WINDOW_SIZE) + swa_start_pos = pos - swa_len + 1 + topk_len = tl.minimum((pos + 1) // COMPRESS_RATIO, TOP_K) + + for i in range(0, WINDOW_SIZE, WINDOW_BLOCK_SIZE): + offset = i + tl.arange(0, WINDOW_BLOCK_SIZE) + mask = offset < WINDOW_SIZE + pos_offset = swa_start_pos + offset + block_indices = pos_offset // swa_block_size + block_numbers = tl.load( + swa_block_table_ptr + req_idx * swa_block_table_stride + block_indices, + mask=mask & (offset < swa_len), + other=-1, + ) + block_offsets = pos_offset % swa_block_size + slot_ids = block_numbers * swa_block_size + block_offsets + slot_ids = tl.where(offset < swa_len, slot_ids, -1) + tl.store( + sparse_indices_ptr + token_idx * sparse_indices_stride + offset, + slot_ids, + mask=mask, + ) + + for i in range(0, PADDED_TOP_K, TOPK_BLOCK_SIZE): + offset = i + tl.arange(0, TOPK_BLOCK_SIZE) + mask = offset < PADDED_TOP_K + local_idx = tl.load( + prefill_topk_indices_ptr + prefill_idx * prefill_topk_stride + offset, + mask=(offset < PREFILL_TOPK_STRIDE) & (offset < topk_len), + other=-1, + ) + is_valid = local_idx >= 0 + block_indices = local_idx // compressed_block_size + block_numbers = tl.load( + compressed_block_table_ptr + + req_idx * compressed_block_table_stride + + block_indices, + mask=mask & is_valid, + other=-1, + ) + block_offsets = local_idx % compressed_block_size + slot_ids = block_numbers * compressed_block_size + block_offsets + slot_ids = tl.where((offset < topk_len) & is_valid, slot_ids, -1) + tl.store( + sparse_indices_ptr + + token_idx * sparse_indices_stride + + WINDOW_SIZE + + offset, + slot_ids, + mask=mask, + ) + + tl.store(sparse_topk_lens_ptr + token_idx, WINDOW_SIZE + topk_len) + + @triton.jit def _compute_global_topk_indices_and_lens_kernel( global_topk_indices_ptr, @@ -459,7 +764,6 @@ def _compute_global_topk_indices_and_lens_kernel( ) count += tl.sum(is_valid.to(tl.int32), axis=0) - # Zero out length for padding tokens. tl.store(topk_lens_ptr + token_idx, tl.where(is_valid_token, count, 0)) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py index 2f97d8733c95..4e68c312b406 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_compress_quant_cache.py @@ -50,6 +50,7 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( k_cache_ptr, kv_slot_mapping_ptr, kv_cache_block_size, + fp8_scale_ptr, # ── constexprs ── HEAD_SIZE: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr, @@ -62,6 +63,9 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( TOKEN_STRIDE: tl.constexpr, # 576 for DeepseekV4 SCALE_DIM: tl.constexpr, # 8 for DeepseekV4 (7 real + 1 pad) KV_BLOCK_STRIDE: tl.constexpr, + KV_TOKEN_STRIDE: tl.constexpr, + STORE_FULL_CACHE: tl.constexpr, + STORE_FULL_FP8: tl.constexpr, ): """Fused compress → RMSNorm → FP8 quant (nope) → RoPE → bf16 store (rope). @@ -141,16 +145,53 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( kv_block_idx = kv_slot_idx // kv_cache_block_size kv_pos_in_block = kv_slot_idx % kv_cache_block_size + NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM # 448 + HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 # 32 + NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 + NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 + + if STORE_FULL_CACHE: + cache_row = ( + k_cache_ptr + + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE + + kv_pos_in_block * KV_TOKEN_STRIDE + ) + + pair_2d = tl.reshape(normed, (NUM_PAIRS, 2)) + even, odd = tl.split(pair_2d) + pair_idx = tl.arange(0, NUM_PAIRS) + rope_pair_local = pair_idx - NOPE_PAIRS + is_rope_pair = rope_pair_local >= 0 + cs_idx = tl.maximum(rope_pair_local, 0) + + compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO + cache_base = cos_sin_cache_ptr + compressed_pos * cos_sin_stride + cos_v = tl.load(cache_base + cs_idx, mask=is_rope_pair, other=1.0) + sin_v = tl.load(cache_base + HALF_ROPE + cs_idx, + mask=is_rope_pair, + other=0.0) + + new_even = tl.where(is_rope_pair, even * cos_v - odd * sin_v, even) + new_odd = tl.where(is_rope_pair, odd * cos_v + even * sin_v, odd) + result = tl.interleave(new_even, new_odd).to(tl.bfloat16).to(tl.float32) + + if STORE_FULL_FP8: + fp8_scale = tl.load(fp8_scale_ptr) + result = tl.clamp(result / fp8_scale, -448.0, 448.0) + tl.store(cache_row + block, result.to(tl.float8e4nv), mask=mask) + else: + tl.store(cache_row + block, result.to(tl.bfloat16), mask=mask) + return + cache_block_ptr = k_cache_ptr + kv_block_idx.to(tl.int64) * KV_BLOCK_STRIDE - fp8_ptr = cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE + fp8_ptr = (cache_block_ptr + kv_pos_in_block * TOKEN_STRIDE).to( + tl.pointer_type(tl.uint8) + ) scale_ptr = ( cache_block_ptr + kv_cache_block_size * TOKEN_STRIDE + kv_pos_in_block * SCALE_DIM - ) - - NOPE_HEAD_DIM: tl.constexpr = HEAD_SIZE - ROPE_HEAD_DIM # 448 - HALF_ROPE: tl.constexpr = ROPE_HEAD_DIM // 2 # 32 + ).to(tl.pointer_type(tl.uint8)) # FP8 UE8M0 quant: cast fp32 → bf16 → fp32 before quant to match reference. N_QUANT_BLOCKS: tl.constexpr = TRITON_BLOCK_SIZE // QUANT_BLOCK @@ -187,9 +228,6 @@ def _fused_kv_compress_norm_rope_insert_sparse_attn( tl.store(scale_ptr + N_NOPE_BLOCKS, tl.zeros((), dtype=tl.uint8)) # Register-based GPT-J RoPE in fp32. - NUM_PAIRS: tl.constexpr = TRITON_BLOCK_SIZE // 2 - NOPE_PAIRS: tl.constexpr = NOPE_HEAD_DIM // 2 - pair_2d = tl.reshape(normed, (NUM_PAIRS, 2)) even, odd = tl.split(pair_2d) # each [NUM_PAIRS] fp32 diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b57e10b67faa..e0fa89b2c1c8 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1541,7 +1541,11 @@ def _get_kv_cache_groups_uniform_groups( for sm_spec in swa_mla_specs: sm_page_sizes = sm_spec.get_page_sizes() layers_per_size: dict[int, list[str]] = defaultdict(list) - assert max(sm_page_sizes) <= max(all_page_sizes) + if max(sm_page_sizes) > max(all_page_sizes): + raise AssertionError( + "DeepseekV4 SWA page size exceeds full-MLA page sizes: " + f"swa={sorted(sm_page_sizes)}, full={sorted(all_page_sizes)}" + ) # Unify page size by padding layers' page_size to the nearest larger page_size. # Compute candidate (nearest larger page_size) for each unique page size. diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index cf50dbff179a..56e8e1dfab83 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -482,10 +482,11 @@ def storage_block_size(self) -> int: @property def real_page_size_bytes(self) -> int: - if self.model_version == "deepseek_v4": - # DeepseekV4: 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token. + if self.model_version == "deepseek_v4" and self.cache_dtype_str == "fp8_ds_mla": + # DeepseekV4 legacy UE8M0 layout: + # 448B NoPE + 128B RoPE + 8B fp8 scale = 584B per token. return self.storage_block_size * 584 - assert self.model_version is None, ( + assert self.model_version in (None, "deepseek_v4"), ( f"Unsupported model version: {self.model_version}" ) return (