From b620af685be93db60da36efe801ea52c52fb2a4f Mon Sep 17 00:00:00 2001 From: ganyi Date: Fri, 24 Apr 2026 07:30:32 +0000 Subject: [PATCH 01/17] make kernel compatible with rocm platform Signed-off-by: ganyi Made-with: Cursor --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 37 ++++++++++++++++++- csrc/moe/topk_softplus_sqrt_kernels.cu | 6 ++- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index e96017d86da..3e903080c7f 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -29,7 +29,11 @@ */ #include -#include +#ifndef USE_ROCM + #include +#else + #include +#endif #include #include @@ -42,7 +46,23 @@ #include "type_convert.cuh" #ifndef FINAL_MASK - #define FINAL_MASK 0xffffffffu + #ifdef USE_ROCM + #define FINAL_MASK 0xffffffffffffffffULL + #else + #define FINAL_MASK 0xffffffffu + #endif +#endif + +#ifdef USE_ROCM +// ROCm-compatible FP8 conversion helpers +__device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) { + #if defined(HIP_FP8_TYPE_OCP) + __hip_fp8_e4m3 fp8_val(val); + #else + __hip_fp8_e4m3_fnuz fp8_val(val); + #endif + return reinterpret_cast(fp8_val); +} #endif namespace vllm { @@ -314,9 +334,13 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( for (int i = 0; i < kElemsPerLane; i++) { float scaled = elements[i] * inv_scale; scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max); +#ifndef USE_ROCM __nv_fp8_storage_t s = __nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3); out_bytes[i] = static_cast(s); +#else + out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); +#endif } // One 16-byte STG per lane. *reinterpret_cast(token_fp8_ptr + dim_base) = @@ -384,6 +408,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( // PDL: enable programmatic stream serialization whenever the hardware // supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable, // so leave numAttrs = 0 and launch as a regular kernel. +#ifndef USE_ROCM static int const sm_version = getSMVersion(); // Host-side guard: the device kernel body is compiled as a no-op for // bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert is @@ -410,6 +435,14 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, kv_block_stride); +#else + // ROCm: use standard kernel launch syntax (no PDL/stream serialization) + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + <<>>( + q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, + num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, + kv_block_stride); +#endif } } // namespace deepseek_v4_fused_ops diff --git a/csrc/moe/topk_softplus_sqrt_kernels.cu b/csrc/moe/topk_softplus_sqrt_kernels.cu index 50a8540a737..7cbba6b6354 100644 --- a/csrc/moe/topk_softplus_sqrt_kernels.cu +++ b/csrc/moe/topk_softplus_sqrt_kernels.cu @@ -60,7 +60,11 @@ __device__ __forceinline__ float toFloat(T value) { } } -#define FINAL_MASK 0xffffffff +#ifdef USE_ROCM + #define FINAL_MASK 0xffffffffffffffffULL +#else + #define FINAL_MASK 0xffffffff +#endif template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll From a6a6a0f3dced9320f9c38bf0a48abbcc552bf897 Mon Sep 17 00:00:00 2001 From: whx-sjtu Date: Tue, 28 Apr 2026 07:12:04 +0000 Subject: [PATCH 02/17] ref supoort of dsv4 Signed-off-by: whx-sjtu --- vllm/config/kernel.py | 2 + vllm/forward_context.py | 8 +- .../kernels/linear/scaled_mm/aiter.py | 15 + vllm/model_executor/layers/activation.py | 4 +- .../layers/deepseek_compressor.py | 242 +++++++++- .../layers/deepseek_v4_attention.py | 438 +++++++++++++++++- .../layers/fused_moe/oracle/mxfp4.py | 78 +++- vllm/model_executor/layers/mhc.py | 42 ++ .../layers/quantization/utils/fp8_utils.py | 9 + .../layers/sparse_attn_indexer.py | 32 +- vllm/model_executor/models/deepseek_v4.py | 1 - vllm/platforms/rocm.py | 1 + vllm/utils/deep_gemm.py | 6 +- vllm/v1/attention/backends/mla/sparse_swa.py | 3 +- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 165 +++++-- vllm/v1/worker/gpu_model_runner.py | 2 + 16 files changed, 976 insertions(+), 72 deletions(-) diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index 93fb4c54b7f..3de89eacdd3 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -115,6 +115,7 @@ def with_default( "flashinfer_cutlass", "flashinfer_cutedsl", "marlin", + "triton_unfused", "aiter", "emulation", ] @@ -145,6 +146,7 @@ class KernelConfig: - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only) - "marlin": Use Marlin kernels (weight-only quantization) + - "triton_unfused": Use Triton unfused MoE kernels - "aiter": Use AMD AITer kernels (ROCm only) - "emulation": use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 537a28a4252..54d565ef340 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -257,6 +257,7 @@ def set_forward_context( batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, + additional_kwargs: dict[str, Any] | None = None, skip_compiled: bool = False, ): """A context manager that stores the current forward context, @@ -296,7 +297,7 @@ def set_forward_context( if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens) - additional_kwargs = current_platform.set_additional_forward_context( + platform_additional_kwargs = current_platform.set_additional_forward_context( attn_metadata=attn_metadata, vllm_config=vllm_config, dp_metadata=dp_metadata, @@ -306,6 +307,9 @@ def set_forward_context( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ) + merged_additional_kwargs = dict(platform_additional_kwargs) + if additional_kwargs: + merged_additional_kwargs.update(additional_kwargs) forward_context = create_forward_context( attn_metadata, @@ -315,7 +319,7 @@ def set_forward_context( batch_descriptor, ubatch_slices, slot_mapping, - additional_kwargs, + merged_additional_kwargs, skip_compiled, ) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 8a8650d2213..5ded5ca798a 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -312,6 +312,21 @@ def apply_block_scaled_mm( As: torch.Tensor, Bs: torch.Tensor, ) -> torch.Tensor: + if As.dtype != Bs.dtype: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + if As.dtype == torch.float8_e8m0fnu: + As = _upcast_e8m0_to_fp32(As).contiguous() + else: + As = As.to(torch.float32) + + if Bs.dtype == torch.float8_e8m0fnu: + Bs = _upcast_e8m0_to_fp32(Bs).contiguous() + else: + Bs = Bs.to(torch.float32) + out_dtype = self.config.out_dtype if self.use_triton: gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 59cc95f18c5..df9459012ae 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -169,7 +169,9 @@ class SiluAndMulWithClamp(CustomOp): def __init__(self, swiglu_limit: float, *, compile_native: bool = True): super().__init__(compile_native=compile_native) self.swiglu_limit = float(swiglu_limit) - if current_platform.is_cuda_alike() or current_platform.is_xpu(): + if current_platform.is_rocm(): + self._forward_method = self.forward_native + elif current_platform.is_cuda_alike() or current_platform.is_xpu(): self.op = torch.ops._C.silu_and_mul_with_clamp elif current_platform.is_cpu(): self._forward_method = self.forward_native diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index cae80c35316..e92d9173c84 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -7,6 +7,7 @@ import torch from torch import nn +from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -28,6 +29,7 @@ _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, _fused_kv_compress_norm_rope_insert_sparse_attn, ) +from vllm.v1.attention.ops.deepseek_v4_ops.cache_utils import quantize_and_insert_k_cache from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( MXFP4_BLOCK_SIZE, ) @@ -174,6 +176,54 @@ def get_attn_backend(self) -> type[AttentionBackend]: return CompressorBackend +def hadamard_transform_ref(x: torch.Tensor) -> torch.Tensor: + hidden_size = x.shape[-1] + assert hidden_size > 0 and (hidden_size & (hidden_size - 1)) == 0, ( + f"Hidden size must be a power of 2, got {hidden_size}" + ) + dtype = x.dtype + y = x.to(torch.float32).reshape(-1, hidden_size) + h = 1 + while h < hidden_size: + y = y.view(-1, hidden_size // (2 * h), 2, h) + a = y[:, :, 0, :] + b = y[:, :, 1, :] + y = torch.cat((a + b, a - b), dim=-1) + h *= 2 + y = y.view(*x.shape) * (hidden_size**-0.5) + return y.to(dtype) + + +def apply_gptj_rope_ref( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if rope_dim == 0 or x.numel() == 0: + return x + half_rot = rope_dim // 2 + nope_dim = x.shape[-1] - rope_dim + dtype = x.dtype + x = x.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) + view_shape = (positions.shape[0],) + (1,) * (x.dim() - 2) + (half_rot,) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = x[..., nope_dim:] + x_even = rope[..., 0::2] + x_odd = rope[..., 1::2] + rope_out = torch.stack( + (x_even * cos - x_odd * sin, x_odd * cos + x_even * sin), + dim=-1, + ).flatten(-2) + x = x.clone() + x[..., nope_dim:] = rope_out + return x.to(dtype) + + class DeepseekCompressor(nn.Module): def __init__( self, @@ -239,6 +289,9 @@ def __init__( self._static_forward_context = ( vllm_config.compilation_config.static_forward_context ) + self._old_kv_state: dict[str, torch.Tensor] = {} + self._old_score_state: dict[str, torch.Tensor] = {} + self._old_need_hadamard = self.head_dim == 128 if self.head_dim == 512: assert not use_fp4_cache, ( @@ -324,7 +377,6 @@ def forward( TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), STATE_WIDTH=state_width, COMPRESS_RATIO=self.compress_ratio, - launch_pdl=False, ) # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. @@ -373,9 +425,195 @@ def forward( SCALE_DIM=self._scale_dim, KV_BLOCK_STRIDE=kv_cache.stride(0), num_warps=self._num_warps, - launch_pdl=False, ) + @property + def _state_width(self) -> int: + return self.coff * self.head_dim + + @property + def _state_len(self) -> int: + return self.compress_ratio * self.coff + + def _get_old_state(self, req_id: str, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + kv_state = self._old_kv_state.get(req_id) + score_state = self._old_score_state.get(req_id) + if kv_state is None or score_state is None or kv_state.device != device: + kv_state = torch.zeros( + (self._state_len, self._state_width), + dtype=torch.float32, + device=device, + ) + score_state = torch.full( + (self._state_len, self._state_width), + float("-inf"), + dtype=torch.float32, + device=device, + ) + self._old_kv_state[req_id] = kv_state + self._old_score_state[req_id] = score_state + return kv_state, score_state + + def _clear_old_state(self, kv_state: torch.Tensor, score_state: torch.Tensor) -> None: + kv_state.zero_() + score_state.fill_(float("-inf")) + + def _overlap_transform(self, tensor: torch.Tensor, fill_value: Any) -> torch.Tensor: + assert tensor.dim() == 3 + assert tensor.shape[1:] == (self.compress_ratio, 2 * self.head_dim) + s, r, d = tensor.shape[0], self.compress_ratio, self.head_dim + new_tensor = tensor.new_full((s, 2 * r, d), fill_value) + new_tensor[:, r:] = tensor[:, :, d:] + new_tensor[1:, :r] = tensor[:-1, :, :d] + return new_tensor + + def _ref_rms_norm(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(torch.float32) + x = x * torch.rsqrt(x.square().mean(dim=-1, keepdim=True) + self.rms_norm_eps) + return x * self.norm.weight.to(torch.float32) + + @staticmethod + def _compute_state_len(seq_len: int, ratio: int) -> int: + return seq_len % ratio + (ratio == 4) * ratio + + def _forward_old( + self, + x: torch.Tensor, + positions: torch.Tensor, + rotary_emb, + ) -> None: + num_tokens, _ = x.shape + kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight) + kv, score = kv_score.split([self._state_width, self._state_width], dim=-1) + + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if not isinstance(attn_metadata, dict): + return + + state_metadata = attn_metadata[self.state_cache.prefix] + token_to_req_indices = state_metadata.token_to_req_indices[:num_tokens] + k_cache_metadata = attn_metadata[self.k_cache_prefix] + slot_mapping = k_cache_metadata.slot_mapping[:num_tokens] + kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache + req_ids = forward_context.additional_kwargs.get("req_ids") + if req_ids is None: + max_req_idx = int(token_to_req_indices.max().item()) + 1 if num_tokens > 0 else 0 + req_ids = [str(i) for i in range(max_req_idx)] + + _, counts = torch.unique_consecutive(token_to_req_indices.to(torch.int64), return_counts=True) + start = 0 + for req_idx_tensor, count_tensor in zip( + torch.unique_consecutive(token_to_req_indices.to(torch.int64)), counts + ): + req_idx = int(req_idx_tensor.item()) + count = int(count_tensor.item()) + end = start + count + req_id = req_ids[req_idx] + kv_state, score_state = self._get_old_state(req_id, x.device) + + positions_i = positions[start:end].to(torch.int64) + prefix_len = int(positions_i[0].item()) + query_len = end - start + if prefix_len == 0: + self._clear_old_state(kv_state, score_state) + + pre_state_len = self._compute_state_len(prefix_len, self.compress_ratio) + valid_kv_len = pre_state_len + query_len + + temp_kv = torch.empty( + (valid_kv_len, self._state_width), dtype=torch.float32, device=x.device + ) + temp_score = torch.empty_like(temp_kv) + if pre_state_len > 0: + temp_kv[:pre_state_len] = kv_state[:pre_state_len] + temp_score[:pre_state_len] = score_state[:pre_state_len] + temp_kv[pre_state_len:] = kv[start:end] + temp_score[pre_state_len:] = score[start:end] + + post_state_len = self._compute_state_len(valid_kv_len, self.compress_ratio) + kv_state[:post_state_len] = temp_kv[valid_kv_len - post_state_len : valid_kv_len] + score_state[:post_state_len] = temp_score[ + valid_kv_len - post_state_len : valid_kv_len + ] + if post_state_len < self._state_len: + kv_state[post_state_len:].zero_() + score_state[post_state_len:].fill_(float("-inf")) + + compress_len = valid_kv_len // self.compress_ratio * self.compress_ratio + if compress_len == 0: + start = end + continue + + kv_to_compress = temp_kv[:compress_len].view( + compress_len // self.compress_ratio, + self.compress_ratio, + self._state_width, + ) + score_to_compress = temp_score[:compress_len].view_as(kv_to_compress) + score_to_compress = score_to_compress + self.ape.unsqueeze(0) + + if self.overlap: + kv_to_compress = self._overlap_transform(kv_to_compress, 0.0) + score_to_compress = self._overlap_transform( + score_to_compress, float("-inf") + ) + kv_to_compress = kv_to_compress[1:] + score_to_compress = score_to_compress[1:] + if kv_to_compress.numel() == 0: + start = end + continue + + kv_compressed = ( + kv_to_compress * torch.softmax(score_to_compress, dim=1) + ).sum(dim=1) + kv_compressed = self._ref_rms_norm(kv_compressed) + + first_compressed_pos = prefix_len + first_compressed_pos += self.compress_ratio - 1 - first_compressed_pos % self.compress_ratio + compressed_positions = torch.arange( + first_compressed_pos, + prefix_len + query_len, + self.compress_ratio, + device=x.device, + dtype=torch.int64, + ) + if compressed_positions.numel() != kv_compressed.shape[0]: + raise RuntimeError( + f"Compressed positions mismatch for req {req_id}: " + f"{compressed_positions.numel()} vs {kv_compressed.shape[0]}" + ) + + kv_compressed = apply_gptj_rope_ref( + kv_compressed, compressed_positions, rotary_emb.cos_sin_cache, self.rope_head_dim + ).to(torch.bfloat16) + if self._old_need_hadamard: + kv_compressed = hadamard_transform_ref(kv_compressed) + + local_output = torch.zeros( + (query_len, self.head_dim), dtype=torch.bfloat16, device=x.device + ) + local_output[(compressed_positions - prefix_len).to(torch.long)] = kv_compressed + + if self.head_dim == 512: + quantize_and_insert_k_cache( + local_output, + kv_cache, + slot_mapping[start:end], + block_size=kv_cache.shape[1], + is_ue8m0=True, + ) + else: + ops.indexer_k_quant_and_cache( + local_output, + kv_cache, + slot_mapping[start:end], + self._quant_block, + "ue8m0", + ) + + start = end + @triton.jit def _save_partial_states_kernel( diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 74b494dc485..f83be6c5af3 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -5,6 +5,8 @@ """ from collections.abc import Callable +import math + from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -43,7 +45,11 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor +from vllm.model_executor.layers.deepseek_compressor import ( + DeepseekCompressor, + apply_gptj_rope_ref, + hadamard_transform_ref, +) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import ( @@ -56,6 +62,7 @@ execute_in_parallel, maybe_execute_in_parallel, ) +from vllm.platforms import current_platform from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -197,8 +204,6 @@ def __init__( # Pick fp8_einsum recipe based on GPU arch: # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 - from vllm.platforms import current_platform - cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) @@ -209,6 +214,8 @@ def __init__( self.topk_indices_buffer = mla_modules.topk_indices_buffer self.indexer = mla_modules.indexer + # Keep ROCm on the BF16 reference wo_a path util kernel ready. + self.use_ref_wo_a_path = current_platform.is_rocm() # Per-head RMS normalization for Q (no learnable weights) self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) @@ -302,6 +309,32 @@ def forward( ) o = o_padded[:, : self.n_local_heads, :] + if self.use_ref_wo_a_path: + o_ref = _apply_inv_rope_ref( + self.rotary_emb, o, positions, self.rope_head_dim + ).to(torch.bfloat16) + o_ref = o_ref.view(num_tokens, self.n_local_groups, -1) + + hidden_dim = o_ref.shape[-1] + if hasattr(self.wo_a, "weight_scale_inv"): + wo_a_weight = self.wo_a.weight.view( + self.n_local_groups, self.o_lora_rank, hidden_dim + ).to(torch.float32) + wo_a_scale = _expand_2d_block_scales( + self.wo_a.weight_scale_inv.view( + self.n_local_groups, -1, self.wo_a.weight_scale_inv.shape[-1] + ), + self.o_lora_rank, + hidden_dim, + ) + wo_a_weight = (wo_a_weight * wo_a_scale).to(torch.bfloat16) + else: + wo_a_weight = self.wo_a.weight.view( + self.n_local_groups, self.o_lora_rank, hidden_dim + ).to(torch.bfloat16) + z = torch.einsum("tgd,grd->tgr", o_ref, wo_a_weight) + return self.wo_b(z.flatten(1)) + # O projection: inverse RoPE + FP8 quant + einsum + wo_b o_fp8, o_scale = fused_inv_rope_fp8_quant( o, @@ -568,7 +601,128 @@ def deepseek_v4_fp8_einsum( equation: str, recipe: list[int], ) -> None: - fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + try: + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + except RuntimeError as exc: + if "DeepGEMM backend is not available or outdated" not in str(exc): + raise + _deepseek_v4_fp8_einsum_fallback(a, a_scale, b, b_scale, out, equation) + + +def _decode_e8m0_scales(scale: torch.Tensor) -> torch.Tensor: + if scale.dtype == torch.float8_e8m0fnu: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + return _upcast_e8m0_to_fp32(scale).contiguous() + return scale.to(torch.float32) + + +def _expand_last_dim_scales(scale: torch.Tensor, last_dim: int) -> torch.Tensor: + scale = _decode_e8m0_scales(scale) + block = math.ceil(last_dim / scale.shape[-1]) + return torch.repeat_interleave(scale, block, dim=-1)[..., :last_dim] + + +def _expand_2d_block_scales( + scale: torch.Tensor, + rows: int, + cols: int, +) -> torch.Tensor: + scale = _decode_e8m0_scales(scale) + row_blocks, col_blocks = scale.shape[-2:] + row_block = math.ceil(rows / row_blocks) + col_block = math.ceil(cols / col_blocks) + scale = torch.repeat_interleave(scale, row_block, dim=-2)[..., :rows, :] + scale = torch.repeat_interleave(scale, col_block, dim=-1)[..., :, :cols] + return scale + + +def _apply_gptj_inv_rope_ref( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if rope_dim == 0 or x.numel() == 0: + return x + half_rot = rope_dim // 2 + nope_dim = x.shape[-1] - rope_dim + dtype = x.dtype + x = x.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) + view_shape = (positions.shape[0],) + (1,) * (x.dim() - 2) + (half_rot,) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = x[..., nope_dim:] + y_even = rope[..., 0::2] + y_odd = rope[..., 1::2] + rope_out = torch.stack( + (y_even * cos + y_odd * sin, y_odd * cos - y_even * sin), + dim=-1, + ).flatten(-2) + x = x.clone() + x[..., nope_dim:] = rope_out + return x.to(dtype) + + +def _apply_inv_rope_ref( + rotary_emb: torch.nn.Module, + x: torch.Tensor, + positions: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if hasattr(rotary_emb, "forward_native"): + try: + query, _ = rotary_emb.forward_native( + positions, + x.clone(), + None, + inverse=True, + ) + return query + except TypeError: + pass + return _apply_gptj_inv_rope_ref(x, positions, rotary_emb.cos_sin_cache, rope_dim) + + +def _deepseek_v4_fp8_einsum_fallback( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, +) -> None: + if equation != "bhr,hdr->bhd": + raise RuntimeError(f"Unsupported fallback equation: {equation}") + + num_groups = a.shape[1] + hidden_dim = a.shape[2] + output_dim = b.shape[0] // num_groups + + if b.shape[0] % num_groups != 0: + raise RuntimeError( + f"Cannot reshape weight of shape {tuple(b.shape)} into " + f"({num_groups}, {output_dim}, {hidden_dim})." + ) + + a_deq = (a.to(torch.float32) * _expand_last_dim_scales(a_scale, hidden_dim)).to( + torch.bfloat16 + ) + + b_deq = b.view(num_groups, output_dim, hidden_dim).to(torch.float32) + b_scale_deq = _expand_2d_block_scales( + b_scale.view(num_groups, -1, b_scale.shape[-1]), + output_dim, + hidden_dim, + ) + b_deq = (b_deq * b_scale_deq).to(torch.bfloat16) + + out.copy_(torch.einsum(equation, a_deq, b_deq).to(out.dtype)) def deepseek_v4_fp8_einsum_fake( @@ -665,8 +819,11 @@ def __init__( vllm_config.scheduler_config.max_num_batched_tokens ) self.max_model_len = vllm_config.model_config.max_model_len - # DeepseekV4 only supports fp8 kv-cache format for now + # DeepseekV4 only supports fp8 kv-cache format for now. Treat "auto" + # as the model default and normalize it before the fp8-only checks. kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" + if kv_cache_dtype == "auto": + kv_cache_dtype = "fp8" assert kv_cache_dtype.startswith("fp8"), ( f"DeepseekV4 only supports fp8 kv-cache format for now, " @@ -813,6 +970,20 @@ def _forward_decode( swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens + if current_platform.is_rocm(): + self._forward_decode_fallback( + q=q, + kv_cache=kv_cache, + swa_metadata=swa_metadata, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + output=output, + ) + return + # We treat queries in the same seq as different queries # and later we only attend by generated indices. # q arrives pre-padded to self.padded_heads by the outer wrapper. @@ -977,15 +1148,234 @@ def _forward_prefill( N, ) - output_chunk, _, _ = flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], + if current_platform.is_rocm(): + output_chunk = self._ref_sparse_attn_prefill( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + topk_length=combined_lens, + ) + output[query_start:query_end].copy_(output_chunk.to(output.dtype)) + else: + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) + + def _decode_e8m0_scales(self, scale: torch.Tensor) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + return _upcast_e8m0_to_fp32(scale.contiguous()) + + def _dequantize_cache_rows(self, rows: torch.Tensor) -> torch.Tensor: + rows = rows.reshape(-1, rows.shape[-1]) + fp8_dtype = current_platform.fp8_dtype() + fp8_dim = self.nope_head_dim + rope_bytes = self.rope_head_dim * 2 + scale_dim = fp8_dim // 64 + + fp8_vals = rows[:, :fp8_dim].contiguous().view(fp8_dtype) + rope_vals = rows[:, fp8_dim : fp8_dim + rope_bytes].contiguous().view( + torch.bfloat16 + ) + scale_bytes = rows[ + :, fp8_dim + rope_bytes : fp8_dim + rope_bytes + scale_dim + ].contiguous() + scales = self._decode_e8m0_scales(scale_bytes) + scales = torch.repeat_interleave(scales, 64, dim=-1) + nope = fp8_vals.to(torch.float32) * scales + return torch.cat([nope, rope_vals.to(torch.float32)], dim=-1).to(torch.bfloat16) + + def _gather_dequantized_cache_tokens( + self, + cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, + ) -> torch.Tensor: + if slot_ids.numel() == 0: + return torch.empty((0, self.head_dim), dtype=torch.bfloat16, device=cache.device) + slot_ids = slot_ids.to(torch.int64) + rows = cache[slot_ids // block_size, slot_ids % block_size] + return self._dequantize_cache_rows(rows).reshape(-1, self.head_dim) + + def _forward_decode_fallback( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_metadata: "DeepseekSparseSWAMetadata", + swa_only: bool, + topk_indices: torch.Tensor | None, + topk_lens: torch.Tensor | None, + swa_indices: torch.Tensor, + swa_lens: torch.Tensor, + output: torch.Tensor, + ) -> None: + blocked_swa = self._dequantize_blocked_k_cache(self.swa_cache_layer.kv_cache) + blocked_extra = None if swa_only else self._dequantize_blocked_k_cache(kv_cache) + attn_out = self._ref_sparse_attn_decode( + q=q.unsqueeze(1), + blocked_k=blocked_swa, + indices_in_kvcache=swa_indices.unsqueeze(1), + topk_length=swa_lens, + attn_sink=self.attn_sink[: q.shape[1]], + extra_blocked_k=blocked_extra, + extra_indices_in_kvcache=topk_indices, + extra_topk_length=topk_lens, + ) + output.copy_(attn_out.to(output.dtype)) + + def _dequantize_blocked_k_cache(self, quant_k_cache: torch.Tensor) -> torch.Tensor: + fp8_dtype = current_platform.fp8_dtype() + d = self.head_dim + d_nope = self.nope_head_dim + d_rope = self.rope_head_dim + tile_size = 64 + num_tiles = d_nope // tile_size + + num_blocks, block_size, _ = quant_k_cache.shape + quant_k_cache = quant_k_cache.view(num_blocks, -1) + input_nope_rope = quant_k_cache[:, : block_size * (d_nope + 2 * d_rope)].view( + num_blocks, block_size, d_nope + 2 * d_rope + ) + input_nope = input_nope_rope[:, :, :d_nope].view(fp8_dtype) + input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16) + input_scale = ( + quant_k_cache[:, block_size * (d_nope + 2 * d_rope) :] + .view(num_blocks, block_size, 8)[:, :, :num_tiles] + .view(torch.float8_e8m0fnu) + ) + + result = torch.empty( + (num_blocks, block_size, 1, d), + dtype=torch.bfloat16, + device=quant_k_cache.device, + ) + result[..., d_nope:] = input_rope.unsqueeze(2) + for tile_idx in range(num_tiles): + cur_nope = input_nope[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].to(torch.bfloat16) + cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1) + result[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ] = (cur_nope * cur_scales).unsqueeze(2) + return result + + def _ref_sparse_attn_prefill( + self, + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + topk_length: torch.Tensor | None, + ) -> torch.Tensor: + indices = indices.clone().squeeze(1) + s_q, h_q, d_qk = q.shape + topk = indices.shape[-1] + s_kv = kv.shape[0] + if topk_length is not None: + mask = torch.arange(topk, device=indices.device).unsqueeze(0) >= topk_length.unsqueeze(1) + indices[mask] = -1 + invalid_mask = (indices < 0) | (indices >= s_kv) + indices[invalid_mask] = 0 + + qf = q.float() + gathered_kv = ( + kv.index_select(0, indices.flatten()).reshape(s_q, topk, d_qk).float() + ) + scores = qf @ gathered_kv.transpose(1, 2) + scores *= self.scale + scores[invalid_mask.unsqueeze(1).expand_as(scores)] = float("-inf") + + orig_lse = torch.logsumexp(scores, dim=-1) + lse_for_o = orig_lse + if self.attn_sink is not None: + lse_for_o = torch.logsumexp( + torch.stack([orig_lse, self.attn_sink[:h_q].view(1, h_q).expand_as(orig_lse)], dim=0), + dim=0, + ) + lse_for_o = lse_for_o.clone() + lse_for_o[lse_for_o == float("-inf")] = float("+inf") + probs = torch.exp(scores - lse_for_o.unsqueeze(-1)) + out = probs @ gathered_kv[..., : self.head_dim] + lonely_q_mask = orig_lse == float("-inf") + out[lonely_q_mask.unsqueeze(-1).expand_as(out)] = 0.0 + return out.to(torch.bfloat16) + + def _ref_sparse_attn_decode( + self, + q: torch.Tensor, + blocked_k: torch.Tensor, + indices_in_kvcache: torch.Tensor, + topk_length: torch.Tensor | None, + attn_sink: torch.Tensor | None, + extra_blocked_k: torch.Tensor | None = None, + extra_indices_in_kvcache: torch.Tensor | None = None, + extra_topk_length: torch.Tensor | None = None, + ) -> torch.Tensor: + b, s_q, h_q, d_qk = q.shape + d_v = self.head_dim + + def process_scope( + cur_blocked_k: torch.Tensor, + cur_indices: torch.Tensor, + cur_topk_length: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + cur_indices = cur_indices.reshape(b, s_q, -1) + topk = cur_indices.size(-1) + fixed_indices = torch.clamp_min(cur_indices, 0) + gathered_kv = ( + cur_blocked_k.view(-1, d_qk) + .index_select(0, fixed_indices.view(-1)) + .view(b, s_q, topk, d_qk) ) + invalid_mask = cur_indices == -1 + if cur_topk_length is not None: + cur_topk_length = cur_topk_length.reshape(b) + invalid_mask |= torch.arange(0, topk, device=invalid_mask.device).view( + 1, 1, topk + ) >= cur_topk_length.view(b, 1, 1) + return gathered_kv, invalid_mask + + gathered_kv, invalid_mask = process_scope( + blocked_k, indices_in_kvcache, topk_length + ) + if extra_blocked_k is not None: + assert extra_indices_in_kvcache is not None + gathered_kv1, invalid_mask1 = process_scope( + extra_blocked_k, extra_indices_in_kvcache, extra_topk_length + ) + gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2) + invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2) + + gathered_kv = gathered_kv.view(b * s_q, -1, d_qk).float() + gathered_kv[gathered_kv != gathered_kv] = 0.0 + qf = q.float().view(b * s_q, h_q, d_qk) + attn_weight = qf @ gathered_kv.transpose(-1, -2) + attn_weight *= self.scale + attn_weight[ + invalid_mask.view(b * s_q, 1, -1).expand(b * s_q, h_q, invalid_mask.size(-1)) + ] = float("-inf") + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1)) + output = attn_weight @ gathered_kv[..., :d_v] + output = output.view(b, s_q, h_q, d_v) + lse = lse.view(b, s_q, h_q) + + if attn_sink is not None: + output *= ( + 1.0 / (1.0 + torch.exp(attn_sink.view(1, 1, h_q) - lse)) + ).unsqueeze(-1) + + lonely_q_mask = lse == float("-inf") + output[lonely_q_mask.unsqueeze(-1).expand_as(output)] = 0.0 + return output.squeeze(1).to(torch.bfloat16) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): @@ -1150,3 +1540,25 @@ def forward( use_fp4=self.use_fp4_kv, ) return self.indexer_op(hidden_states, q_quant, k, weights) + + def _quantize_indexer_q_torch( + self, + q: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + q = apply_gptj_rope_ref(q, positions, cos_sin_cache, self.rope_dim).to( + torch.bfloat16 + ) + q = hadamard_transform_ref(q).to(torch.float32) + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else 448.0 + q_scale = torch.abs(q).amax(dim=-1).clamp(min=1e-12) / fp8_max + q_quant = (q / q_scale.unsqueeze(-1)).to(current_platform.fp8_dtype()) + weights = ( + weights.to(torch.float32) + * q_scale + * self.softmax_scale + * (self.n_head**-0.5) + ) + return q_quant, weights diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index c1423362d73..a1237bedb49 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -217,6 +217,8 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: TRTLLM MXFP8; SM90 falls through to Triton_unfused or Marlin (the backend-level ``is_supported_config`` check filters by device capability). """ + if current_platform.is_rocm(): + return [Mxfp4MoeBackend.AITER] _AVAILABLE_BACKENDS = [ Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.DEEPGEMM_MXFP4, @@ -484,8 +486,22 @@ def _return_or_raise( activation_format, ) + # DeepSeek-V4 on ROCm is more accurate with the unfused Triton MXFP4 path + # than the default AITER path. Prefer Triton-unfused for this routing mode, + # while keeping AITER as a fallback if Triton-unfused rejects the config. + if ( + current_platform.is_rocm() + and config.routing_method == RoutingMethodType.DeepseekV4 + ): + priority_backends = [ + Mxfp4MoeBackend.TRITON_UNFUSED, + Mxfp4MoeBackend.AITER, + ] + else: + priority_backends = _get_priority_backends() + # Iterate priority backends: TRTLLM MXFP8, then Triton. - for backend in _get_priority_backends(): + for backend in priority_backends: activation_key = _backend_activation_key(backend) for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( @@ -1107,6 +1123,64 @@ def convert_weight_to_mxfp4_moe_kernel_format( w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER: + from vllm._aiter_ops import rocm_aiter_ops + + if w13_bias is not None: + w13_bias = w13_bias.data.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.data.to(torch.float32) + + e, n, k = w13_weight.shape + + w13_weight.view(torch.uint8).copy_( + w13_weight.data.view(torch.uint8) + .view(e, n // 2, 2, k) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, k) + ) + w13_weight_scale.data = ( + w13_weight_scale.data.view(e, n // 2, 2, -1) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, -1) + ) + + w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2) + w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2) + + w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True) + shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4( + w13_weight_scale.view(-1, w13_weight_scale.shape[-1]), + num_experts, + True, + ) + + w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False) + shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4( + w2_weight_scale.view(-1, w2_weight_scale.shape[-1]), + num_experts, + False, + ) + + if w13_bias is not None: + w13_bias = ( + w13_bias.data.view(-1, n // 2, 2) + .permute(0, 2, 1) + .contiguous() + .view(-1, n) + ) + + return ( + w13_weight, + w2_weight, + shuffled_w13_scale, + shuffled_w2_scale, + w13_bias, + w2_bias, + ) + elif mxfp4_backend in TRITON_BACKENDS: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -1162,7 +1236,7 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor: else: raise ValueError( f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. " - f"Expected TRTLLM or Triton backend." + f"Expected TRTLLM, Triton, or AITER backend." ) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index 1521a6b601b..ad1655e7ead 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -234,6 +234,40 @@ def mhc_pre( num_tokens = residual_flat.shape[0] fn_flat = fn + if current_platform.is_rocm(): + x = residual_flat.view(num_tokens, hc_mult * hidden_size).to(torch.float32) + mixes = torch.matmul(x, fn_flat.t()) + sqrsum = x.square().sum(dim=-1, keepdim=True) + mixes = mixes * torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps) + + pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult] + pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps + + post_logits = ( + mixes[:, hc_mult : 2 * hc_mult] * hc_scale[1] + + hc_base[hc_mult : 2 * hc_mult] + ) + post_mix = torch.sigmoid(post_logits) * hc_post_mult_value + + comb_logits = ( + mixes[:, 2 * hc_mult :].view(num_tokens, hc_mult, hc_mult) * hc_scale[2] + + hc_base[2 * hc_mult :].view(1, hc_mult, hc_mult) + ) + comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps + comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) + for _ in range(sinkhorn_repeat - 1): + comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps) + comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) + + layer_input = torch.sum( + pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1 + ).to(torch.bfloat16) + return ( + post_mix.view(*outer_shape, hc_mult, 1), + comb_mix.view(*outer_shape, hc_mult, hc_mult), + layer_input.view(*outer_shape, hidden_size), + ) + # these number are from deepgemm kernel impl block_k = 64 block_m = 64 @@ -414,6 +448,14 @@ def mhc_post( post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: + if current_platform.is_rocm(): + mixed_residual = torch.einsum( + "...ij,...ih->...jh", + comb_res_mix.to(torch.float32), + residual.to(torch.float32), + ) + post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32) + return (mixed_residual + post_term).to(residual.dtype) out = torch.empty_like(residual) mhc_post_tilelang( comb_res_mix, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9613b11d35e..d9aab35c25f 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -843,6 +843,15 @@ def w8a8_triton_block_scaled_mm( assert len(block_size) == 2 block_n, block_k = block_size[0], block_size[1] + # Triton cannot currently bind E8M0 scale tensors directly. On ROCm, + # DeepSeek-V4 checkpoints store block scales in exponent-only E8M0 format, + # so decode them to fp32 before launching the kernel. + if current_platform.is_rocm(): + if As.dtype == torch.float8_e8m0fnu: + As = _upcast_e8m0_to_fp32(As).contiguous() + if Bs.dtype == torch.float8_e8m0fnu: + Bs = _upcast_e8m0_to_fp32(Bs).contiguous() + assert A.shape[-1] == B.shape[-1] assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index ca82f2feb7e..9ad881cdbe8 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -499,13 +499,31 @@ def forward_hip( k: torch.Tensor, weights: torch.Tensor, ): - assert not self.skip_k_cache_insert, ( - "AMD platform doesn't support skip cache insert yet" - ) assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet" assert isinstance(q_quant, torch.Tensor), ( "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" ) + if self.skip_k_cache_insert or not rocm_aiter_ops.is_enabled(): + from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + rocm_aiter_sparse_attn_indexer_native, + ) + + return rocm_aiter_sparse_attn_indexer_native( + hidden_states, + _encode_layer_name(self.k_cache.prefix), + self.k_cache.kv_cache, + q_quant, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + skip_k_cache_insert=self.skip_k_cache_insert, + ) if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, @@ -522,8 +540,6 @@ def forward_hip( self.max_total_seq_len, self.topk_indices_buffer, ) - else: - raise RuntimeError( - "Sparse attention indexer ROCm custom op requires ROCm " - "Aiter ops to be enabled." - ) + raise RuntimeError( + "Sparse attention indexer ROCm path could not be selected." + ) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 5521a9764a9..0c337cb7ae0 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -66,7 +66,6 @@ _DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8") - class DeepseekV4MLP(nn.Module): def __init__( self, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 866b9ffd1a6..70534c106ff 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -409,6 +409,7 @@ class RocmPlatform(Platform): "gptq", "gptq_marlin", # will be overwritten with gptq "fp8", + "deepseek_v4_fp8", "compressed-tensors", "fbgemm_fp8", "gguf", diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c3320..f66818ab897 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -473,7 +473,11 @@ def tf32_hc_prenorm_gemm( """ _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: - return _missing() + out.zero_() + sqrsum.zero_() + out[0].copy_(torch.matmul(x.to(torch.float32), fn.t().to(torch.float32))) + sqrsum[0].copy_(x.to(torch.float32).square().sum(dim=-1)) + return out return _tf32_hc_prenorm_gemm_impl( x, fn, diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index b17fd5d3441..28564e6a97d 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -7,6 +7,7 @@ from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, @@ -360,7 +361,7 @@ def build_tile_scheduler( _LAYER_TYPE_C4A: None, _LAYER_TYPE_C128A: None, } - if num_decode_tokens == 0: + if num_decode_tokens == 0 or current_platform.is_rocm(): return out for layer_type in self._layer_types: # get_mla_metadata() is the official FlashMLA entry point that diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 81cc489db0d..ea75f19862e 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -5,6 +5,7 @@ from importlib.util import find_spec import torch +import torch.nn.functional as F from vllm.forward_context import get_forward_context from vllm.platforms import current_platform @@ -232,6 +233,39 @@ def fp8_paged_mqa_logits_torch( fp8_dtype = current_platform.fp8_dtype() batch_size, next_n, _, dim = q.size() + if next_n == 1: + block_size = kv_cache.shape[1] + logits = torch.full( + [batch_size, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + if context_lens.dim > 1: + context_lens = context_lens.squeeze(-1) + kv_cache_flat = kv_cache.view(-1, block_size * (dim + 4)) + for i in range(batch_size): + q_i = q[i, 0].to(torch.float32) + q_scale = weights[i] + seq_len = int(context_lens[i].item()) + assert seq_len <= max_model_len + num_pages = cdiv(seq_len, block_size) + padded_seq_len = num_pages * block_size + pages = block_tables[i, :num_pages] + cache = kv_cache_flat[pages] + scale_offset = block_size * dim + cache_value = cache[..., :scale_offset].view(dtype=fp8_dtype).to(torch.float32) + cache_scale = cache[..., scale_offset:].view(dtype=torch.float32).contiguous() + cache_value = cache_value.view(padded_seq_len, dim) + cache_scale = cache_scale.view(padded_seq_len) + score = F.linear(cache_value, q_i) + score = F.relu(score) + score *= q_scale[None, :] + score = score.sum(dim=1) + score *= cache_scale + logits[i, :seq_len] = score[:seq_len] + return logits + kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] scale = scale.contiguous().view(torch.float) q = q.float() @@ -243,20 +277,30 @@ def fp8_paged_mqa_logits_torch( device=q.device, dtype=torch.float32, ) - context_lens = context_lens.tolist() for i in range(batch_size): context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + if context_len.ndim == 0: + context_len_i = int(context_len.item()) + q_offsets = torch.arange( + context_len_i - next_n, context_len_i, device=q.device + ) + context_limit = torch.full( + (next_n,), context_len_i, dtype=torch.int32, device=q.device + ) + else: + context_limit = context_len.to(device=q.device, dtype=torch.int32) + q_offsets = context_limit - 1 weight_slice = ( weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() ) - for block_rk in range(cdiv(context_len, block_size)): + max_context_len = int(context_limit.max().item()) + for block_rk in range(cdiv(max_context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] k_offsets = torch.arange( - block_rk * block_size, (block_rk + 1) * block_size, device="cuda" + block_rk * block_size, (block_rk + 1) * block_size, device=q.device ) - mask = (k_offsets[None, :] < context_len) & ( + mask = (k_offsets[None, :] < context_limit[:, None]) & ( k_offsets[None, :] <= q_offsets[:, None] ) s = torch.where( @@ -461,6 +505,27 @@ def rocm_fp8_mqa_logits( return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) +def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor: + k = min(topk_tokens, logits.shape[-1]) + values, indices = torch.topk(logits, k=k, dim=-1) + indices = indices.to(torch.int32) + indices = torch.where( + values == float("-inf"), + torch.full_like(indices, -1, dtype=torch.int32), + indices, + ) + if k == topk_tokens: + return indices + padded = torch.full( + (logits.shape[0], topk_tokens), + -1, + dtype=torch.int32, + device=logits.device, + ) + padded[:, :k] = indices + return padded + + def rocm_aiter_sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, @@ -479,8 +544,9 @@ def rocm_aiter_sparse_attn_indexer_fake( # profile run # NOTE(Chen): create the max possible flattened_kv. So that # profile_run can get correct memory usage. + device = hidden_states.device if k is None else k.device _flattened_kv = torch.empty( - [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + [total_seq_lens, head_dim + 4], device=device, dtype=torch.uint8 ) fp8_dtype = current_platform.fp8_dtype() _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() @@ -488,7 +554,7 @@ def rocm_aiter_sparse_attn_indexer_fake( return topk_indices_buffer -def rocm_aiter_sparse_attn_indexer( +def rocm_aiter_sparse_attn_indexer_native( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, kv_cache: torch.Tensor, @@ -502,6 +568,7 @@ def rocm_aiter_sparse_attn_indexer( max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, + skip_k_cache_insert: bool = False, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata @@ -534,19 +601,24 @@ def rocm_aiter_sparse_attn_indexer( has_decode = layer_attn_metadata.num_decodes > 0 has_prefill = layer_attn_metadata.num_prefills > 0 num_decode_tokens = layer_attn_metadata.num_decode_tokens + device = hidden_states.device if k is None else k.device # during speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. num_tokens = slot_mapping.shape[0] - k = k[:num_tokens] + if k is not None: + k = k[:num_tokens] + elif not skip_k_cache_insert: + raise ValueError("k must be provided when skip_k_cache_insert is False") - ops.indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) + if not skip_k_cache_insert: + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: @@ -555,12 +627,12 @@ def rocm_aiter_sparse_attn_indexer( for chunk in prefill_metadata.chunks: k_fp8 = torch.empty( [chunk.total_seq_lens, head_dim], - device=k.device, + device=device, dtype=fp8_dtype, ) k_scale = torch.empty( [chunk.total_seq_lens, 4], - device=k.device, + device=device, dtype=torch.uint8, ) @@ -579,21 +651,10 @@ def rocm_aiter_sparse_attn_indexer( chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) - num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + topk_indices.copy_(_topk_indices_torch(logits, topk_tokens)) if has_decode: decode_metadata = layer_attn_metadata.decode @@ -630,19 +691,8 @@ def rocm_aiter_sparse_attn_indexer( max_model_len=max_model_len, ) - num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + topk_indices.copy_(_topk_indices_torch(logits, topk_tokens)[:num_decode_tokens]) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -656,3 +706,36 @@ def rocm_aiter_sparse_attn_indexer( ) return topk_indices_buffer + + +def rocm_aiter_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + return rocm_aiter_sparse_attn_indexer_native( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + skip_k_cache_insert=False, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 522e7fdbf25..630df049c2c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4078,6 +4078,7 @@ def execute_model( batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, slot_mapping=slot_mappings, + additional_kwargs={"req_ids": self.input_batch.req_ids.copy()}, skip_compiled=has_encoder_input, ), record_function_or_nullcontext("gpu_model_runner: forward"), @@ -5570,6 +5571,7 @@ def _dummy_run( batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, slot_mapping=slot_mappings, + additional_kwargs={"req_ids": self.input_batch.req_ids.copy()}, ), ): outputs = self.model( From 8e55f0d169d8eed5dcf33a6cd36e4d0b512e3a3d Mon Sep 17 00:00:00 2001 From: whx-sjtu Date: Tue, 28 Apr 2026 10:15:56 +0000 Subject: [PATCH 03/17] remove _old_forward Signed-off-by: whx-sjtu --- .../layers/deepseek_compressor.py | 190 ------------------ vllm/v1/worker/gpu_model_runner.py | 2 - 2 files changed, 192 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index e92d9173c84..715840adf65 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -289,9 +289,6 @@ def __init__( self._static_forward_context = ( vllm_config.compilation_config.static_forward_context ) - self._old_kv_state: dict[str, torch.Tensor] = {} - self._old_score_state: dict[str, torch.Tensor] = {} - self._old_need_hadamard = self.head_dim == 128 if self.head_dim == 512: assert not use_fp4_cache, ( @@ -427,193 +424,6 @@ def forward( num_warps=self._num_warps, ) - @property - def _state_width(self) -> int: - return self.coff * self.head_dim - - @property - def _state_len(self) -> int: - return self.compress_ratio * self.coff - - def _get_old_state(self, req_id: str, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: - kv_state = self._old_kv_state.get(req_id) - score_state = self._old_score_state.get(req_id) - if kv_state is None or score_state is None or kv_state.device != device: - kv_state = torch.zeros( - (self._state_len, self._state_width), - dtype=torch.float32, - device=device, - ) - score_state = torch.full( - (self._state_len, self._state_width), - float("-inf"), - dtype=torch.float32, - device=device, - ) - self._old_kv_state[req_id] = kv_state - self._old_score_state[req_id] = score_state - return kv_state, score_state - - def _clear_old_state(self, kv_state: torch.Tensor, score_state: torch.Tensor) -> None: - kv_state.zero_() - score_state.fill_(float("-inf")) - - def _overlap_transform(self, tensor: torch.Tensor, fill_value: Any) -> torch.Tensor: - assert tensor.dim() == 3 - assert tensor.shape[1:] == (self.compress_ratio, 2 * self.head_dim) - s, r, d = tensor.shape[0], self.compress_ratio, self.head_dim - new_tensor = tensor.new_full((s, 2 * r, d), fill_value) - new_tensor[:, r:] = tensor[:, :, d:] - new_tensor[1:, :r] = tensor[:-1, :, :d] - return new_tensor - - def _ref_rms_norm(self, x: torch.Tensor) -> torch.Tensor: - x = x.to(torch.float32) - x = x * torch.rsqrt(x.square().mean(dim=-1, keepdim=True) + self.rms_norm_eps) - return x * self.norm.weight.to(torch.float32) - - @staticmethod - def _compute_state_len(seq_len: int, ratio: int) -> int: - return seq_len % ratio + (ratio == 4) * ratio - - def _forward_old( - self, - x: torch.Tensor, - positions: torch.Tensor, - rotary_emb, - ) -> None: - num_tokens, _ = x.shape - kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight) - kv, score = kv_score.split([self._state_width, self._state_width], dim=-1) - - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - if not isinstance(attn_metadata, dict): - return - - state_metadata = attn_metadata[self.state_cache.prefix] - token_to_req_indices = state_metadata.token_to_req_indices[:num_tokens] - k_cache_metadata = attn_metadata[self.k_cache_prefix] - slot_mapping = k_cache_metadata.slot_mapping[:num_tokens] - kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache - req_ids = forward_context.additional_kwargs.get("req_ids") - if req_ids is None: - max_req_idx = int(token_to_req_indices.max().item()) + 1 if num_tokens > 0 else 0 - req_ids = [str(i) for i in range(max_req_idx)] - - _, counts = torch.unique_consecutive(token_to_req_indices.to(torch.int64), return_counts=True) - start = 0 - for req_idx_tensor, count_tensor in zip( - torch.unique_consecutive(token_to_req_indices.to(torch.int64)), counts - ): - req_idx = int(req_idx_tensor.item()) - count = int(count_tensor.item()) - end = start + count - req_id = req_ids[req_idx] - kv_state, score_state = self._get_old_state(req_id, x.device) - - positions_i = positions[start:end].to(torch.int64) - prefix_len = int(positions_i[0].item()) - query_len = end - start - if prefix_len == 0: - self._clear_old_state(kv_state, score_state) - - pre_state_len = self._compute_state_len(prefix_len, self.compress_ratio) - valid_kv_len = pre_state_len + query_len - - temp_kv = torch.empty( - (valid_kv_len, self._state_width), dtype=torch.float32, device=x.device - ) - temp_score = torch.empty_like(temp_kv) - if pre_state_len > 0: - temp_kv[:pre_state_len] = kv_state[:pre_state_len] - temp_score[:pre_state_len] = score_state[:pre_state_len] - temp_kv[pre_state_len:] = kv[start:end] - temp_score[pre_state_len:] = score[start:end] - - post_state_len = self._compute_state_len(valid_kv_len, self.compress_ratio) - kv_state[:post_state_len] = temp_kv[valid_kv_len - post_state_len : valid_kv_len] - score_state[:post_state_len] = temp_score[ - valid_kv_len - post_state_len : valid_kv_len - ] - if post_state_len < self._state_len: - kv_state[post_state_len:].zero_() - score_state[post_state_len:].fill_(float("-inf")) - - compress_len = valid_kv_len // self.compress_ratio * self.compress_ratio - if compress_len == 0: - start = end - continue - - kv_to_compress = temp_kv[:compress_len].view( - compress_len // self.compress_ratio, - self.compress_ratio, - self._state_width, - ) - score_to_compress = temp_score[:compress_len].view_as(kv_to_compress) - score_to_compress = score_to_compress + self.ape.unsqueeze(0) - - if self.overlap: - kv_to_compress = self._overlap_transform(kv_to_compress, 0.0) - score_to_compress = self._overlap_transform( - score_to_compress, float("-inf") - ) - kv_to_compress = kv_to_compress[1:] - score_to_compress = score_to_compress[1:] - if kv_to_compress.numel() == 0: - start = end - continue - - kv_compressed = ( - kv_to_compress * torch.softmax(score_to_compress, dim=1) - ).sum(dim=1) - kv_compressed = self._ref_rms_norm(kv_compressed) - - first_compressed_pos = prefix_len - first_compressed_pos += self.compress_ratio - 1 - first_compressed_pos % self.compress_ratio - compressed_positions = torch.arange( - first_compressed_pos, - prefix_len + query_len, - self.compress_ratio, - device=x.device, - dtype=torch.int64, - ) - if compressed_positions.numel() != kv_compressed.shape[0]: - raise RuntimeError( - f"Compressed positions mismatch for req {req_id}: " - f"{compressed_positions.numel()} vs {kv_compressed.shape[0]}" - ) - - kv_compressed = apply_gptj_rope_ref( - kv_compressed, compressed_positions, rotary_emb.cos_sin_cache, self.rope_head_dim - ).to(torch.bfloat16) - if self._old_need_hadamard: - kv_compressed = hadamard_transform_ref(kv_compressed) - - local_output = torch.zeros( - (query_len, self.head_dim), dtype=torch.bfloat16, device=x.device - ) - local_output[(compressed_positions - prefix_len).to(torch.long)] = kv_compressed - - if self.head_dim == 512: - quantize_and_insert_k_cache( - local_output, - kv_cache, - slot_mapping[start:end], - block_size=kv_cache.shape[1], - is_ue8m0=True, - ) - else: - ops.indexer_k_quant_and_cache( - local_output, - kv_cache, - slot_mapping[start:end], - self._quant_block, - "ue8m0", - ) - - start = end - @triton.jit def _save_partial_states_kernel( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 630df049c2c..522e7fdbf25 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4078,7 +4078,6 @@ def execute_model( batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, slot_mapping=slot_mappings, - additional_kwargs={"req_ids": self.input_batch.req_ids.copy()}, skip_compiled=has_encoder_input, ), record_function_or_nullcontext("gpu_model_runner: forward"), @@ -5571,7 +5570,6 @@ def _dummy_run( batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, slot_mapping=slot_mappings, - additional_kwargs={"req_ids": self.input_batch.req_ids.copy()}, ), ): outputs = self.model( From 05580cb188d3a26c6a79246b3a587fc7f623d759 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sat, 25 Apr 2026 05:52:09 +0000 Subject: [PATCH 04/17] support topk_softplus for all number of experts Signed-off-by: tjtanaa --- csrc/moe/topk_softplus_sqrt_kernels.cu | 57 +++++++++++--------- tests/kernels/moe/test_topk_softplus_sqrt.py | 6 +-- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/csrc/moe/topk_softplus_sqrt_kernels.cu b/csrc/moe/topk_softplus_sqrt_kernels.cu index 7cbba6b6354..43d461a0179 100644 --- a/csrc/moe/topk_softplus_sqrt_kernels.cu +++ b/csrc/moe/topk_softplus_sqrt_kernels.cu @@ -60,19 +60,6 @@ __device__ __forceinline__ float toFloat(T value) { } } -#ifdef USE_ROCM - #define FINAL_MASK 0xffffffffffffffffULL -#else - #define FINAL_MASK 0xffffffff -#endif -template -__inline__ __device__ T warpReduceSum(T val) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); - return val; -} - // ====================== TopK softplus_sqrt things // =============================== @@ -276,8 +263,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ } } // Compute per-thread scale (using warp reduction when renormalizing). + // THREADS_PER_ROW-parameterized butterfly works for both warp sizes (32 + // on CUDA, 64 on ROCm CDNA) and any THREADS_PER_ROW the dispatch picks. if (renormalize) { - selected_sum = warpReduceSum(selected_sum); +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + selected_sum += + VLLM_SHFL_XOR_SYNC_WIDTH(selected_sum, mask, THREADS_PER_ROW); + } } float scale = static_cast(routed_scaling_factor); if (renormalize) { @@ -548,7 +541,6 @@ void topkGatingSoftplusSqrtKernelLauncher( const IndType* tid2eid, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; -#ifndef USE_ROCM // for bfloat16 dtype, we need 4 bytes loading to make sure num_experts // elements can be loaded by a warp static constexpr int BYTES_PER_LDG_MULTIPLE_64 = @@ -556,6 +548,19 @@ void topkGatingSoftplusSqrtKernelLauncher( std::is_same_v) ? 4 : 8; + // Narrower LDG (ELTS_PER_LDG=1) used by 192/320/448/576 on ROCm WARP_SIZE=64 + // where ELTS_PER_LDG=2 fails the EXPERTS%(ELTS_PER_LDG*WARP_SIZE)==0 check. + // On CUDA WARP_SIZE=32 the wider LDG already aligns, so the alias collapses + // back to BYTES_PER_LDG_MULTIPLE_64 — no behavioral change for CUDA. +#ifdef USE_ROCM + static constexpr int BYTES_PER_LDG_MULTIPLE_64_NARROW = + (std::is_same_v || + std::is_same_v) + ? 2 + : 4; +#else + static constexpr int BYTES_PER_LDG_MULTIPLE_64_NARROW = + BYTES_PER_LDG_MULTIPLE_64; #endif switch (num_experts) { case 1: @@ -588,27 +593,29 @@ void topkGatingSoftplusSqrtKernelLauncher( case 512: LAUNCH_SOFTPLUS_SQRT(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; - // (CUDA only) support multiples of 64 when num_experts is not power of 2. - // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of - // num_experts, alternatively we can test 4 bytes loading and enable it in - // future. -#ifndef USE_ROCM + // Multiples of 64 that are not powers of 2. The kernel requires + // EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0. With ELTS_PER_LDG=2 + // (BYTES_PER_LDG_MULTIPLE_64), this holds for all five values on CUDA + // WARP_SIZE=32 but only for 384 on ROCm WARP_SIZE=64. The other four + // use BYTES_PER_LDG_MULTIPLE_64_NARROW (ELTS_PER_LDG=1), which + // satisfies the assertion for any multiple of 64 on either backend; + // on CUDA the narrow alias collapses back to the wider load, so CUDA + // behavior is unchanged. case 192: - LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW); break; case 320: - LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW); break; case 384: LAUNCH_SOFTPLUS_SQRT(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; case 448: - LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW); break; case 576: - LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW); break; -#endif default: { TORCH_CHECK(false, "Unsupported expert number: ", num_experts); } diff --git a/tests/kernels/moe/test_topk_softplus_sqrt.py b/tests/kernels/moe/test_topk_softplus_sqrt.py index 7f5aacb383d..7edbfc5316e 100644 --- a/tests/kernels/moe/test_topk_softplus_sqrt.py +++ b/tests/kernels/moe/test_topk_softplus_sqrt.py @@ -70,7 +70,7 @@ def test_sqrtsoftplus_bias_uses_deepseek_v4_routing_method(): @pytest.mark.skipif( - not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." + not current_platform.is_cuda_alike(), reason="This test is skipped on non-CUDA platform." ) @pytest.mark.parametrize("num_tokens", [1, 33, 128]) @pytest.mark.parametrize("hidden_size", [1024, 2048]) @@ -125,7 +125,7 @@ def test_fused_topk_softplus_sqrt( @pytest.mark.skipif( - not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." + not current_platform.is_cuda_alike(), reason="This test is skipped on non-CUDA platform." ) @pytest.mark.parametrize("num_tokens", [1, 33, 128]) @pytest.mark.parametrize("hidden_size", [1024, 2048]) @@ -183,4 +183,4 @@ def test_fused_topk_softplus_sqrt_hash( sorted_w_ref = topk_weights_ref.gather(1, idx_ref) sorted_w = topk_weights.gather(1, idx_ops) - torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2) + torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2) \ No newline at end of file From e0df265c9b4041de8278f8511dcd054007372890 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sat, 25 Apr 2026 06:02:52 +0000 Subject: [PATCH 05/17] fix formatting Signed-off-by: tjtanaa --- tests/kernels/moe/test_topk_softplus_sqrt.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_topk_softplus_sqrt.py b/tests/kernels/moe/test_topk_softplus_sqrt.py index 7edbfc5316e..1b68213fafe 100644 --- a/tests/kernels/moe/test_topk_softplus_sqrt.py +++ b/tests/kernels/moe/test_topk_softplus_sqrt.py @@ -70,7 +70,8 @@ def test_sqrtsoftplus_bias_uses_deepseek_v4_routing_method(): @pytest.mark.skipif( - not current_platform.is_cuda_alike(), reason="This test is skipped on non-CUDA platform." + not current_platform.is_cuda_alike(), + reason="This test is skipped on non-CUDA platform.", ) @pytest.mark.parametrize("num_tokens", [1, 33, 128]) @pytest.mark.parametrize("hidden_size", [1024, 2048]) @@ -125,7 +126,8 @@ def test_fused_topk_softplus_sqrt( @pytest.mark.skipif( - not current_platform.is_cuda_alike(), reason="This test is skipped on non-CUDA platform." + not current_platform.is_cuda_alike(), + reason="This test is skipped on non-CUDA platform.", ) @pytest.mark.parametrize("num_tokens", [1, 33, 128]) @pytest.mark.parametrize("hidden_size", [1024, 2048]) @@ -183,4 +185,4 @@ def test_fused_topk_softplus_sqrt_hash( sorted_w_ref = topk_weights_ref.gather(1, idx_ref) sorted_w = topk_weights.gather(1, idx_ops) - torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2) \ No newline at end of file + torch.testing.assert_close(sorted_w_ref, sorted_w, atol=2e-2, rtol=1e-2) From 914e38744d668cd8844845340640a2dfdcf36f71 Mon Sep 17 00:00:00 2001 From: tjtanaavllm Date: Tue, 28 Apr 2026 13:01:26 +0000 Subject: [PATCH 06/17] enable topk_softplus_sqrt Signed-off-by: tjtanaavllm --- CMakeLists.txt | 6 +++--- csrc/moe/torch_bindings.cpp | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bf4ac05e4f2..b1c87696a97 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1047,13 +1047,13 @@ endif() set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" - "csrc/moe/topk_softmax_kernels.cu") + "csrc/moe/topk_softmax_kernels.cu" + "csrc/moe/topk_softplus_sqrt_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" - "csrc/moe/grouped_topk_kernels.cu" - "csrc/moe/topk_softplus_sqrt_kernels.cu") + "csrc/moe/grouped_topk_kernels.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index b737cb54353..8940e341cd0 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -16,14 +16,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "bias) -> ()"); m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid); -#ifndef USE_ROCM m.def( "topk_softplus_sqrt(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output, bool renormalize, float " "routed_scaling_factor, Tensor? " "bias, Tensor? input_ids, Tensor? tid2eid) -> ()"); m.impl("topk_softplus_sqrt", torch::kCUDA, &topk_softplus_sqrt); -#endif + // Calculate the result of moe by summing up the partial results // from all selected experts. m.def("moe_sum(Tensor input, Tensor! output) -> ()"); From 7e61709ab2eb0b3bca426619ec7f981156b27917 Mon Sep 17 00:00:00 2001 From: tjtanaavllm Date: Tue, 28 Apr 2026 13:26:37 +0000 Subject: [PATCH 07/17] enable fused_deepseek_v4_qnorm_rope_kv_insert Signed-off-by: tjtanaavllm --- CMakeLists.txt | 6 +++--- csrc/torch_bindings.cpp | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b1c87696a97..13788fa8743 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,12 +307,12 @@ set(VLLM_EXT_SRC "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/custom_all_reduce.cu" - "csrc/torch_bindings.cpp") + "csrc/torch_bindings.cpp" + "csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC - "csrc/minimax_reduce_rms_kernel.cu" - "csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu") + "csrc/minimax_reduce_rms_kernel.cu") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8d8f7bed044..e695497fd88 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -183,7 +183,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int forced_token_heads_per_warp=-1) -> ()"); ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope); -#ifndef USE_ROCM // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one // kernel launch. @@ -194,7 +193,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float eps, int cache_block_size) -> ()"); ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); -#endif // Apply repetition penalties to logits in-place ops.def( From 7d4236c01ca463a7109b5c3df0f1abe66c4da809 Mon Sep 17 00:00:00 2001 From: whx-sjtu Date: Wed, 29 Apr 2026 17:17:20 +0000 Subject: [PATCH 08/17] fix review Signed-off-by: whx-sjtu --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 14 ++++---- .../layers/deepseek_compressor.py | 5 +-- .../layers/deepseek_v4_attention.py | 32 ++++++++++++------- .../layers/fused_moe/oracle/mxfp4.py | 1 + vllm/model_executor/layers/mhc.py | 7 ++-- .../layers/sparse_attn_indexer.py | 4 +-- vllm/model_executor/models/deepseek_v4.py | 1 + .../fused_inv_rope_fp8_quant.py | 4 ++- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 10 ++++-- 9 files changed, 46 insertions(+), 32 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 3e903080c7f..4f2ec6e7488 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -57,9 +57,9 @@ // ROCm-compatible FP8 conversion helpers __device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) { #if defined(HIP_FP8_TYPE_OCP) - __hip_fp8_e4m3 fp8_val(val); + __hip_fp8_e4m3 fp8_val(val); #else - __hip_fp8_e4m3_fnuz fp8_val(val); + __hip_fp8_e4m3_fnuz fp8_val(val); #endif return reinterpret_cast(fp8_val); } @@ -339,7 +339,7 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( __nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3); out_bytes[i] = static_cast(s); #else - out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); + out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); #endif } // One 16-byte STG per lane. @@ -438,10 +438,10 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( #else // ROCm: use standard kernel launch syntax (no PDL/stream serialization) fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel - <<>>( - q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, - num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, - kv_block_stride); + << > >( + q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, + eps, num_tokens_full, num_tokens_insert, num_heads_q, + cache_block_size, kv_block_stride); #endif } diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index 715840adf65..1d69e140b66 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -7,7 +7,6 @@ import torch from torch import nn -from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -29,7 +28,6 @@ _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, _fused_kv_compress_norm_rope_insert_sparse_attn, ) -from vllm.v1.attention.ops.deepseek_v4_ops.cache_utils import quantize_and_insert_k_cache from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( MXFP4_BLOCK_SIZE, ) @@ -350,6 +348,7 @@ def forward( state_cache = self.state_cache.kv_cache # kv_state stored in first half, score_state stored in second half state_width = state_cache.shape[-1] // 2 + pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False} # Store the KV and score (with fused APE addition) in the state. # NOTE: PDL is disabled — both this kernel and _fused_kernel below @@ -374,6 +373,7 @@ def forward( TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), STATE_WIDTH=state_width, COMPRESS_RATIO=self.compress_ratio, + **pdl_kwargs, ) # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. @@ -422,6 +422,7 @@ def forward( SCALE_DIM=self._scale_dim, KV_BLOCK_STRIDE=kv_cache.stride(0), num_warps=self._num_warps, + **pdl_kwargs, ) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index f83be6c5af3..4707ad501c1 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,9 +4,8 @@ DeepseekV4 MLA Attention Layer """ -from collections.abc import Callable import math - +from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -58,11 +57,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) +from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, ) -from vllm.platforms import current_platform from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -1182,8 +1181,8 @@ def _dequantize_cache_rows(self, rows: torch.Tensor) -> torch.Tensor: scale_dim = fp8_dim // 64 fp8_vals = rows[:, :fp8_dim].contiguous().view(fp8_dtype) - rope_vals = rows[:, fp8_dim : fp8_dim + rope_bytes].contiguous().view( - torch.bfloat16 + rope_vals = ( + rows[:, fp8_dim : fp8_dim + rope_bytes].contiguous().view(torch.bfloat16) ) scale_bytes = rows[ :, fp8_dim + rope_bytes : fp8_dim + rope_bytes + scale_dim @@ -1200,7 +1199,9 @@ def _gather_dequantized_cache_tokens( block_size: int, ) -> torch.Tensor: if slot_ids.numel() == 0: - return torch.empty((0, self.head_dim), dtype=torch.bfloat16, device=cache.device) + return torch.empty( + (0, self.head_dim), dtype=torch.bfloat16, device=cache.device + ) slot_ids = slot_ids.to(torch.int64) rows = cache[slot_ids // block_size, slot_ids % block_size] return self._dequantize_cache_rows(rows).reshape(-1, self.head_dim) @@ -1263,9 +1264,9 @@ def _dequantize_blocked_k_cache(self, quant_k_cache: torch.Tensor) -> torch.Tens ..., tile_idx * tile_size : (tile_idx + 1) * tile_size ].to(torch.bfloat16) cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1) - result[ - ..., tile_idx * tile_size : (tile_idx + 1) * tile_size - ] = (cur_nope * cur_scales).unsqueeze(2) + result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_nope * cur_scales + ).unsqueeze(2) return result def _ref_sparse_attn_prefill( @@ -1280,7 +1281,9 @@ def _ref_sparse_attn_prefill( topk = indices.shape[-1] s_kv = kv.shape[0] if topk_length is not None: - mask = torch.arange(topk, device=indices.device).unsqueeze(0) >= topk_length.unsqueeze(1) + mask = torch.arange(topk, device=indices.device).unsqueeze( + 0 + ) >= topk_length.unsqueeze(1) indices[mask] = -1 invalid_mask = (indices < 0) | (indices >= s_kv) indices[invalid_mask] = 0 @@ -1297,7 +1300,10 @@ def _ref_sparse_attn_prefill( lse_for_o = orig_lse if self.attn_sink is not None: lse_for_o = torch.logsumexp( - torch.stack([orig_lse, self.attn_sink[:h_q].view(1, h_q).expand_as(orig_lse)], dim=0), + torch.stack( + [orig_lse, self.attn_sink[:h_q].view(1, h_q).expand_as(orig_lse)], + dim=0, + ), dim=0, ) lse_for_o = lse_for_o.clone() @@ -1360,7 +1366,9 @@ def process_scope( attn_weight = qf @ gathered_kv.transpose(-1, -2) attn_weight *= self.scale attn_weight[ - invalid_mask.view(b * s_q, 1, -1).expand(b * s_q, h_q, invalid_mask.size(-1)) + invalid_mask.view(b * s_q, 1, -1).expand( + b * s_q, h_q, invalid_mask.size(-1) + ) ] = float("-inf") lse = attn_weight.logsumexp(dim=-1) attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1)) diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index a1237bedb49..da0c8a688a1 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, FusedMoEQuantDesc, + RoutingMethodType, mxfp4_mxfp8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index ad1655e7ead..68b47a89b1a 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -249,10 +249,9 @@ def mhc_pre( ) post_mix = torch.sigmoid(post_logits) * hc_post_mult_value - comb_logits = ( - mixes[:, 2 * hc_mult :].view(num_tokens, hc_mult, hc_mult) * hc_scale[2] - + hc_base[2 * hc_mult :].view(1, hc_mult, hc_mult) - ) + comb_logits = mixes[:, 2 * hc_mult :].view( + num_tokens, hc_mult, hc_mult + ) * hc_scale[2] + hc_base[2 * hc_mult :].view(1, hc_mult, hc_mult) comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) for _ in range(sinkhorn_repeat - 1): diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 9ad881cdbe8..4bf52a49c43 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -540,6 +540,4 @@ def forward_hip( self.max_total_seq_len, self.topk_indices_buffer, ) - raise RuntimeError( - "Sparse attention indexer ROCm path could not be selected." - ) + raise RuntimeError("Sparse attention indexer ROCm path could not be selected.") diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 0c337cb7ae0..5521a9764a9 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -66,6 +66,7 @@ _DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8") + class DeepseekV4MLP(nn.Module): def __init__( self, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py index 84647d6120d..68d33f1aa10 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py @@ -9,6 +9,7 @@ import torch +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op @@ -242,6 +243,7 @@ def _fused_inv_rope_fp8_quant_kernel_impl( (scale_inner * tma_aligned_T, 1, tma_aligned_T), ) grid = (tma_aligned_T, n_groups * heads_per_group) + pdl_kwargs = {} if current_platform.is_rocm() else {"launch_pdl": False} _fused_inv_rope_fp8_quant_per_head[grid]( o, positions, @@ -265,7 +267,7 @@ def _fused_inv_rope_fp8_quant_kernel_impl( HALF_ROPE=half_rope, TMA_ALIGNED_SCALES=tma_aligned_scales, num_stages=1, - launch_pdl=False, + **pdl_kwargs, num_warps=1, ) return fp8_buf, scale_buf diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index ea75f19862e..8278cb70415 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -241,7 +241,7 @@ def fp8_paged_mqa_logits_torch( device=q.device, dtype=torch.float32, ) - if context_lens.dim > 1: + if context_lens.dim() > 1: context_lens = context_lens.squeeze(-1) kv_cache_flat = kv_cache.view(-1, block_size * (dim + 4)) for i in range(batch_size): @@ -254,8 +254,12 @@ def fp8_paged_mqa_logits_torch( pages = block_tables[i, :num_pages] cache = kv_cache_flat[pages] scale_offset = block_size * dim - cache_value = cache[..., :scale_offset].view(dtype=fp8_dtype).to(torch.float32) - cache_scale = cache[..., scale_offset:].view(dtype=torch.float32).contiguous() + cache_value = ( + cache[..., :scale_offset].view(dtype=fp8_dtype).to(torch.float32) + ) + cache_scale = ( + cache[..., scale_offset:].view(dtype=torch.float32).contiguous() + ) cache_value = cache_value.view(padded_seq_len, dim) cache_scale = cache_scale.view(padded_seq_len) score = F.linear(cache_value, q_i) From e091b926740ac8ca2152b55bf374587a9613231f Mon Sep 17 00:00:00 2001 From: whx-sjtu Date: Thu, 30 Apr 2026 06:04:09 +0000 Subject: [PATCH 09/17] disable multistream for rocm Signed-off-by: whx-sjtu --- .../layers/deepseek_v4_attention.py | 31 ++++++++++++++----- vllm/model_executor/models/deepseek_v4.py | 7 ++++- vllm/model_executor/models/deepseek_v4_mtp.py | 8 +++-- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 4707ad501c1..a7164211b74 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -227,6 +227,7 @@ def __init__( + 1 # 1B pad ) + # Will be None on ROCm for now. self.aux_stream_list = mla_modules.aux_stream_list # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins @@ -367,12 +368,15 @@ def forward( return self.wo_b(z.flatten(1)) def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: - assert self.aux_stream_list is not None - assert len(self.aux_stream_list) >= 3 + aux_streams = self.aux_stream_list + if aux_streams is not None: + assert len(aux_streams) >= 3 + aux_streams = aux_streams[:3] # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] # is the fan-out start event; ln_events[1..3] are per-aux done events. + # On ROCm, aux_streams is None and execute_in_parallel runs serially. aux_fns: list[Callable[[], Any] | None] = [None, None, None] if self.compressor is not None: @@ -411,12 +415,19 @@ def fused_wqa_wkv() -> torch.Tensor: qr_kv, _ = self.fused_wqa_wkv(hidden_states) return qr_kv + if aux_streams is None: + qr_kv = fused_wqa_wkv() + kv_score = aux_fns[0]() if aux_fns[0] is not None else None + indexer_weights = aux_fns[1]() if aux_fns[1] is not None else None + indexer_kv_score = aux_fns[2]() if aux_fns[2] is not None else None + return qr_kv, kv_score, indexer_kv_score, indexer_weights + qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( fused_wqa_wkv, aux_fns, self.ln_events[0], self.ln_events[1:4], - self.aux_stream_list[:3], + aux_streams, ) return qr_kv, kv_score, indexer_kv_score, indexer_weights @@ -448,8 +459,11 @@ def attention_impl( # downstream reads q on default). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. if self.indexer is not None: - assert self.aux_stream_list is not None - aux_stream = self.aux_stream_list[0] + aux_stream = ( + self.aux_stream_list[0] + if self.aux_stream_list is not None + else None + ) indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None @@ -477,8 +491,11 @@ def wq_b_kv_insert_and_compress() -> torch.Tensor: ) elif self.compressor is not None: # wq_b + kv_insert on default, compressor on aux. - assert self.aux_stream_list is not None - aux_stream = self.aux_stream_list[0] + aux_stream = ( + self.aux_stream_list[0] + if self.aux_stream_list is not None + else None + ) compressor = self.compressor def wq_b_kv_insert() -> torch.Tensor: diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 5521a9764a9..73aa3538aca 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -1240,7 +1240,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. - aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + # Disable them on ROCm because of hang issues. + aux_stream_list = ( + None + if current_platform.is_rocm() + else [torch.cuda.Stream() for _ in range(3)] + ) self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. diff --git a/vllm/model_executor/models/deepseek_v4_mtp.py b/vllm/model_executor/models/deepseek_v4_mtp.py index a3724e5ebe8..195709c9dac 100644 --- a/vllm/model_executor/models/deepseek_v4_mtp.py +++ b/vllm/model_executor/models/deepseek_v4_mtp.py @@ -167,8 +167,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) # Three aux streams shared across all MTP layers, mirroring - # DeepseekV4Model. - aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + # DeepseekV4Model. ROCm runs the same work serially for now. + aux_stream_list = ( + None + if current_platform.is_rocm() + else [torch.cuda.Stream() for _ in range(3)] + ) # to map the exact layer index from weights self.layers = torch.nn.ModuleDict( From df93d130e4eafb8621ffe65a1d3a9e54ec19dadf Mon Sep 17 00:00:00 2001 From: tjtanaavllm Date: Thu, 30 Apr 2026 03:47:20 +0000 Subject: [PATCH 10/17] add tilelang as dependency to ensure the mhc module can be imported Signed-off-by: tjtanaavllm --- requirements/rocm.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 0b472b90c02..037b20874b5 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -21,3 +21,6 @@ timm>=1.0.17 # amd-quark: required for Quark quantization on ROCm # To be consistent with test_quark.py amd-quark>=0.8.99 +# tilelang has to be installed for mhc module to be +# imported correctly. +tilelang==0.1.9 From b41cc3ef8ee6166bb76bf5b0af14a60301483fd5 Mon Sep 17 00:00:00 2001 From: whx-sjtu Date: Thu, 30 Apr 2026 12:29:59 +0000 Subject: [PATCH 11/17] revert forward_context Signed-off-by: whx-sjtu --- vllm/forward_context.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 54d565ef340..537a28a4252 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -257,7 +257,6 @@ def set_forward_context( batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, - additional_kwargs: dict[str, Any] | None = None, skip_compiled: bool = False, ): """A context manager that stores the current forward context, @@ -297,7 +296,7 @@ def set_forward_context( if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens) - platform_additional_kwargs = current_platform.set_additional_forward_context( + additional_kwargs = current_platform.set_additional_forward_context( attn_metadata=attn_metadata, vllm_config=vllm_config, dp_metadata=dp_metadata, @@ -307,9 +306,6 @@ def set_forward_context( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ) - merged_additional_kwargs = dict(platform_additional_kwargs) - if additional_kwargs: - merged_additional_kwargs.update(additional_kwargs) forward_context = create_forward_context( attn_metadata, @@ -319,7 +315,7 @@ def set_forward_context( batch_descriptor, ubatch_slices, slot_mapping, - merged_additional_kwargs, + additional_kwargs, skip_compiled, ) From d040af06daf4416f0ae7431999118b95c6829434 Mon Sep 17 00:00:00 2001 From: whx-sjtu Date: Fri, 1 May 2026 03:30:22 +0000 Subject: [PATCH 12/17] fix lint Signed-off-by: whx-sjtu --- vllm/model_executor/layers/deepseek_v4_attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index a7164211b74..0edc7392c17 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -460,9 +460,7 @@ def attention_impl( # overlap with default's GEMM + cache write. if self.indexer is not None: aux_stream = ( - self.aux_stream_list[0] - if self.aux_stream_list is not None - else None + self.aux_stream_list[0] if self.aux_stream_list is not None else None ) indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. @@ -492,9 +490,7 @@ def wq_b_kv_insert_and_compress() -> torch.Tensor: elif self.compressor is not None: # wq_b + kv_insert on default, compressor on aux. aux_stream = ( - self.aux_stream_list[0] - if self.aux_stream_list is not None - else None + self.aux_stream_list[0] if self.aux_stream_list is not None else None ) compressor = self.compressor From 895262237e5fa9c866608c4f24014f9aab7446e4 Mon Sep 17 00:00:00 2001 From: ganyi Date: Fri, 1 May 2026 05:09:31 +0000 Subject: [PATCH 13/17] fix compile forever issue Signed-off-by: ganyi --- csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 4f2ec6e7488..b895763dc98 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -438,7 +438,8 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( #else // ROCm: use standard kernel launch syntax (no PDL/stream serialization) fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel - << > >( + // clang-format off + <<>>( q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, kv_block_stride); From e786a2dfcdc592ab82c5ce6d7e11cd6cb5fa139a Mon Sep 17 00:00:00 2001 From: ganyi Date: Fri, 1 May 2026 05:16:53 +0000 Subject: [PATCH 14/17] relocate clang-format off Signed-off-by: ganyi --- csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index b895763dc98..2f2e7ecc182 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -437,8 +437,8 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( kv_block_stride); #else // ROCm: use standard kernel launch syntax (no PDL/stream serialization) + // clang-format off fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel - // clang-format off <<>>( q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, num_tokens_full, num_tokens_insert, num_heads_q, From 1be2b74f0aa314b17df57ec9303f2932813262e8 Mon Sep 17 00:00:00 2001 From: ganyi Date: Fri, 1 May 2026 06:31:16 +0000 Subject: [PATCH 15/17] mi300 support Signed-off-by: ganyi --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 4 ++++ .../layers/deepseek_v4_attention.py | 2 ++ .../layers/quantization/utils/fp8_utils.py | 21 ++++++++++++++++--- .../ops/deepseek_v4_ops/cache_utils.py | 10 +++++++-- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 2f2e7ecc182..7b8d0796fa8 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -85,7 +85,11 @@ constexpr int kQuantBlock = 64; constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 +#if defined(USE_ROCM) && !defined(HIP_FP8_TYPE_OCP) +constexpr float kFp8Max = 240.0f; +#else constexpr float kFp8Max = 448.0f; +#endif // Per-warp layout: 32 lanes × 16 elems/lane = 512 elems = HEAD_DIM. constexpr int kNumLanes = 32; diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 0edc7392c17..eae9f81da8b 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -1124,6 +1124,7 @@ def _forward_prefill( block_table=block_table[chunk_start:chunk_end], block_size=attn_metadata.block_size // self.compress_ratio, offset=0, + use_fnuz=current_platform.is_fp8_fnuz(), ) # Gather SWA KV @@ -1136,6 +1137,7 @@ def _forward_prefill( block_table=swa_block_table[chunk_start:chunk_end], block_size=swa_metadata.block_size, offset=N, + use_fnuz=current_platform.is_fp8_fnuz(), ) # Combine the topk indices and SWA indices for gathered KV cache diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index d9aab35c25f..74736d75f0d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1289,9 +1289,24 @@ def process_fp8_weight_block_strategy( ) if current_platform.is_fp8_fnuz(): - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale - ) + if weight_scale.dtype == torch.float8_e8m0fnu: + # e8m0 stores exponent-only values (2^(exp-127)). + # Doubling == incrementing the exponent byte by 1. + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + exp_bytes = weight_scale.view(torch.uint8) + weight_scale = ( + (exp_bytes.to(torch.int16) + 1) + .clamp(max=254) + .to(torch.uint8) + .view(torch.float8_e8m0fnu) + ) + else: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale + ) weight = _maybe_pad_fp8_weight(weight) return weight, weight_scale diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 69d20c107e1..ae2e300f98b 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -215,6 +215,7 @@ def _dequantize_and_gather_k_kernel( output_dim: tl.constexpr, # 512 fp8_max: tl.constexpr, n_quant_blocks: tl.constexpr, # 7 real blocks + use_fnuz: tl.constexpr = False, ): batch_idx = tl.program_id(0) worker_id = tl.program_id(1) @@ -272,8 +273,11 @@ def _dequantize_and_gather_k_kernel( # Load quantized fp8 values (stored as uint8) x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0) - # Bitcast uint8 back to fp8 - x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + # Bitcast uint8 back to fp8 (FNUZ on gfx942, OCP otherwise) + if use_fnuz: + x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True) + else: + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) # Convert fp8 to float32 for computation x_float = x_fp8.to(tl.float32) @@ -316,6 +320,7 @@ def dequantize_and_gather_k_cache( block_table: torch.Tensor, block_size: int, offset: int, + use_fnuz: bool = False, ) -> None: TOKEN_FP8_DIM = 448 TOKEN_BF16_DIM = 64 @@ -346,6 +351,7 @@ def dequantize_and_gather_k_cache( output_dim=512, fp8_max=FP8_MAX, n_quant_blocks=7, + use_fnuz=use_fnuz, ) From dde2fb080ae28aac471782f85ed722b0be03ef23 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 27 Nov 2025 06:04:24 +0000 Subject: [PATCH 16/17] further optimize dsv3.2 Signed-off-by: ganyi --- docs/design/attention_backends.md | 2 +- vllm/model_executor/models/deepseek_v2.py | 61 +++-- vllm/v1/attention/backends/mla/indexer.py | 2 +- .../attention/backends/mla/rocm_aiter_mla.py | 4 + .../backends/mla/rocm_aiter_mla_sparse.py | 256 ++++++++++++++++-- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 41 +-- 6 files changed, 293 insertions(+), 73 deletions(-) diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index bdbe46ad917..4a2e909d50b 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -216,7 +216,7 @@ configuration. | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 54d2cb53b0d..15913c418b0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -674,30 +674,45 @@ def forward( ) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) - q_pe, q_nope = torch.split( - q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 - ) - # Fused wk + weights_proj: one GEMM, then split - kw, _ = self.wk_weights_proj(hidden_states) - k = kw[:, : self.head_dim] - weights = kw[:, self.head_dim :] - - k = self.k_norm(k) - k_pe, k_nope = torch.split( - k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 - ) - q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation - # so we need to reshape back to token-flattened shapes - q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) - k_pe = k_pe.reshape(-1, 1, self.rope_dim) - - # `rotary_emb` is shape-preserving; `q_pe` is already - # [num_tokens, n_head, rope_dim]. - q = torch.cat([q_pe, q_nope], dim=-1) - # `k_pe` is [num_tokens, 1, rope_dim] (MQA). - k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) + if current_platform.is_rocm(): + # This path should works on all platform, will remove extra + # branches in the future + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] + + k = self.k_norm(k) + + rotary_emb( + positions, q[..., : self.rope_dim], k[..., : self.rope_dim].unsqueeze(1) + ) + else: + q_pe, q_nope = torch.split( + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + # Fused wk + weights_proj: one GEMM, then split + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] + + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) + + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + # Note: RoPE (NeoX) can introduce extra leading dimensions during + # compilation so we need to reshape back to token-flattened shapes + q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim) + k_pe = k_pe.reshape(-1, 1, self.rope_dim) + + # `rotary_emb` is shape-preserving; `q_pe` is already + # [num_tokens, n_head, rope_dim]. + q = torch.cat([q_pe, q_nope], dim=-1) + # `k_pe` is [num_tokens, 1, rope_dim] (MQA). + k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 5d12d27e762..7c0715a9e8b 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -122,7 +122,7 @@ def get_name() -> str: @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1 if current_platform.is_rocm() else 64] + return [1, 64] if current_platform.is_rocm() else [64] @classmethod def get_supported_head_sizes(cls) -> list[int]: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index a66a97311fb..2106226118e 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -396,6 +396,7 @@ class AiterMLAHelper: """ _AITER_MIN_MLA_HEADS: Final = 16 + _AITER_UNSUPPORTED_HEADS = [32] @staticmethod def check_num_heads_validity(num_heads: int): @@ -419,6 +420,9 @@ def get_actual_mla_num_heads(num_heads: int) -> int: @staticmethod def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor: + assert num_heads not in AiterMLAHelper._AITER_UNSUPPORTED_HEADS, ( + f"unsupported head_num: {num_heads}" + ) return ( q if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 503bb509b10..dc343b639f6 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -7,6 +7,7 @@ import numpy as np import torch +from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.config.cache import CacheDType @@ -14,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( get_mla_dims, ) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, @@ -25,9 +27,6 @@ MultipleOf, SparseMLAAttentionImpl, ) -from vllm.v1.attention.backends.mla.flashmla_sparse import ( - triton_convert_req_index_to_global_index, -) from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( AiterMLAHelper, ) @@ -38,6 +37,188 @@ logger = init_logger(__name__) +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + cu_seqlens_ptr, # int32 [num_tokens + 1] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load cumulative sequence lengths to get starting index of this request + seq_start = tl.load(cu_seqlens_ptr + token_id) + seq_end = tl.load(cu_seqlens_ptr + token_id + 1) + + if tile_id * BLOCK_N + seq_start >= seq_end: + return + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # # If token == -1 OR block_id OOB, output 0; else base * BLOCK_SIZE + offset + out_val = tl.where( + is_invalid_tok | (~valid_block), 0, base * BLOCK_SIZE + inblock_off + ) + out_ptr_ij = out_ptr + seq_start + indice_id + out_ptr_ij_mask = (seq_start + indice_id) < seq_end + + # store the results with mask + tl.store(out_ptr_ij, out_val, mask=out_ptr_ij_mask) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + cu_seqlens: torch.Tensor, # int32 [num_tokens + 1] + paged_kv_indices: torch.Tensor, # int32 [num_tokens * topk] out_buffer + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + ) + # print("req_id: ", req_id, flush=True) + num_tokens = req_id.shape[0] + _, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + cu_seqlens, + paged_kv_indices, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + ) + return + + +@triton.jit +def generate_sparse_seqlen_kernel( + seq_len_ptr, # [num_seq] + cu_query_lens_ptr, # [num_seq] + out_ptr, # [num_query_tokens] + topk_token: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + seq_id = tl.program_id(0) + query_offset = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + query_start = tl.load(cu_query_lens_ptr + seq_id) + query_end = tl.load(cu_query_lens_ptr + seq_id + 1) + if query_start + tl.program_id(1) * BLOCK_SIZE > query_end: + return + query_len = query_end - query_start + query_mask = query_offset + query_start < query_end + seq_len = tl.load(seq_len_ptr + seq_id) + # Just return since the out_ptr is zero initialized. + if seq_len == 0: + return + context_start_point = seq_len - query_len + sparse_seqlen = context_start_point + query_offset + sparse_seqlen_masked = tl.where( + sparse_seqlen + 1 < topk_token, sparse_seqlen + 1, topk_token + ) + tl.store( + out_ptr + query_start + query_offset, sparse_seqlen_masked, mask=query_mask + ) + + +def generate_sparse_seqlen_triton( + query_lens: torch.Tensor, + seq_lens: torch.Tensor, + cu_query_lens: torch.Tensor, + topk_token: int, + num_tokens: int, + max_query_len: int, +): + num_seqs = query_lens.size(0) + # zero initialize the tensor to make sure invalid positions will be zero + out = torch.zeros([num_tokens], dtype=torch.int32, device=query_lens.device) + block_size = 64 + num_block_per_row = triton.cdiv(max_query_len, block_size) + grid = ( + num_seqs, + num_block_per_row, + ) + generate_sparse_seqlen_kernel[grid]( + seq_lens, + cu_query_lens, + out, + topk_token, + block_size, + ) + return out + + @triton.jit def fetch_id_to_ragged_kernel( in_tensor_ptr, # [num_seq, topk] @@ -86,11 +267,13 @@ class ROCMAiterMLASparseBackend(AttentionBackend): "auto", "float16", "bfloat16", + "fp8", + "fp8_e4m3", ] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [1] + return [1, 64] @staticmethod def get_name() -> str: @@ -144,7 +327,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata): paged_kv_last_page_len: torch.Tensor paged_kv_indices: torch.Tensor paged_kv_indptr: torch.Tensor - paged_kv_indptr_rest: torch.Tensor + attn_out_dtype: torch.dtype block_size: int = 1 topk_tokens: int = 2048 @@ -167,6 +350,7 @@ def __init__( ): self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config + self.model_dtype = vllm_config.model_config.dtype parallel_config = vllm_config.parallel_config self.device = device max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens @@ -174,9 +358,6 @@ def __init__( self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk - self.topk_tokens_tensor = torch.tensor( - [self.topk_tokens], device=device, dtype=torch.int32 - ) self.max_model_len_tensor = torch.tensor( [self.model_config.max_model_len], device=device, dtype=torch.int32 ) @@ -222,18 +403,33 @@ def build( ) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) + self.paged_kv_indices.fill_(0) + self.paged_kv_indptr.fill_(0) self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( torch.from_numpy(req_id_per_token), non_blocking=True ) - self.paged_kv_indices.fill_(0) - self.paged_kv_indptr.fill_(0) + query_lens = ( + common_attn_metadata.query_start_loc[1:] + - common_attn_metadata.query_start_loc[:-1] + ) + seq_lens = common_attn_metadata.seq_lens + sparse_seqlen = generate_sparse_seqlen_triton( + query_lens, + seq_lens, + common_attn_metadata.query_start_loc, + self.topk_tokens, + num_tokens, + common_attn_metadata.max_query_len, + ) + + torch.cumsum(sparse_seqlen, dim=0, out=self.paged_kv_indptr[1 : num_tokens + 1]) + self.paged_kv_indptr[num_tokens + 1 :].fill_(self.paged_kv_indptr[num_tokens]) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] qo_indptr = self.qo_indptr[: num_tokens + 1] paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens] - paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1] - paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :] + paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens] metadata = ROCMAiterMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -245,12 +441,12 @@ def build( block_table=common_attn_metadata.block_table_tensor, req_id_per_token=req_id_per_token, block_size=self.kv_cache_spec.block_size, + attn_out_dtype=self.model_dtype, topk_tokens=self.topk_tokens, qo_indptr=qo_indptr, paged_kv_last_page_len=paged_kv_last_page_len, paged_kv_indices=paged_kv_indices, paged_kv_indptr=paged_kv_indptr, - paged_kv_indptr_rest=paged_kv_indptr_rest, ) return metadata @@ -314,29 +510,20 @@ def __init__( assert indexer is not None self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer - def _forward_bf16_kv( + def _forward_mla( self, + layer: AttentionLayer, q: torch.Tensor, # [sq, heads, d_qk] kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk] - topk_indices: torch.Tensor, # [sq, topk] attn_metadata: ROCMAiterMLASparseMetadata, ) -> torch.Tensor: num_tokens = q.shape[0] mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads) output = torch.empty( [num_tokens, mla_num_heads, self.kv_lora_rank], - dtype=q.dtype, + dtype=attn_metadata.attn_out_dtype, device=q.device, ) - seq_len = (topk_indices != -1).sum(dim=-1) - torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:]) - attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1]) - fetch_id_to_ragged_triton( - topk_indices, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.topk_tokens, - ) rocm_aiter_ops.mla_decode_fwd( q, @@ -348,6 +535,8 @@ def _forward_bf16_kv( attn_metadata.paged_kv_indptr, attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_len, + q_scale=layer._q_scale, + kv_scale=layer._k_scale, ) return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, output) @@ -366,23 +555,32 @@ def forward_mqa( if isinstance(q, tuple): q = torch.cat(q, dim=-1) - num_actual_toks = q.shape[0] + num_actual_toks = attn_metadata.num_actual_tokens # Get topk indices assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[:num_actual_toks] - topk_indices_global = triton_convert_req_index_to_global_index( + triton_convert_req_index_to_global_index( attn_metadata.req_id_per_token, attn_metadata.block_table, topk_indices, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, BLOCK_SIZE=attn_metadata.block_size, NUM_TOPK_TOKENS=attn_metadata.topk_tokens, ) + # write the latent and rope to kv cache + fp8_attention = self.kv_cache_dtype.startswith("fp8") + if fp8_attention: + original_q_shape = q.shape + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(current_platform.fp8_dtype()) + q, _ = ops.scaled_fp8_quant(q.view(q.shape[0], -1), layer._q_scale) + q = q.view(original_q_shape) mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q) - attn_out = self._forward_bf16_kv( - mla_padded_q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata + attn_out = self._forward_mla( + layer, mla_padded_q, kv_c_and_k_pe_cache, attn_metadata ) return attn_out, None diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 8278cb70415..ffa22b61ef3 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -14,9 +14,6 @@ from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops - @triton.jit def _indexer_k_quant_and_cache_kernel( @@ -98,7 +95,8 @@ def indexer_k_quant_and_cache_triton( # In real layout, we store the first portion as kv cache value # and second portion as kv cache scale kv_cache = kv_cache.view(num_blocks, -1) - kv_cache_value = kv_cache[:, : block_size * head_dim] + fp8_dtype = current_platform.fp8_dtype() + kv_cache_value = kv_cache[:, : block_size * head_dim].view(fp8_dtype) kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) head_tile_size = head_tile_size // kv_cache.element_size() grid = (num_tokens,) @@ -112,7 +110,7 @@ def indexer_k_quant_and_cache_triton( block_size, num_tokens, head_dim, - "NHD", + "SHUFFLE", block_tile_size, head_tile_size, IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, @@ -213,7 +211,7 @@ def cp_gather_indexer_k_quant_cache_triton( block_table_stride, k_cache_value.stride(0), k_cache_scale.stride(0), - "NHD", + "SHUFFLE", head_dim, block_tile_size, head_tile_size, @@ -373,33 +371,38 @@ def rocm_fp8_paged_mqa_logits( from vllm._aiter_ops import rocm_aiter_ops aiter_paged_mqa_logits_module = None + # if rocm_aiter_ops.is_enabled(): + batch_size, next_n, heads, head_dim = q_fp8.shape + num_blocks, block_size, _, _ = kv_cache_fp8.shape + if rocm_aiter_ops.is_enabled(): aiter_paged_mqa_logits_module = paged_mqa_logits_module() if aiter_paged_mqa_logits_module is not None: - deepgemm_fp8_paged_mqa_logits_stage1 = ( - aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 + deepgemm_fp8_paged_mqa_logits = ( + aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits ) batch_size, next_n, heads, _ = q_fp8.shape - out_qk = torch.full( - (heads, batch_size * next_n, max_model_len), + out_logits = torch.full( + [batch_size * next_n, max_model_len], float("-inf"), device="cuda", dtype=torch.float32, ) - # TODO: 1. Replace _stage1 and out_qk.sum with another fused variant; - # 2. Remove ChunkQ when AITER PR #2891 merged - deepgemm_fp8_paged_mqa_logits_stage1( + deepgemm_fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, weights, - out_qk, + out_logits, context_lens, block_tables, max_model_len, - ChunkQ=heads, + ChunkK=256, + Preshuffle=block_size == 64, + KVBlockSize=block_size, + WavePerEU=2, ) - return out_qk.sum(dim=0) + return out_logits else: return fp8_paged_mqa_logits_torch( q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len @@ -616,7 +619,7 @@ def rocm_aiter_sparse_attn_indexer_native( raise ValueError("k must be provided when skip_k_cache_insert is False") if not skip_k_cache_insert: - ops.indexer_k_quant_and_cache( + indexer_k_quant_and_cache_triton( k, kv_cache, slot_mapping, @@ -639,13 +642,13 @@ def rocm_aiter_sparse_attn_indexer_native( device=device, dtype=torch.uint8, ) - - ops.cp_gather_indexer_k_quant_cache( + cp_gather_indexer_k_quant_cache_triton( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, + chunk.token_to_seq, ) logits = rocm_fp8_mqa_logits( From 80fed5b250feeda8229c1300168a3f73d1139bca Mon Sep 17 00:00:00 2001 From: ganyi Date: Fri, 1 May 2026 07:14:20 +0000 Subject: [PATCH 17/17] accuracy right Signed-off-by: ganyi --- csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 7b8d0796fa8..f328e3acb79 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -56,7 +56,11 @@ #ifdef USE_ROCM // ROCm-compatible FP8 conversion helpers __device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) { - #if defined(HIP_FP8_TYPE_OCP) + // HIP defines HIP_FP8_TYPE_OCP based on HIP version, not GPU arch. On gfx942 + // mfma only supports FNUZ fp8, and the rest of vLLM's gfx942 path (Triton + // indexer / current_platform.fp8_dtype()) uses FNUZ. Gate OCP on __gfx950__ + // so the K cache encoding matches what the reader expects. + #if defined(HIP_FP8_TYPE_OCP) && defined(__gfx950__) __hip_fp8_e4m3 fp8_val(val); #else __hip_fp8_e4m3_fnuz fp8_val(val); @@ -85,7 +89,9 @@ constexpr int kQuantBlock = 64; constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 -#if defined(USE_ROCM) && !defined(HIP_FP8_TYPE_OCP) +// Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942 +// (max 240), OCP on gfx950 (max 448). +#if defined(USE_ROCM) && (!defined(HIP_FP8_TYPE_OCP) || !defined(__gfx950__)) constexpr float kFp8Max = 240.0f; #else constexpr float kFp8Max = 448.0f;