diff --git a/CMakeLists.txt b/CMakeLists.txt index bf4ac05e4f29..13788fa87437 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,12 +307,12 @@ set(VLLM_EXT_SRC "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/custom_all_reduce.cu" - "csrc/torch_bindings.cpp") + "csrc/torch_bindings.cpp" + "csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC - "csrc/minimax_reduce_rms_kernel.cu" - "csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu") + "csrc/minimax_reduce_rms_kernel.cu") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") @@ -1047,13 +1047,13 @@ endif() set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" - "csrc/moe/topk_softmax_kernels.cu") + "csrc/moe/topk_softmax_kernels.cu" + "csrc/moe/topk_softplus_sqrt_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" - "csrc/moe/grouped_topk_kernels.cu" - "csrc/moe/topk_softplus_sqrt_kernels.cu") + "csrc/moe/grouped_topk_kernels.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index e96017d86dad..f328e3acb79b 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -29,7 +29,11 @@ */ #include -#include +#ifndef USE_ROCM + #include +#else + #include +#endif #include #include @@ -42,7 +46,27 @@ #include "type_convert.cuh" #ifndef FINAL_MASK - #define FINAL_MASK 0xffffffffu + #ifdef USE_ROCM + #define FINAL_MASK 0xffffffffffffffffULL + #else + #define FINAL_MASK 0xffffffffu + #endif +#endif + +#ifdef USE_ROCM +// ROCm-compatible FP8 conversion helpers +__device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) { + // HIP defines HIP_FP8_TYPE_OCP based on HIP version, not GPU arch. On gfx942 + // mfma only supports FNUZ fp8, and the rest of vLLM's gfx942 path (Triton + // indexer / current_platform.fp8_dtype()) uses FNUZ. Gate OCP on __gfx950__ + // so the K cache encoding matches what the reader expects. + #if defined(HIP_FP8_TYPE_OCP) && defined(__gfx950__) + __hip_fp8_e4m3 fp8_val(val); + #else + __hip_fp8_e4m3_fnuz fp8_val(val); + #endif + return reinterpret_cast(fp8_val); +} #endif namespace vllm { @@ -65,7 +89,13 @@ constexpr int kQuantBlock = 64; constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 +// Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942 +// (max 240), OCP on gfx950 (max 448). +#if defined(USE_ROCM) && (!defined(HIP_FP8_TYPE_OCP) || !defined(__gfx950__)) +constexpr float kFp8Max = 240.0f; +#else constexpr float kFp8Max = 448.0f; +#endif // Per-warp layout: 32 lanes × 16 elems/lane = 512 elems = HEAD_DIM. constexpr int kNumLanes = 32; @@ -314,9 +344,13 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( for (int i = 0; i < kElemsPerLane; i++) { float scaled = elements[i] * inv_scale; scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max); +#ifndef USE_ROCM __nv_fp8_storage_t s = __nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3); out_bytes[i] = static_cast(s); +#else + out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); +#endif } // One 16-byte STG per lane. *reinterpret_cast(token_fp8_ptr + dim_base) = @@ -384,6 +418,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( // PDL: enable programmatic stream serialization whenever the hardware // supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable, // so leave numAttrs = 0 and launch as a regular kernel. +#ifndef USE_ROCM static int const sm_version = getSMVersion(); // Host-side guard: the device kernel body is compiled as a no-op for // bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert is @@ -410,6 +445,15 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, kv_block_stride); +#else + // ROCm: use standard kernel launch syntax (no PDL/stream serialization) + // clang-format off + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + <<>>( + q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, + eps, num_tokens_full, num_tokens_insert, num_heads_q, + cache_block_size, kv_block_stride); +#endif } } // namespace deepseek_v4_fused_ops diff --git a/csrc/moe/topk_softplus_sqrt_kernels.cu b/csrc/moe/topk_softplus_sqrt_kernels.cu index 50a8540a7374..43d461a0179a 100644 --- a/csrc/moe/topk_softplus_sqrt_kernels.cu +++ b/csrc/moe/topk_softplus_sqrt_kernels.cu @@ -60,15 +60,6 @@ __device__ __forceinline__ float toFloat(T value) { } } -#define FINAL_MASK 0xffffffff -template -__inline__ __device__ T warpReduceSum(T val) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); - return val; -} - // ====================== TopK softplus_sqrt things // =============================== @@ -272,8 +263,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ } } // Compute per-thread scale (using warp reduction when renormalizing). + // THREADS_PER_ROW-parameterized butterfly works for both warp sizes (32 + // on CUDA, 64 on ROCm CDNA) and any THREADS_PER_ROW the dispatch picks. if (renormalize) { - selected_sum = warpReduceSum(selected_sum); +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + selected_sum += + VLLM_SHFL_XOR_SYNC_WIDTH(selected_sum, mask, THREADS_PER_ROW); + } } float scale = static_cast(routed_scaling_factor); if (renormalize) { @@ -544,7 +541,6 @@ void topkGatingSoftplusSqrtKernelLauncher( const IndType* tid2eid, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; -#ifndef USE_ROCM // for bfloat16 dtype, we need 4 bytes loading to make sure num_experts // elements can be loaded by a warp static constexpr int BYTES_PER_LDG_MULTIPLE_64 = @@ -552,6 +548,19 @@ void topkGatingSoftplusSqrtKernelLauncher( std::is_same_v) ? 4 : 8; + // Narrower LDG (ELTS_PER_LDG=1) used by 192/320/448/576 on ROCm WARP_SIZE=64 + // where ELTS_PER_LDG=2 fails the EXPERTS%(ELTS_PER_LDG*WARP_SIZE)==0 check. + // On CUDA WARP_SIZE=32 the wider LDG already aligns, so the alias collapses + // back to BYTES_PER_LDG_MULTIPLE_64 — no behavioral change for CUDA. +#ifdef USE_ROCM + static constexpr int BYTES_PER_LDG_MULTIPLE_64_NARROW = + (std::is_same_v || + std::is_same_v) + ? 2 + : 4; +#else + static constexpr int BYTES_PER_LDG_MULTIPLE_64_NARROW = + BYTES_PER_LDG_MULTIPLE_64; #endif switch (num_experts) { case 1: @@ -584,27 +593,29 @@ void topkGatingSoftplusSqrtKernelLauncher( case 512: LAUNCH_SOFTPLUS_SQRT(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; - // (CUDA only) support multiples of 64 when num_experts is not power of 2. - // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of - // num_experts, alternatively we can test 4 bytes loading and enable it in - // future. -#ifndef USE_ROCM + // Multiples of 64 that are not powers of 2. The kernel requires + // EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0. With ELTS_PER_LDG=2 + // (BYTES_PER_LDG_MULTIPLE_64), this holds for all five values on CUDA + // WARP_SIZE=32 but only for 384 on ROCm WARP_SIZE=64. The other four + // use BYTES_PER_LDG_MULTIPLE_64_NARROW (ELTS_PER_LDG=1), which + // satisfies the assertion for any multiple of 64 on either backend; + // on CUDA the narrow alias collapses back to the wider load, so CUDA + // behavior is unchanged. case 192: - LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW); break; case 320: - LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW); break; case 384: LAUNCH_SOFTPLUS_SQRT(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; case 448: - LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW); break; case 576: - LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW); break; -#endif default: { TORCH_CHECK(false, "Unsupported expert number: ", num_experts); } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index b737cb54353c..8940e341cd01 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -16,14 +16,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "bias) -> ()"); m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid); -#ifndef USE_ROCM m.def( "topk_softplus_sqrt(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output, bool renormalize, float " "routed_scaling_factor, Tensor? " "bias, Tensor? input_ids, Tensor? tid2eid) -> ()"); m.impl("topk_softplus_sqrt", torch::kCUDA, &topk_softplus_sqrt); -#endif + // Calculate the result of moe by summing up the partial results // from all selected experts. m.def("moe_sum(Tensor input, Tensor! output) -> ()"); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8d8f7bed0441..e695497fd88f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -183,7 +183,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int forced_token_heads_per_warp=-1) -> ()"); ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); -#ifndef USE_ROCM // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one // kernel launch. @@ -194,7 +193,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float eps, int cache_block_size) -> ()"); ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); -#endif // Apply repetition penalties to logits in-place ops.def( diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index bdbe46ad9177..4a2e909d50ba 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -216,7 +216,7 @@ configuration. | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 0b472b90c026..037b20874b52 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -21,3 +21,6 @@ timm>=1.0.17 # amd-quark: required for Quark quantization on ROCm # To be consistent with test_quark.py amd-quark>=0.8.99 +# tilelang has to be installed for mhc module to be +# imported correctly. +tilelang==0.1.9 diff --git a/tests/kernels/moe/test_topk_softplus_sqrt.py b/tests/kernels/moe/test_topk_softplus_sqrt.py index 7f5aacb383db..1b68213fafef 100644 --- a/tests/kernels/moe/test_topk_softplus_sqrt.py +++ b/tests/kernels/moe/test_topk_softplus_sqrt.py @@ -70,7 +70,8 @@ def test_sqrtsoftplus_bias_uses_deepseek_v4_routing_method(): @pytest.mark.skipif( - not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." + not current_platform.is_cuda_alike(), + reason="This test is skipped on non-CUDA platform.", ) @pytest.mark.parametrize("num_tokens", [1, 33, 128]) @pytest.mark.parametrize("hidden_size", [1024, 2048]) @@ -125,7 +126,8 @@ def test_fused_topk_softplus_sqrt( @pytest.mark.skipif( - not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." + not current_platform.is_cuda_alike(), + reason="This test is skipped on non-CUDA platform.", ) @pytest.mark.parametrize("num_tokens", [1, 33, 128]) @pytest.mark.parametrize("hidden_size", [1024, 2048]) diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index 93fb4c54b7f1..3de89eacdd34 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -115,6 +115,7 @@ def with_default( "flashinfer_cutlass", "flashinfer_cutedsl", "marlin", + "triton_unfused", "aiter", "emulation", ] @@ -145,6 +146,7 @@ class KernelConfig: - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only) - "marlin": Use Marlin kernels (weight-only quantization) + - "triton_unfused": Use Triton unfused MoE kernels - "aiter": Use AMD AITer kernels (ROCm only) - "emulation": use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations. diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 8a8650d22135..5ded5ca798ad 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -312,6 +312,21 @@ def apply_block_scaled_mm( As: torch.Tensor, Bs: torch.Tensor, ) -> torch.Tensor: + if As.dtype != Bs.dtype: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + if As.dtype == torch.float8_e8m0fnu: + As = _upcast_e8m0_to_fp32(As).contiguous() + else: + As = As.to(torch.float32) + + if Bs.dtype == torch.float8_e8m0fnu: + Bs = _upcast_e8m0_to_fp32(Bs).contiguous() + else: + Bs = Bs.to(torch.float32) + out_dtype = self.config.out_dtype if self.use_triton: gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 59cc95f18c58..df9459012ae8 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -169,7 +169,9 @@ class SiluAndMulWithClamp(CustomOp): def __init__(self, swiglu_limit: float, *, compile_native: bool = True): super().__init__(compile_native=compile_native) self.swiglu_limit = float(swiglu_limit) - if current_platform.is_cuda_alike() or current_platform.is_xpu(): + if current_platform.is_rocm(): + self._forward_method = self.forward_native + elif current_platform.is_cuda_alike() or current_platform.is_xpu(): self.op = torch.ops._C.silu_and_mul_with_clamp elif current_platform.is_cpu(): self._forward_method = self.forward_native diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index cae80c35316a..1d69e140b669 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -174,6 +174,54 @@ def get_attn_backend(self) -> type[AttentionBackend]: return CompressorBackend +def hadamard_transform_ref(x: torch.Tensor) -> torch.Tensor: + hidden_size = x.shape[-1] + assert hidden_size > 0 and (hidden_size & (hidden_size - 1)) == 0, ( + f"Hidden size must be a power of 2, got {hidden_size}" + ) + dtype = x.dtype + y = x.to(torch.float32).reshape(-1, hidden_size) + h = 1 + while h < hidden_size: + y = y.view(-1, hidden_size // (2 * h), 2, h) + a = y[:, :, 0, :] + b = y[:, :, 1, :] + y = torch.cat((a + b, a - b), dim=-1) + h *= 2 + y = y.view(*x.shape) * (hidden_size**-0.5) + return y.to(dtype) + + +def apply_gptj_rope_ref( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if rope_dim == 0 or x.numel() == 0: + return x + half_rot = rope_dim // 2 + nope_dim = x.shape[-1] - rope_dim + dtype = x.dtype + x = x.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) + view_shape = (positions.shape[0],) + (1,) * (x.dim() - 2) + (half_rot,) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = x[..., nope_dim:] + x_even = rope[..., 0::2] + x_odd = rope[..., 1::2] + rope_out = torch.stack( + (x_even * cos - x_odd * sin, x_odd * cos + x_even * sin), + dim=-1, + ).flatten(-2) + x = x.clone() + x[..., nope_dim:] = rope_out + return x.to(dtype) + + class DeepseekCompressor(nn.Module): def __init__( self, @@ -300,6 +348,7 @@ def forward( state_cache = self.state_cache.kv_cache # kv_state stored in first half, score_state stored in second half state_width = state_cache.shape[-1] // 2 + pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False} # Store the KV and score (with fused APE addition) in the state. # NOTE: PDL is disabled — both this kernel and _fused_kernel below @@ -324,7 +373,7 @@ def forward( TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), STATE_WIDTH=state_width, COMPRESS_RATIO=self.compress_ratio, - launch_pdl=False, + **pdl_kwargs, ) # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. @@ -373,7 +422,7 @@ def forward( SCALE_DIM=self._scale_dim, KV_BLOCK_STRIDE=kv_cache.stride(0), num_warps=self._num_warps, - launch_pdl=False, + **pdl_kwargs, ) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 74b494dc4851..eae9f81da8b0 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,6 +4,7 @@ DeepseekV4 MLA Attention Layer """ +import math from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -43,7 +44,11 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor +from vllm.model_executor.layers.deepseek_compressor import ( + DeepseekCompressor, + apply_gptj_rope_ref, + hadamard_transform_ref, +) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import ( @@ -52,6 +57,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) +from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, @@ -197,8 +203,6 @@ def __init__( # Pick fp8_einsum recipe based on GPU arch: # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 - from vllm.platforms import current_platform - cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) @@ -209,6 +213,8 @@ def __init__( self.topk_indices_buffer = mla_modules.topk_indices_buffer self.indexer = mla_modules.indexer + # Keep ROCm on the BF16 reference wo_a path util kernel ready. + self.use_ref_wo_a_path = current_platform.is_rocm() # Per-head RMS normalization for Q (no learnable weights) self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) @@ -221,6 +227,7 @@ def __init__( + 1 # 1B pad ) + # Will be None on ROCm for now. self.aux_stream_list = mla_modules.aux_stream_list # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins @@ -302,6 +309,32 @@ def forward( ) o = o_padded[:, : self.n_local_heads, :] + if self.use_ref_wo_a_path: + o_ref = _apply_inv_rope_ref( + self.rotary_emb, o, positions, self.rope_head_dim + ).to(torch.bfloat16) + o_ref = o_ref.view(num_tokens, self.n_local_groups, -1) + + hidden_dim = o_ref.shape[-1] + if hasattr(self.wo_a, "weight_scale_inv"): + wo_a_weight = self.wo_a.weight.view( + self.n_local_groups, self.o_lora_rank, hidden_dim + ).to(torch.float32) + wo_a_scale = _expand_2d_block_scales( + self.wo_a.weight_scale_inv.view( + self.n_local_groups, -1, self.wo_a.weight_scale_inv.shape[-1] + ), + self.o_lora_rank, + hidden_dim, + ) + wo_a_weight = (wo_a_weight * wo_a_scale).to(torch.bfloat16) + else: + wo_a_weight = self.wo_a.weight.view( + self.n_local_groups, self.o_lora_rank, hidden_dim + ).to(torch.bfloat16) + z = torch.einsum("tgd,grd->tgr", o_ref, wo_a_weight) + return self.wo_b(z.flatten(1)) + # O projection: inverse RoPE + FP8 quant + einsum + wo_b o_fp8, o_scale = fused_inv_rope_fp8_quant( o, @@ -335,12 +368,15 @@ def forward( return self.wo_b(z.flatten(1)) def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: - assert self.aux_stream_list is not None - assert len(self.aux_stream_list) >= 3 + aux_streams = self.aux_stream_list + if aux_streams is not None: + assert len(aux_streams) >= 3 + aux_streams = aux_streams[:3] # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] # is the fan-out start event; ln_events[1..3] are per-aux done events. + # On ROCm, aux_streams is None and execute_in_parallel runs serially. aux_fns: list[Callable[[], Any] | None] = [None, None, None] if self.compressor is not None: @@ -379,12 +415,19 @@ def fused_wqa_wkv() -> torch.Tensor: qr_kv, _ = self.fused_wqa_wkv(hidden_states) return qr_kv + if aux_streams is None: + qr_kv = fused_wqa_wkv() + kv_score = aux_fns[0]() if aux_fns[0] is not None else None + indexer_weights = aux_fns[1]() if aux_fns[1] is not None else None + indexer_kv_score = aux_fns[2]() if aux_fns[2] is not None else None + return qr_kv, kv_score, indexer_kv_score, indexer_weights + qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( fused_wqa_wkv, aux_fns, self.ln_events[0], self.ln_events[1:4], - self.aux_stream_list[:3], + aux_streams, ) return qr_kv, kv_score, indexer_kv_score, indexer_weights @@ -416,8 +459,9 @@ def attention_impl( # downstream reads q on default). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. if self.indexer is not None: - assert self.aux_stream_list is not None - aux_stream = self.aux_stream_list[0] + aux_stream = ( + self.aux_stream_list[0] if self.aux_stream_list is not None else None + ) indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None @@ -445,8 +489,9 @@ def wq_b_kv_insert_and_compress() -> torch.Tensor: ) elif self.compressor is not None: # wq_b + kv_insert on default, compressor on aux. - assert self.aux_stream_list is not None - aux_stream = self.aux_stream_list[0] + aux_stream = ( + self.aux_stream_list[0] if self.aux_stream_list is not None else None + ) compressor = self.compressor def wq_b_kv_insert() -> torch.Tensor: @@ -568,7 +613,128 @@ def deepseek_v4_fp8_einsum( equation: str, recipe: list[int], ) -> None: - fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + try: + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + except RuntimeError as exc: + if "DeepGEMM backend is not available or outdated" not in str(exc): + raise + _deepseek_v4_fp8_einsum_fallback(a, a_scale, b, b_scale, out, equation) + + +def _decode_e8m0_scales(scale: torch.Tensor) -> torch.Tensor: + if scale.dtype == torch.float8_e8m0fnu: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + return _upcast_e8m0_to_fp32(scale).contiguous() + return scale.to(torch.float32) + + +def _expand_last_dim_scales(scale: torch.Tensor, last_dim: int) -> torch.Tensor: + scale = _decode_e8m0_scales(scale) + block = math.ceil(last_dim / scale.shape[-1]) + return torch.repeat_interleave(scale, block, dim=-1)[..., :last_dim] + + +def _expand_2d_block_scales( + scale: torch.Tensor, + rows: int, + cols: int, +) -> torch.Tensor: + scale = _decode_e8m0_scales(scale) + row_blocks, col_blocks = scale.shape[-2:] + row_block = math.ceil(rows / row_blocks) + col_block = math.ceil(cols / col_blocks) + scale = torch.repeat_interleave(scale, row_block, dim=-2)[..., :rows, :] + scale = torch.repeat_interleave(scale, col_block, dim=-1)[..., :, :cols] + return scale + + +def _apply_gptj_inv_rope_ref( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if rope_dim == 0 or x.numel() == 0: + return x + half_rot = rope_dim // 2 + nope_dim = x.shape[-1] - rope_dim + dtype = x.dtype + x = x.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) + view_shape = (positions.shape[0],) + (1,) * (x.dim() - 2) + (half_rot,) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = x[..., nope_dim:] + y_even = rope[..., 0::2] + y_odd = rope[..., 1::2] + rope_out = torch.stack( + (y_even * cos + y_odd * sin, y_odd * cos - y_even * sin), + dim=-1, + ).flatten(-2) + x = x.clone() + x[..., nope_dim:] = rope_out + return x.to(dtype) + + +def _apply_inv_rope_ref( + rotary_emb: torch.nn.Module, + x: torch.Tensor, + positions: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if hasattr(rotary_emb, "forward_native"): + try: + query, _ = rotary_emb.forward_native( + positions, + x.clone(), + None, + inverse=True, + ) + return query + except TypeError: + pass + return _apply_gptj_inv_rope_ref(x, positions, rotary_emb.cos_sin_cache, rope_dim) + + +def _deepseek_v4_fp8_einsum_fallback( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, +) -> None: + if equation != "bhr,hdr->bhd": + raise RuntimeError(f"Unsupported fallback equation: {equation}") + + num_groups = a.shape[1] + hidden_dim = a.shape[2] + output_dim = b.shape[0] // num_groups + + if b.shape[0] % num_groups != 0: + raise RuntimeError( + f"Cannot reshape weight of shape {tuple(b.shape)} into " + f"({num_groups}, {output_dim}, {hidden_dim})." + ) + + a_deq = (a.to(torch.float32) * _expand_last_dim_scales(a_scale, hidden_dim)).to( + torch.bfloat16 + ) + + b_deq = b.view(num_groups, output_dim, hidden_dim).to(torch.float32) + b_scale_deq = _expand_2d_block_scales( + b_scale.view(num_groups, -1, b_scale.shape[-1]), + output_dim, + hidden_dim, + ) + b_deq = (b_deq * b_scale_deq).to(torch.bfloat16) + + out.copy_(torch.einsum(equation, a_deq, b_deq).to(out.dtype)) def deepseek_v4_fp8_einsum_fake( @@ -665,8 +831,11 @@ def __init__( vllm_config.scheduler_config.max_num_batched_tokens ) self.max_model_len = vllm_config.model_config.max_model_len - # DeepseekV4 only supports fp8 kv-cache format for now + # DeepseekV4 only supports fp8 kv-cache format for now. Treat "auto" + # as the model default and normalize it before the fp8-only checks. kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" + if kv_cache_dtype == "auto": + kv_cache_dtype = "fp8" assert kv_cache_dtype.startswith("fp8"), ( f"DeepseekV4 only supports fp8 kv-cache format for now, " @@ -813,6 +982,20 @@ def _forward_decode( swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens + if current_platform.is_rocm(): + self._forward_decode_fallback( + q=q, + kv_cache=kv_cache, + swa_metadata=swa_metadata, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + output=output, + ) + return + # We treat queries in the same seq as different queries # and later we only attend by generated indices. # q arrives pre-padded to self.padded_heads by the outer wrapper. @@ -941,6 +1124,7 @@ def _forward_prefill( block_table=block_table[chunk_start:chunk_end], block_size=attn_metadata.block_size // self.compress_ratio, offset=0, + use_fnuz=current_platform.is_fp8_fnuz(), ) # Gather SWA KV @@ -953,6 +1137,7 @@ def _forward_prefill( block_table=swa_block_table[chunk_start:chunk_end], block_size=swa_metadata.block_size, offset=N, + use_fnuz=current_platform.is_fp8_fnuz(), ) # Combine the topk indices and SWA indices for gathered KV cache @@ -977,15 +1162,243 @@ def _forward_prefill( N, ) - output_chunk, _, _ = flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], + if current_platform.is_rocm(): + output_chunk = self._ref_sparse_attn_prefill( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + topk_length=combined_lens, + ) + output[query_start:query_end].copy_(output_chunk.to(output.dtype)) + else: + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) + + def _decode_e8m0_scales(self, scale: torch.Tensor) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + return _upcast_e8m0_to_fp32(scale.contiguous()) + + def _dequantize_cache_rows(self, rows: torch.Tensor) -> torch.Tensor: + rows = rows.reshape(-1, rows.shape[-1]) + fp8_dtype = current_platform.fp8_dtype() + fp8_dim = self.nope_head_dim + rope_bytes = self.rope_head_dim * 2 + scale_dim = fp8_dim // 64 + + fp8_vals = rows[:, :fp8_dim].contiguous().view(fp8_dtype) + rope_vals = ( + rows[:, fp8_dim : fp8_dim + rope_bytes].contiguous().view(torch.bfloat16) + ) + scale_bytes = rows[ + :, fp8_dim + rope_bytes : fp8_dim + rope_bytes + scale_dim + ].contiguous() + scales = self._decode_e8m0_scales(scale_bytes) + scales = torch.repeat_interleave(scales, 64, dim=-1) + nope = fp8_vals.to(torch.float32) * scales + return torch.cat([nope, rope_vals.to(torch.float32)], dim=-1).to(torch.bfloat16) + + def _gather_dequantized_cache_tokens( + self, + cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, + ) -> torch.Tensor: + if slot_ids.numel() == 0: + return torch.empty( + (0, self.head_dim), dtype=torch.bfloat16, device=cache.device ) + slot_ids = slot_ids.to(torch.int64) + rows = cache[slot_ids // block_size, slot_ids % block_size] + return self._dequantize_cache_rows(rows).reshape(-1, self.head_dim) + + def _forward_decode_fallback( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_metadata: "DeepseekSparseSWAMetadata", + swa_only: bool, + topk_indices: torch.Tensor | None, + topk_lens: torch.Tensor | None, + swa_indices: torch.Tensor, + swa_lens: torch.Tensor, + output: torch.Tensor, + ) -> None: + blocked_swa = self._dequantize_blocked_k_cache(self.swa_cache_layer.kv_cache) + blocked_extra = None if swa_only else self._dequantize_blocked_k_cache(kv_cache) + attn_out = self._ref_sparse_attn_decode( + q=q.unsqueeze(1), + blocked_k=blocked_swa, + indices_in_kvcache=swa_indices.unsqueeze(1), + topk_length=swa_lens, + attn_sink=self.attn_sink[: q.shape[1]], + extra_blocked_k=blocked_extra, + extra_indices_in_kvcache=topk_indices, + extra_topk_length=topk_lens, + ) + output.copy_(attn_out.to(output.dtype)) + + def _dequantize_blocked_k_cache(self, quant_k_cache: torch.Tensor) -> torch.Tensor: + fp8_dtype = current_platform.fp8_dtype() + d = self.head_dim + d_nope = self.nope_head_dim + d_rope = self.rope_head_dim + tile_size = 64 + num_tiles = d_nope // tile_size + + num_blocks, block_size, _ = quant_k_cache.shape + quant_k_cache = quant_k_cache.view(num_blocks, -1) + input_nope_rope = quant_k_cache[:, : block_size * (d_nope + 2 * d_rope)].view( + num_blocks, block_size, d_nope + 2 * d_rope + ) + input_nope = input_nope_rope[:, :, :d_nope].view(fp8_dtype) + input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16) + input_scale = ( + quant_k_cache[:, block_size * (d_nope + 2 * d_rope) :] + .view(num_blocks, block_size, 8)[:, :, :num_tiles] + .view(torch.float8_e8m0fnu) + ) + + result = torch.empty( + (num_blocks, block_size, 1, d), + dtype=torch.bfloat16, + device=quant_k_cache.device, + ) + result[..., d_nope:] = input_rope.unsqueeze(2) + for tile_idx in range(num_tiles): + cur_nope = input_nope[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].to(torch.bfloat16) + cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1) + result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_nope * cur_scales + ).unsqueeze(2) + return result + + def _ref_sparse_attn_prefill( + self, + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + topk_length: torch.Tensor | None, + ) -> torch.Tensor: + indices = indices.clone().squeeze(1) + s_q, h_q, d_qk = q.shape + topk = indices.shape[-1] + s_kv = kv.shape[0] + if topk_length is not None: + mask = torch.arange(topk, device=indices.device).unsqueeze( + 0 + ) >= topk_length.unsqueeze(1) + indices[mask] = -1 + invalid_mask = (indices < 0) | (indices >= s_kv) + indices[invalid_mask] = 0 + + qf = q.float() + gathered_kv = ( + kv.index_select(0, indices.flatten()).reshape(s_q, topk, d_qk).float() + ) + scores = qf @ gathered_kv.transpose(1, 2) + scores *= self.scale + scores[invalid_mask.unsqueeze(1).expand_as(scores)] = float("-inf") + + orig_lse = torch.logsumexp(scores, dim=-1) + lse_for_o = orig_lse + if self.attn_sink is not None: + lse_for_o = torch.logsumexp( + torch.stack( + [orig_lse, self.attn_sink[:h_q].view(1, h_q).expand_as(orig_lse)], + dim=0, + ), + dim=0, + ) + lse_for_o = lse_for_o.clone() + lse_for_o[lse_for_o == float("-inf")] = float("+inf") + probs = torch.exp(scores - lse_for_o.unsqueeze(-1)) + out = probs @ gathered_kv[..., : self.head_dim] + lonely_q_mask = orig_lse == float("-inf") + out[lonely_q_mask.unsqueeze(-1).expand_as(out)] = 0.0 + return out.to(torch.bfloat16) + + def _ref_sparse_attn_decode( + self, + q: torch.Tensor, + blocked_k: torch.Tensor, + indices_in_kvcache: torch.Tensor, + topk_length: torch.Tensor | None, + attn_sink: torch.Tensor | None, + extra_blocked_k: torch.Tensor | None = None, + extra_indices_in_kvcache: torch.Tensor | None = None, + extra_topk_length: torch.Tensor | None = None, + ) -> torch.Tensor: + b, s_q, h_q, d_qk = q.shape + d_v = self.head_dim + + def process_scope( + cur_blocked_k: torch.Tensor, + cur_indices: torch.Tensor, + cur_topk_length: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + cur_indices = cur_indices.reshape(b, s_q, -1) + topk = cur_indices.size(-1) + fixed_indices = torch.clamp_min(cur_indices, 0) + gathered_kv = ( + cur_blocked_k.view(-1, d_qk) + .index_select(0, fixed_indices.view(-1)) + .view(b, s_q, topk, d_qk) + ) + invalid_mask = cur_indices == -1 + if cur_topk_length is not None: + cur_topk_length = cur_topk_length.reshape(b) + invalid_mask |= torch.arange(0, topk, device=invalid_mask.device).view( + 1, 1, topk + ) >= cur_topk_length.view(b, 1, 1) + return gathered_kv, invalid_mask + + gathered_kv, invalid_mask = process_scope( + blocked_k, indices_in_kvcache, topk_length + ) + if extra_blocked_k is not None: + assert extra_indices_in_kvcache is not None + gathered_kv1, invalid_mask1 = process_scope( + extra_blocked_k, extra_indices_in_kvcache, extra_topk_length + ) + gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2) + invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2) + + gathered_kv = gathered_kv.view(b * s_q, -1, d_qk).float() + gathered_kv[gathered_kv != gathered_kv] = 0.0 + qf = q.float().view(b * s_q, h_q, d_qk) + attn_weight = qf @ gathered_kv.transpose(-1, -2) + attn_weight *= self.scale + attn_weight[ + invalid_mask.view(b * s_q, 1, -1).expand( + b * s_q, h_q, invalid_mask.size(-1) + ) + ] = float("-inf") + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1)) + output = attn_weight @ gathered_kv[..., :d_v] + output = output.view(b, s_q, h_q, d_v) + lse = lse.view(b, s_q, h_q) + + if attn_sink is not None: + output *= ( + 1.0 / (1.0 + torch.exp(attn_sink.view(1, 1, h_q) - lse)) + ).unsqueeze(-1) + + lonely_q_mask = lse == float("-inf") + output[lonely_q_mask.unsqueeze(-1).expand_as(output)] = 0.0 + return output.squeeze(1).to(torch.bfloat16) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): @@ -1150,3 +1563,25 @@ def forward( use_fp4=self.use_fp4_kv, ) return self.indexer_op(hidden_states, q_quant, k, weights) + + def _quantize_indexer_q_torch( + self, + q: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + q = apply_gptj_rope_ref(q, positions, cos_sin_cache, self.rope_dim).to( + torch.bfloat16 + ) + q = hadamard_transform_ref(q).to(torch.float32) + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else 448.0 + q_scale = torch.abs(q).amax(dim=-1).clamp(min=1e-12) / fp8_max + q_quant = (q / q_scale.unsqueeze(-1)).to(current_platform.fp8_dtype()) + weights = ( + weights.to(torch.float32) + * q_scale + * self.softmax_scale + * (self.n_head**-0.5) + ) + return q_quant, weights diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index c1423362d737..da0c8a688a1f 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, FusedMoEQuantDesc, + RoutingMethodType, mxfp4_mxfp8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, @@ -217,6 +218,8 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: TRTLLM MXFP8; SM90 falls through to Triton_unfused or Marlin (the backend-level ``is_supported_config`` check filters by device capability). """ + if current_platform.is_rocm(): + return [Mxfp4MoeBackend.AITER] _AVAILABLE_BACKENDS = [ Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.DEEPGEMM_MXFP4, @@ -484,8 +487,22 @@ def _return_or_raise( activation_format, ) + # DeepSeek-V4 on ROCm is more accurate with the unfused Triton MXFP4 path + # than the default AITER path. Prefer Triton-unfused for this routing mode, + # while keeping AITER as a fallback if Triton-unfused rejects the config. + if ( + current_platform.is_rocm() + and config.routing_method == RoutingMethodType.DeepseekV4 + ): + priority_backends = [ + Mxfp4MoeBackend.TRITON_UNFUSED, + Mxfp4MoeBackend.AITER, + ] + else: + priority_backends = _get_priority_backends() + # Iterate priority backends: TRTLLM MXFP8, then Triton. - for backend in _get_priority_backends(): + for backend in priority_backends: activation_key = _backend_activation_key(backend) for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( @@ -1107,6 +1124,64 @@ def convert_weight_to_mxfp4_moe_kernel_format( w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER: + from vllm._aiter_ops import rocm_aiter_ops + + if w13_bias is not None: + w13_bias = w13_bias.data.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.data.to(torch.float32) + + e, n, k = w13_weight.shape + + w13_weight.view(torch.uint8).copy_( + w13_weight.data.view(torch.uint8) + .view(e, n // 2, 2, k) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, k) + ) + w13_weight_scale.data = ( + w13_weight_scale.data.view(e, n // 2, 2, -1) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, -1) + ) + + w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2) + w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2) + + w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True) + shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4( + w13_weight_scale.view(-1, w13_weight_scale.shape[-1]), + num_experts, + True, + ) + + w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False) + shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4( + w2_weight_scale.view(-1, w2_weight_scale.shape[-1]), + num_experts, + False, + ) + + if w13_bias is not None: + w13_bias = ( + w13_bias.data.view(-1, n // 2, 2) + .permute(0, 2, 1) + .contiguous() + .view(-1, n) + ) + + return ( + w13_weight, + w2_weight, + shuffled_w13_scale, + shuffled_w2_scale, + w13_bias, + w2_bias, + ) + elif mxfp4_backend in TRITON_BACKENDS: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -1162,7 +1237,7 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor: else: raise ValueError( f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. " - f"Expected TRTLLM or Triton backend." + f"Expected TRTLLM, Triton, or AITER backend." ) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index 1521a6b601bf..68b47a89b1a6 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -234,6 +234,39 @@ def mhc_pre( num_tokens = residual_flat.shape[0] fn_flat = fn + if current_platform.is_rocm(): + x = residual_flat.view(num_tokens, hc_mult * hidden_size).to(torch.float32) + mixes = torch.matmul(x, fn_flat.t()) + sqrsum = x.square().sum(dim=-1, keepdim=True) + mixes = mixes * torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps) + + pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult] + pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps + + post_logits = ( + mixes[:, hc_mult : 2 * hc_mult] * hc_scale[1] + + hc_base[hc_mult : 2 * hc_mult] + ) + post_mix = torch.sigmoid(post_logits) * hc_post_mult_value + + comb_logits = mixes[:, 2 * hc_mult :].view( + num_tokens, hc_mult, hc_mult + ) * hc_scale[2] + hc_base[2 * hc_mult :].view(1, hc_mult, hc_mult) + comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps + comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) + for _ in range(sinkhorn_repeat - 1): + comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps) + comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) + + layer_input = torch.sum( + pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1 + ).to(torch.bfloat16) + return ( + post_mix.view(*outer_shape, hc_mult, 1), + comb_mix.view(*outer_shape, hc_mult, hc_mult), + layer_input.view(*outer_shape, hidden_size), + ) + # these number are from deepgemm kernel impl block_k = 64 block_m = 64 @@ -414,6 +447,14 @@ def mhc_post( post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: + if current_platform.is_rocm(): + mixed_residual = torch.einsum( + "...ij,...ih->...jh", + comb_res_mix.to(torch.float32), + residual.to(torch.float32), + ) + post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32) + return (mixed_residual + post_term).to(residual.dtype) out = torch.empty_like(residual) mhc_post_tilelang( comb_res_mix, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9613b11d35e2..74736d75f0dc 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -843,6 +843,15 @@ def w8a8_triton_block_scaled_mm( assert len(block_size) == 2 block_n, block_k = block_size[0], block_size[1] + # Triton cannot currently bind E8M0 scale tensors directly. On ROCm, + # DeepSeek-V4 checkpoints store block scales in exponent-only E8M0 format, + # so decode them to fp32 before launching the kernel. + if current_platform.is_rocm(): + if As.dtype == torch.float8_e8m0fnu: + As = _upcast_e8m0_to_fp32(As).contiguous() + if Bs.dtype == torch.float8_e8m0fnu: + Bs = _upcast_e8m0_to_fp32(Bs).contiguous() + assert A.shape[-1] == B.shape[-1] assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] @@ -1280,9 +1289,24 @@ def process_fp8_weight_block_strategy( ) if current_platform.is_fp8_fnuz(): - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale - ) + if weight_scale.dtype == torch.float8_e8m0fnu: + # e8m0 stores exponent-only values (2^(exp-127)). + # Doubling == incrementing the exponent byte by 1. + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + exp_bytes = weight_scale.view(torch.uint8) + weight_scale = ( + (exp_bytes.to(torch.int16) + 1) + .clamp(max=254) + .to(torch.uint8) + .view(torch.float8_e8m0fnu) + ) + else: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale + ) weight = _maybe_pad_fp8_weight(weight) return weight, weight_scale diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index ca82f2feb7ef..4bf52a49c43f 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -499,13 +499,31 @@ def forward_hip( k: torch.Tensor, weights: torch.Tensor, ): - assert not self.skip_k_cache_insert, ( - "AMD platform doesn't support skip cache insert yet" - ) assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet" assert isinstance(q_quant, torch.Tensor), ( "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" ) + if self.skip_k_cache_insert or not rocm_aiter_ops.is_enabled(): + from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + rocm_aiter_sparse_attn_indexer_native, + ) + + return rocm_aiter_sparse_attn_indexer_native( + hidden_states, + _encode_layer_name(self.k_cache.prefix), + self.k_cache.kv_cache, + q_quant, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + skip_k_cache_insert=self.skip_k_cache_insert, + ) if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, @@ -522,8 +540,4 @@ def forward_hip( self.max_total_seq_len, self.topk_indices_buffer, ) - else: - raise RuntimeError( - "Sparse attention indexer ROCm custom op requires ROCm " - "Aiter ops to be enabled." - ) + raise RuntimeError("Sparse attention indexer ROCm path could not be selected.") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 54d2cb53b0d1..15913c418b05 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -674,30 +674,45 @@ def forward( ) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) - q_pe, q_nope = torch.split( - q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 - ) - # Fused wk + weights_proj: one GEMM, then split - kw, _ = self.wk_weights_proj(hidden_states) - k = kw[:, : self.head_dim] - weights = kw[:, self.head_dim :] - - k = self.k_norm(k) - k_pe, k_nope = torch.split( - k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 - ) - q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation - # so we need to reshape back to token-flattened shapes - q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) - k_pe = k_pe.reshape(-1, 1, self.rope_dim) - - # `rotary_emb` is shape-preserving; `q_pe` is already - # [num_tokens, n_head, rope_dim]. - q = torch.cat([q_pe, q_nope], dim=-1) - # `k_pe` is [num_tokens, 1, rope_dim] (MQA). - k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) + if current_platform.is_rocm(): + # This path should works on all platform, will remove extra + # branches in the future + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] + + k = self.k_norm(k) + + rotary_emb( + positions, q[..., : self.rope_dim], k[..., : self.rope_dim].unsqueeze(1) + ) + else: + q_pe, q_nope = torch.split( + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] + + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + # Note: RoPE (NeoX) can introduce extra leading dimensions during + # compilation so we need to reshape back to token-flattened shapes + q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) + k_pe = k_pe.reshape(-1, 1, self.rope_dim) + + # `rotary_emb` is shape-preserving; `q_pe` is already + # [num_tokens, n_head, rope_dim]. + q = torch.cat([q_pe, q_nope], dim=-1) + # `k_pe` is [num_tokens, 1, rope_dim] (MQA). + k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 5521a9764a9c..73aa3538acad 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -1240,7 +1240,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. - aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + # Disable them on ROCm because of hang issues. + aux_stream_list = ( + None + if current_platform.is_rocm() + else [torch.cuda.Stream() for _ in range(3)] + ) self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. diff --git a/vllm/model_executor/models/deepseek_v4_mtp.py b/vllm/model_executor/models/deepseek_v4_mtp.py index a3724e5ebe80..195709c9dacf 100644 --- a/vllm/model_executor/models/deepseek_v4_mtp.py +++ b/vllm/model_executor/models/deepseek_v4_mtp.py @@ -167,8 +167,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) # Three aux streams shared across all MTP layers, mirroring - # DeepseekV4Model. - aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + # DeepseekV4Model. ROCm runs the same work serially for now. + aux_stream_list = ( + None + if current_platform.is_rocm() + else [torch.cuda.Stream() for _ in range(3)] + ) # to map the exact layer index from weights self.layers = torch.nn.ModuleDict( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 866b9ffd1a6d..70534c106ffe 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -409,6 +409,7 @@ class RocmPlatform(Platform): "gptq", "gptq_marlin", # will be overwritten with gptq "fp8", + "deepseek_v4_fp8", "compressed-tensors", "fbgemm_fp8", "gguf", diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c33203..f66818ab8972 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -473,7 +473,11 @@ def tf32_hc_prenorm_gemm( """ _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: - return _missing() + out.zero_() + sqrsum.zero_() + out[0].copy_(torch.matmul(x.to(torch.float32), fn.t().to(torch.float32))) + sqrsum[0].copy_(x.to(torch.float32).square().sum(dim=-1)) + return out return _tf32_hc_prenorm_gemm_impl( x, fn, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 5d12d27e7625..7c0715a9e8b6 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -122,7 +122,7 @@ def get_name() -> str: @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1 if current_platform.is_rocm() else 64] + return [1, 64] if current_platform.is_rocm() else [64] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index a66a97311fbc..2106226118ef 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -396,6 +396,7 @@ class AiterMLAHelper: """ _AITER_MIN_MLA_HEADS: Final = 16 + _AITER_UNSUPPORTED_HEADS = [32] @staticmethod def check_num_heads_validity(num_heads: int): @@ -419,6 +420,9 @@ def get_actual_mla_num_heads(num_heads: int) -> int: @staticmethod def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor: + assert num_heads not in AiterMLAHelper._AITER_UNSUPPORTED_HEADS, ( + f"unsupported head_num: {num_heads}" + ) return ( q if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 503bb509b105..dc343b639f6c 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -7,6 +7,7 @@ import numpy as np import torch +from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.config.cache import CacheDType @@ -14,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( get_mla_dims, ) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, @@ -25,9 +27,6 @@ MultipleOf, SparseMLAAttentionImpl, ) -from vllm.v1.attention.backends.mla.flashmla_sparse import ( - triton_convert_req_index_to_global_index, -) from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( AiterMLAHelper, ) @@ -38,6 +37,188 @@ logger = init_logger(__name__) +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + cu_seqlens_ptr, # int32 [num_tokens + 1] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load cumulative sequence lengths to get starting index of this request + seq_start = tl.load(cu_seqlens_ptr + token_id) + seq_end = tl.load(cu_seqlens_ptr + token_id + 1) + + if tile_id * BLOCK_N + seq_start >= seq_end: + return + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # # If token == -1 OR block_id OOB, output 0; else base * BLOCK_SIZE + offset + out_val = tl.where( + is_invalid_tok | (~valid_block), 0, base * BLOCK_SIZE + inblock_off + ) + out_ptr_ij = out_ptr + seq_start + indice_id + out_ptr_ij_mask = (seq_start + indice_id) < seq_end + + # store the results with mask + tl.store(out_ptr_ij, out_val, mask=out_ptr_ij_mask) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + cu_seqlens: torch.Tensor, # int32 [num_tokens + 1] + paged_kv_indices: torch.Tensor, # int32 [num_tokens * topk] out_buffer + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + ) + # print("req_id: ", req_id, flush=True) + num_tokens = req_id.shape[0] + _, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + cu_seqlens, + paged_kv_indices, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + ) + return + + +@triton.jit +def generate_sparse_seqlen_kernel( + seq_len_ptr, # [num_seq] + cu_query_lens_ptr, # [num_seq] + out_ptr, # [num_query_tokens] + topk_token: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + query_offset = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + query_start = tl.load(cu_query_lens_ptr + seq_id) + query_end = tl.load(cu_query_lens_ptr + seq_id + 1) + if query_start + tl.program_id(1) * BLOCK_SIZE > query_end: + return + query_len = query_end - query_start + query_mask = query_offset + query_start < query_end + seq_len = tl.load(seq_len_ptr + seq_id) + # Just return since the out_ptr is zero initialized. + if seq_len == 0: + return + context_start_point = seq_len - query_len + sparse_seqlen = context_start_point + query_offset + sparse_seqlen_masked = tl.where( + sparse_seqlen + 1 < topk_token, sparse_seqlen + 1, topk_token + ) + tl.store( + out_ptr + query_start + query_offset, sparse_seqlen_masked, mask=query_mask + ) + + +def generate_sparse_seqlen_triton( + query_lens: torch.Tensor, + seq_lens: torch.Tensor, + cu_query_lens: torch.Tensor, + topk_token: int, + num_tokens: int, + max_query_len: int, +): + num_seqs = query_lens.size(0) + # zero initialize the tensor to make sure invalid positions will be zero + out = torch.zeros([num_tokens], dtype=torch.int32, device=query_lens.device) + block_size = 64 + num_block_per_row = triton.cdiv(max_query_len, block_size) + grid = ( + num_seqs, + num_block_per_row, + ) + generate_sparse_seqlen_kernel[grid]( + seq_lens, + cu_query_lens, + out, + topk_token, + block_size, + ) + return out + + @triton.jit def fetch_id_to_ragged_kernel( in_tensor_ptr, # [num_seq, topk] @@ -86,11 +267,13 @@ class ROCMAiterMLASparseBackend(AttentionBackend): "auto", "float16", "bfloat16", + "fp8", + "fp8_e4m3", ] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1] + return [1, 64] @staticmethod def get_name() -> str: @@ -144,7 +327,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): paged_kv_last_page_len: torch.Tensor paged_kv_indices: torch.Tensor paged_kv_indptr: torch.Tensor - paged_kv_indptr_rest: torch.Tensor + attn_out_dtype: torch.dtype block_size: int = 1 topk_tokens: int = 2048 @@ -167,6 +350,7 @@ def __init__( ): self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config + self.model_dtype = vllm_config.model_config.dtype parallel_config = vllm_config.parallel_config self.device = device max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens @@ -174,9 +358,6 @@ def __init__( self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk - self.topk_tokens_tensor = torch.tensor( - [self.topk_tokens], device=device, dtype=torch.int32 - ) self.max_model_len_tensor = torch.tensor( [self.model_config.max_model_len], device=device, dtype=torch.int32 ) @@ -222,18 +403,33 @@ def build( ) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) + self.paged_kv_indices.fill_(0) + self.paged_kv_indptr.fill_(0) self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( torch.from_numpy(req_id_per_token), non_blocking=True ) - self.paged_kv_indices.fill_(0) - self.paged_kv_indptr.fill_(0) + query_lens = ( + common_attn_metadata.query_start_loc[1:] + - common_attn_metadata.query_start_loc[:-1] + ) + seq_lens = common_attn_metadata.seq_lens + sparse_seqlen = generate_sparse_seqlen_triton( + query_lens, + seq_lens, + common_attn_metadata.query_start_loc, + self.topk_tokens, + num_tokens, + common_attn_metadata.max_query_len, + ) + + torch.cumsum(sparse_seqlen, dim=0, out=self.paged_kv_indptr[1 : num_tokens + 1]) + self.paged_kv_indptr[num_tokens + 1 :].fill_(self.paged_kv_indptr[num_tokens]) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] qo_indptr = self.qo_indptr[: num_tokens + 1] paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens] - paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1] - paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :] + paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] metadata = ROCMAiterMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -245,12 +441,12 @@ def build( block_table=common_attn_metadata.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, + attn_out_dtype=self.model_dtype, topk_tokens=self.topk_tokens, qo_indptr=qo_indptr, paged_kv_last_page_len=paged_kv_last_page_len, paged_kv_indices=paged_kv_indices, paged_kv_indptr=paged_kv_indptr, - paged_kv_indptr_rest=paged_kv_indptr_rest, ) return metadata @@ -314,29 +510,20 @@ def __init__( assert indexer is not None self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer - def _forward_bf16_kv( + def _forward_mla( self, + layer: AttentionLayer, q: torch.Tensor, # [sq, heads, d_qk] kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] - topk_indices: torch.Tensor, # [sq, topk] attn_metadata: ROCMAiterMLASparseMetadata, ) -> torch.Tensor: num_tokens = q.shape[0] mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads) output = torch.empty( [num_tokens, mla_num_heads, self.kv_lora_rank], - dtype=q.dtype, + dtype=attn_metadata.attn_out_dtype, device=q.device, ) - seq_len = (topk_indices != -1).sum(dim=-1) - torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) - attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) - fetch_id_to_ragged_triton( - topk_indices, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.topk_tokens, - ) rocm_aiter_ops.mla_decode_fwd( q, @@ -348,6 +535,8 @@ def _forward_bf16_kv( attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_len, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, output) @@ -366,23 +555,32 @@ def forward_mqa( if isinstance(q, tuple): q = torch.cat(q, dim=-1) - num_actual_toks = q.shape[0] + num_actual_toks = attn_metadata.num_actual_tokens # Get topk indices assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[:num_actual_toks] - topk_indices_global = triton_convert_req_index_to_global_index( + triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, topk_indices, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, BLOCK_SIZE=attn_metadata.block_size, NUM_TOPK_TOKENS=attn_metadata.topk_tokens, ) + # write the latent and rope to kv cache + fp8_attention = self.kv_cache_dtype.startswith("fp8") + if fp8_attention: + original_q_shape = q.shape + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(current_platform.fp8_dtype()) + q, _ = ops.scaled_fp8_quant(q.view(q.shape[0], -1), layer._q_scale) + q = q.view(original_q_shape) mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q) - attn_out = self._forward_bf16_kv( - mla_padded_q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata + attn_out = self._forward_mla( + layer, mla_padded_q, kv_c_and_k_pe_cache, attn_metadata ) return attn_out, None diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index b17fd5d34418..28564e6a97d3 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -7,6 +7,7 @@ from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, @@ -360,7 +361,7 @@ def build_tile_scheduler( _LAYER_TYPE_C4A: None, _LAYER_TYPE_C128A: None, } - if num_decode_tokens == 0: + if num_decode_tokens == 0 or current_platform.is_rocm(): return out for layer_type in self._layer_types: # get_mla_metadata() is the official FlashMLA entry point that diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 69d20c107e11..ae2e300f98b5 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -215,6 +215,7 @@ def _dequantize_and_gather_k_kernel( output_dim: tl.constexpr, # 512 fp8_max: tl.constexpr, n_quant_blocks: tl.constexpr, # 7 real blocks + use_fnuz: tl.constexpr = False, ): batch_idx = tl.program_id(0) worker_id = tl.program_id(1) @@ -272,8 +273,11 @@ def _dequantize_and_gather_k_kernel( # Load quantized fp8 values (stored as uint8) x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0) - # Bitcast uint8 back to fp8 - x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + # Bitcast uint8 back to fp8 (FNUZ on gfx942, OCP otherwise) + if use_fnuz: + x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True) + else: + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) # Convert fp8 to float32 for computation x_float = x_fp8.to(tl.float32) @@ -316,6 +320,7 @@ def dequantize_and_gather_k_cache( block_table: torch.Tensor, block_size: int, offset: int, + use_fnuz: bool = False, ) -> None: TOKEN_FP8_DIM = 448 TOKEN_BF16_DIM = 64 @@ -346,6 +351,7 @@ def dequantize_and_gather_k_cache( output_dim=512, fp8_max=FP8_MAX, n_quant_blocks=7, + use_fnuz=use_fnuz, ) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py index 84647d6120d8..68d33f1aa105 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py @@ -9,6 +9,7 @@ import torch +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op @@ -242,6 +243,7 @@ def _fused_inv_rope_fp8_quant_kernel_impl( (scale_inner * tma_aligned_T, 1, tma_aligned_T), ) grid = (tma_aligned_T, n_groups * heads_per_group) + pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False} _fused_inv_rope_fp8_quant_per_head[grid]( o, positions, @@ -265,7 +267,7 @@ def _fused_inv_rope_fp8_quant_kernel_impl( HALF_ROPE=half_rope, TMA_ALIGNED_SCALES=tma_aligned_scales, num_stages=1, - launch_pdl=False, + **pdl_kwargs, num_warps=1, ) return fp8_buf, scale_buf diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 81cc489db0d8..ffa22b61ef32 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -5,6 +5,7 @@ from importlib.util import find_spec import torch +import torch.nn.functional as F from vllm.forward_context import get_forward_context from vllm.platforms import current_platform @@ -13,9 +14,6 @@ from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops - @triton.jit def _indexer_k_quant_and_cache_kernel( @@ -97,7 +95,8 @@ def indexer_k_quant_and_cache_triton( # In real layout, we store the first portion as kv cache value # and second portion as kv cache scale kv_cache = kv_cache.view(num_blocks, -1) - kv_cache_value = kv_cache[:, : block_size * head_dim] + fp8_dtype = current_platform.fp8_dtype() + kv_cache_value = kv_cache[:, : block_size * head_dim].view(fp8_dtype) kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) head_tile_size = head_tile_size // kv_cache.element_size() grid = (num_tokens,) @@ -111,7 +110,7 @@ def indexer_k_quant_and_cache_triton( block_size, num_tokens, head_dim, - "NHD", + "SHUFFLE", block_tile_size, head_tile_size, IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, @@ -212,7 +211,7 @@ def cp_gather_indexer_k_quant_cache_triton( block_table_stride, k_cache_value.stride(0), k_cache_scale.stride(0), - "NHD", + "SHUFFLE", head_dim, block_tile_size, head_tile_size, @@ -232,6 +231,43 @@ def fp8_paged_mqa_logits_torch( fp8_dtype = current_platform.fp8_dtype() batch_size, next_n, _, dim = q.size() + if next_n == 1: + block_size = kv_cache.shape[1] + logits = torch.full( + [batch_size, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + if context_lens.dim() > 1: + context_lens = context_lens.squeeze(-1) + kv_cache_flat = kv_cache.view(-1, block_size * (dim + 4)) + for i in range(batch_size): + q_i = q[i, 0].to(torch.float32) + q_scale = weights[i] + seq_len = int(context_lens[i].item()) + assert seq_len <= max_model_len + num_pages = cdiv(seq_len, block_size) + padded_seq_len = num_pages * block_size + pages = block_tables[i, :num_pages] + cache = kv_cache_flat[pages] + scale_offset = block_size * dim + cache_value = ( + cache[..., :scale_offset].view(dtype=fp8_dtype).to(torch.float32) + ) + cache_scale = ( + cache[..., scale_offset:].view(dtype=torch.float32).contiguous() + ) + cache_value = cache_value.view(padded_seq_len, dim) + cache_scale = cache_scale.view(padded_seq_len) + score = F.linear(cache_value, q_i) + score = F.relu(score) + score *= q_scale[None, :] + score = score.sum(dim=1) + score *= cache_scale + logits[i, :seq_len] = score[:seq_len] + return logits + kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] scale = scale.contiguous().view(torch.float) q = q.float() @@ -243,20 +279,30 @@ def fp8_paged_mqa_logits_torch( device=q.device, dtype=torch.float32, ) - context_lens = context_lens.tolist() for i in range(batch_size): context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + if context_len.ndim == 0: + context_len_i = int(context_len.item()) + q_offsets = torch.arange( + context_len_i - next_n, context_len_i, device=q.device + ) + context_limit = torch.full( + (next_n,), context_len_i, dtype=torch.int32, device=q.device + ) + else: + context_limit = context_len.to(device=q.device, dtype=torch.int32) + q_offsets = context_limit - 1 weight_slice = ( weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() ) - for block_rk in range(cdiv(context_len, block_size)): + max_context_len = int(context_limit.max().item()) + for block_rk in range(cdiv(max_context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] k_offsets = torch.arange( - block_rk * block_size, (block_rk + 1) * block_size, device="cuda" + block_rk * block_size, (block_rk + 1) * block_size, device=q.device ) - mask = (k_offsets[None, :] < context_len) & ( + mask = (k_offsets[None, :] < context_limit[:, None]) & ( k_offsets[None, :] <= q_offsets[:, None] ) s = torch.where( @@ -325,33 +371,38 @@ def rocm_fp8_paged_mqa_logits( from vllm._aiter_ops import rocm_aiter_ops aiter_paged_mqa_logits_module = None + # if rocm_aiter_ops.is_enabled(): + batch_size, next_n, heads, head_dim = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape + if rocm_aiter_ops.is_enabled(): aiter_paged_mqa_logits_module = paged_mqa_logits_module() if aiter_paged_mqa_logits_module is not None: - deepgemm_fp8_paged_mqa_logits_stage1 = ( - aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 + deepgemm_fp8_paged_mqa_logits = ( + aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits ) batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), + out_logits = torch.full( + [batch_size * next_n, max_model_len], float("-inf"), device="cuda", dtype=torch.float32, ) - # TODO: 1. Replace _stage1 and out_qk.sum with another fused variant; - # 2. Remove ChunkQ when AITER PR #2891 merged - deepgemm_fp8_paged_mqa_logits_stage1( + deepgemm_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, weights, - out_qk, + out_logits, context_lens, block_tables, max_model_len, - ChunkQ=heads, + ChunkK=256, + Preshuffle=block_size == 64, + KVBlockSize=block_size, + WavePerEU=2, ) - return out_qk.sum(dim=0) + return out_logits else: return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len @@ -461,6 +512,27 @@ def rocm_fp8_mqa_logits( return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) +def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor: + k = min(topk_tokens, logits.shape[-1]) + values, indices = torch.topk(logits, k=k, dim=-1) + indices = indices.to(torch.int32) + indices = torch.where( + values == float("-inf"), + torch.full_like(indices, -1, dtype=torch.int32), + indices, + ) + if k == topk_tokens: + return indices + padded = torch.full( + (logits.shape[0], topk_tokens), + -1, + dtype=torch.int32, + device=logits.device, + ) + padded[:, :k] = indices + return padded + + def rocm_aiter_sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, @@ -479,8 +551,9 @@ def rocm_aiter_sparse_attn_indexer_fake( # profile run # NOTE(Chen): create the max possible flattened_kv. So that # profile_run can get correct memory usage. + device = hidden_states.device if k is None else k.device _flattened_kv = torch.empty( - [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + [total_seq_lens, head_dim + 4], device=device, dtype=torch.uint8 ) fp8_dtype = current_platform.fp8_dtype() _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() @@ -488,7 +561,7 @@ def rocm_aiter_sparse_attn_indexer_fake( return topk_indices_buffer -def rocm_aiter_sparse_attn_indexer( +def rocm_aiter_sparse_attn_indexer_native( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, kv_cache: torch.Tensor, @@ -502,6 +575,7 @@ def rocm_aiter_sparse_attn_indexer( max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, + skip_k_cache_insert: bool = False, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata @@ -534,19 +608,24 @@ def rocm_aiter_sparse_attn_indexer( has_decode = layer_attn_metadata.num_decodes > 0 has_prefill = layer_attn_metadata.num_prefills > 0 num_decode_tokens = layer_attn_metadata.num_decode_tokens + device = hidden_states.device if k is None else k.device # during speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. num_tokens = slot_mapping.shape[0] - k = k[:num_tokens] + if k is not None: + k = k[:num_tokens] + elif not skip_k_cache_insert: + raise ValueError("k must be provided when skip_k_cache_insert is False") - ops.indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) + if not skip_k_cache_insert: + indexer_k_quant_and_cache_triton( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: @@ -555,21 +634,21 @@ def rocm_aiter_sparse_attn_indexer( for chunk in prefill_metadata.chunks: k_fp8 = torch.empty( [chunk.total_seq_lens, head_dim], - device=k.device, + device=device, dtype=fp8_dtype, ) k_scale = torch.empty( [chunk.total_seq_lens, 4], - device=k.device, + device=device, dtype=torch.uint8, ) - - ops.cp_gather_indexer_k_quant_cache( + cp_gather_indexer_k_quant_cache_triton( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, + chunk.token_to_seq, ) logits = rocm_fp8_mqa_logits( @@ -579,21 +658,10 @@ def rocm_aiter_sparse_attn_indexer( chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) - num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + topk_indices.copy_(_topk_indices_torch(logits, topk_tokens)) if has_decode: decode_metadata = layer_attn_metadata.decode @@ -630,19 +698,8 @@ def rocm_aiter_sparse_attn_indexer( max_model_len=max_model_len, ) - num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + topk_indices.copy_(_topk_indices_torch(logits, topk_tokens)[:num_decode_tokens]) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -656,3 +713,36 @@ def rocm_aiter_sparse_attn_indexer( ) return topk_indices_buffer + + +def rocm_aiter_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + return rocm_aiter_sparse_attn_indexer_native( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + skip_k_cache_insert=False, + )