diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index e96017d86dad..3e903080c7f6 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -29,7 +29,11 @@ */ #include -#include +#ifndef USE_ROCM + #include +#else + #include +#endif #include #include @@ -42,7 +46,23 @@ #include "type_convert.cuh" #ifndef FINAL_MASK - #define FINAL_MASK 0xffffffffu + #ifdef USE_ROCM + #define FINAL_MASK 0xffffffffffffffffULL + #else + #define FINAL_MASK 0xffffffffu + #endif +#endif + +#ifdef USE_ROCM +// ROCm-compatible FP8 conversion helpers +__device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) { + #if defined(HIP_FP8_TYPE_OCP) + __hip_fp8_e4m3 fp8_val(val); + #else + __hip_fp8_e4m3_fnuz fp8_val(val); + #endif + return reinterpret_cast(fp8_val); +} #endif namespace vllm { @@ -314,9 +334,13 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( for (int i = 0; i < kElemsPerLane; i++) { float scaled = elements[i] * inv_scale; scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max); +#ifndef USE_ROCM __nv_fp8_storage_t s = __nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3); out_bytes[i] = static_cast(s); +#else + out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); +#endif } // One 16-byte STG per lane. *reinterpret_cast(token_fp8_ptr + dim_base) = @@ -384,6 +408,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( // PDL: enable programmatic stream serialization whenever the hardware // supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable, // so leave numAttrs = 0 and launch as a regular kernel. +#ifndef USE_ROCM static int const sm_version = getSMVersion(); // Host-side guard: the device kernel body is compiled as a no-op for // bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert is @@ -410,6 +435,14 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, kv_block_stride); +#else + // ROCm: use standard kernel launch syntax (no PDL/stream serialization) + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel + <<>>( + q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, + num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, + kv_block_stride); +#endif } } // namespace deepseek_v4_fused_ops diff --git a/csrc/moe/topk_softplus_sqrt_kernels.cu b/csrc/moe/topk_softplus_sqrt_kernels.cu index 50a8540a7374..7cbba6b6354e 100644 --- a/csrc/moe/topk_softplus_sqrt_kernels.cu +++ b/csrc/moe/topk_softplus_sqrt_kernels.cu @@ -60,7 +60,11 @@ __device__ __forceinline__ float toFloat(T value) { } } -#define FINAL_MASK 0xffffffff +#ifdef USE_ROCM + #define FINAL_MASK 0xffffffffffffffffULL +#else + #define FINAL_MASK 0xffffffff +#endif template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3d5fbd18d6e8..8653e1d3a83c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2451,7 +2451,11 @@ def moe_wna16_gemm( def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K).""" - return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight) + if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"): + return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight) + + # Native fallback for platforms/builds without the custom MoE GEMM op. + return torch.matmul(input.to(torch.float32), weight.to(torch.float32).t()) if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"): diff --git a/vllm/config/kernel.py b/vllm/config/kernel.py index 93fb4c54b7f1..3de89eacdd34 100644 --- a/vllm/config/kernel.py +++ b/vllm/config/kernel.py @@ -115,6 +115,7 @@ def with_default( "flashinfer_cutlass", "flashinfer_cutedsl", "marlin", + "triton_unfused", "aiter", "emulation", ] @@ -145,6 +146,7 @@ class KernelConfig: - "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels - "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only) - "marlin": Use Marlin kernels (weight-only quantization) + - "triton_unfused": Use Triton unfused MoE kernels - "aiter": Use AMD AITer kernels (ROCm only) - "emulation": use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 537a28a42526..54d565ef340a 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -257,6 +257,7 @@ def set_forward_context( batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, + additional_kwargs: dict[str, Any] | None = None, skip_compiled: bool = False, ): """A context manager that stores the current forward context, @@ -296,7 +297,7 @@ def set_forward_context( if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens) - additional_kwargs = current_platform.set_additional_forward_context( + platform_additional_kwargs = current_platform.set_additional_forward_context( attn_metadata=attn_metadata, vllm_config=vllm_config, dp_metadata=dp_metadata, @@ -306,6 +307,9 @@ def set_forward_context( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ) + merged_additional_kwargs = dict(platform_additional_kwargs) + if additional_kwargs: + merged_additional_kwargs.update(additional_kwargs) forward_context = create_forward_context( attn_metadata, @@ -315,7 +319,7 @@ def set_forward_context( batch_descriptor, ubatch_slices, slot_mapping, - additional_kwargs, + merged_additional_kwargs, skip_compiled, ) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 8a8650d22135..5ded5ca798ad 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -312,6 +312,21 @@ def apply_block_scaled_mm( As: torch.Tensor, Bs: torch.Tensor, ) -> torch.Tensor: + if As.dtype != Bs.dtype: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + if As.dtype == torch.float8_e8m0fnu: + As = _upcast_e8m0_to_fp32(As).contiguous() + else: + As = As.to(torch.float32) + + if Bs.dtype == torch.float8_e8m0fnu: + Bs = _upcast_e8m0_to_fp32(Bs).contiguous() + else: + Bs = Bs.to(torch.float32) + out_dtype = self.config.out_dtype if self.use_triton: gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 59cc95f18c58..df9459012ae8 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -169,7 +169,9 @@ class SiluAndMulWithClamp(CustomOp): def __init__(self, swiglu_limit: float, *, compile_native: bool = True): super().__init__(compile_native=compile_native) self.swiglu_limit = float(swiglu_limit) - if current_platform.is_cuda_alike() or current_platform.is_xpu(): + if current_platform.is_rocm(): + self._forward_method = self.forward_native + elif current_platform.is_cuda_alike() or current_platform.is_xpu(): self.op = torch.ops._C.silu_and_mul_with_clamp elif current_platform.is_cpu(): self._forward_method = self.forward_native diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index af2783f604da..cfeb76aa0c62 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -7,6 +7,7 @@ import torch from torch import nn +from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -29,6 +30,7 @@ _fused_kv_compress_norm_rope_insert_indexer_mxfp4_attn, _fused_kv_compress_norm_rope_insert_sparse_attn, ) +from vllm.v1.attention.ops.deepseek_v4_ops.cache_utils import quantize_and_insert_k_cache from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( MXFP4_BLOCK_SIZE, ) @@ -175,6 +177,54 @@ def get_attn_backend(self) -> type[AttentionBackend]: return CompressorBackend +def hadamard_transform_ref(x: torch.Tensor) -> torch.Tensor: + hidden_size = x.shape[-1] + assert hidden_size > 0 and (hidden_size & (hidden_size - 1)) == 0, ( + f"Hidden size must be a power of 2, got {hidden_size}" + ) + dtype = x.dtype + y = x.to(torch.float32).reshape(-1, hidden_size) + h = 1 + while h < hidden_size: + y = y.view(-1, hidden_size // (2 * h), 2, h) + a = y[:, :, 0, :] + b = y[:, :, 1, :] + y = torch.cat((a + b, a - b), dim=-1) + h *= 2 + y = y.view(*x.shape) * (hidden_size**-0.5) + return y.to(dtype) + + +def apply_gptj_rope_ref( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if rope_dim == 0 or x.numel() == 0: + return x + half_rot = rope_dim // 2 + nope_dim = x.shape[-1] - rope_dim + dtype = x.dtype + x = x.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) + view_shape = (positions.shape[0],) + (1,) * (x.dim() - 2) + (half_rot,) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = x[..., nope_dim:] + x_even = rope[..., 0::2] + x_odd = rope[..., 1::2] + rope_out = torch.stack( + (x_even * cos - x_odd * sin, x_odd * cos + x_even * sin), + dim=-1, + ).flatten(-2) + x = x.clone() + x[..., nope_dim:] = rope_out + return x.to(dtype) + + class DeepseekCompressor(nn.Module): def __init__( self, @@ -240,6 +290,9 @@ def __init__( self._static_forward_context = ( vllm_config.compilation_config.static_forward_context ) + self._old_kv_state: dict[str, torch.Tensor] = {} + self._old_score_state: dict[str, torch.Tensor] = {} + self._old_need_hadamard = self.head_dim == 128 if self.head_dim == 512: assert not use_fp4_cache, ( @@ -277,6 +330,10 @@ def forward( positions: torch.Tensor, rotary_emb, ) -> None: + if current_platform.is_rocm(): + self._forward_old(x, positions, rotary_emb) + return None + num_tokens, _ = x.shape # bf16 weights/activations but fp32 output for numerical stability of # the downstream compressor math. @@ -329,7 +386,6 @@ def forward( TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), STATE_WIDTH=state_width, COMPRESS_RATIO=self.compress_ratio, - launch_pdl=False, ) # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. @@ -378,9 +434,195 @@ def forward( SCALE_DIM=self._scale_dim, KV_BLOCK_STRIDE=kv_cache.stride(0), num_warps=self._num_warps, - launch_pdl=False, ) + @property + def _state_width(self) -> int: + return self.coff * self.head_dim + + @property + def _state_len(self) -> int: + return self.compress_ratio * self.coff + + def _get_old_state(self, req_id: str, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + kv_state = self._old_kv_state.get(req_id) + score_state = self._old_score_state.get(req_id) + if kv_state is None or score_state is None or kv_state.device != device: + kv_state = torch.zeros( + (self._state_len, self._state_width), + dtype=torch.float32, + device=device, + ) + score_state = torch.full( + (self._state_len, self._state_width), + float("-inf"), + dtype=torch.float32, + device=device, + ) + self._old_kv_state[req_id] = kv_state + self._old_score_state[req_id] = score_state + return kv_state, score_state + + def _clear_old_state(self, kv_state: torch.Tensor, score_state: torch.Tensor) -> None: + kv_state.zero_() + score_state.fill_(float("-inf")) + + def _overlap_transform(self, tensor: torch.Tensor, fill_value: Any) -> torch.Tensor: + assert tensor.dim() == 3 + assert tensor.shape[1:] == (self.compress_ratio, 2 * self.head_dim) + s, r, d = tensor.shape[0], self.compress_ratio, self.head_dim + new_tensor = tensor.new_full((s, 2 * r, d), fill_value) + new_tensor[:, r:] = tensor[:, :, d:] + new_tensor[1:, :r] = tensor[:-1, :, :d] + return new_tensor + + def _ref_rms_norm(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(torch.float32) + x = x * torch.rsqrt(x.square().mean(dim=-1, keepdim=True) + self.rms_norm_eps) + return x * self.norm.weight.to(torch.float32) + + @staticmethod + def _compute_state_len(seq_len: int, ratio: int) -> int: + return seq_len % ratio + (ratio == 4) * ratio + + def _forward_old( + self, + x: torch.Tensor, + positions: torch.Tensor, + rotary_emb, + ) -> None: + num_tokens, _ = x.shape + kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight) + kv, score = kv_score.split([self._state_width, self._state_width], dim=-1) + + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if not isinstance(attn_metadata, dict): + return + + state_metadata = attn_metadata[self.state_cache.prefix] + token_to_req_indices = state_metadata.token_to_req_indices[:num_tokens] + k_cache_metadata = attn_metadata[self.k_cache_prefix] + slot_mapping = k_cache_metadata.slot_mapping[:num_tokens] + kv_cache = self._static_forward_context[self.k_cache_prefix].kv_cache + req_ids = forward_context.additional_kwargs.get("req_ids") + if req_ids is None: + max_req_idx = int(token_to_req_indices.max().item()) + 1 if num_tokens > 0 else 0 + req_ids = [str(i) for i in range(max_req_idx)] + + _, counts = torch.unique_consecutive(token_to_req_indices.to(torch.int64), return_counts=True) + start = 0 + for req_idx_tensor, count_tensor in zip( + torch.unique_consecutive(token_to_req_indices.to(torch.int64)), counts + ): + req_idx = int(req_idx_tensor.item()) + count = int(count_tensor.item()) + end = start + count + req_id = req_ids[req_idx] + kv_state, score_state = self._get_old_state(req_id, x.device) + + positions_i = positions[start:end].to(torch.int64) + prefix_len = int(positions_i[0].item()) + query_len = end - start + if prefix_len == 0: + self._clear_old_state(kv_state, score_state) + + pre_state_len = self._compute_state_len(prefix_len, self.compress_ratio) + valid_kv_len = pre_state_len + query_len + + temp_kv = torch.empty( + (valid_kv_len, self._state_width), dtype=torch.float32, device=x.device + ) + temp_score = torch.empty_like(temp_kv) + if pre_state_len > 0: + temp_kv[:pre_state_len] = kv_state[:pre_state_len] + temp_score[:pre_state_len] = score_state[:pre_state_len] + temp_kv[pre_state_len:] = kv[start:end] + temp_score[pre_state_len:] = score[start:end] + + post_state_len = self._compute_state_len(valid_kv_len, self.compress_ratio) + kv_state[:post_state_len] = temp_kv[valid_kv_len - post_state_len : valid_kv_len] + score_state[:post_state_len] = temp_score[ + valid_kv_len - post_state_len : valid_kv_len + ] + if post_state_len < self._state_len: + kv_state[post_state_len:].zero_() + score_state[post_state_len:].fill_(float("-inf")) + + compress_len = valid_kv_len // self.compress_ratio * self.compress_ratio + if compress_len == 0: + start = end + continue + + kv_to_compress = temp_kv[:compress_len].view( + compress_len // self.compress_ratio, + self.compress_ratio, + self._state_width, + ) + score_to_compress = temp_score[:compress_len].view_as(kv_to_compress) + score_to_compress = score_to_compress + self.ape.unsqueeze(0) + + if self.overlap: + kv_to_compress = self._overlap_transform(kv_to_compress, 0.0) + score_to_compress = self._overlap_transform( + score_to_compress, float("-inf") + ) + kv_to_compress = kv_to_compress[1:] + score_to_compress = score_to_compress[1:] + if kv_to_compress.numel() == 0: + start = end + continue + + kv_compressed = ( + kv_to_compress * torch.softmax(score_to_compress, dim=1) + ).sum(dim=1) + kv_compressed = self._ref_rms_norm(kv_compressed) + + first_compressed_pos = prefix_len + first_compressed_pos += self.compress_ratio - 1 - first_compressed_pos % self.compress_ratio + compressed_positions = torch.arange( + first_compressed_pos, + prefix_len + query_len, + self.compress_ratio, + device=x.device, + dtype=torch.int64, + ) + if compressed_positions.numel() != kv_compressed.shape[0]: + raise RuntimeError( + f"Compressed positions mismatch for req {req_id}: " + f"{compressed_positions.numel()} vs {kv_compressed.shape[0]}" + ) + + kv_compressed = apply_gptj_rope_ref( + kv_compressed, compressed_positions, rotary_emb.cos_sin_cache, self.rope_head_dim + ).to(torch.bfloat16) + if self._old_need_hadamard: + kv_compressed = hadamard_transform_ref(kv_compressed) + + local_output = torch.zeros( + (query_len, self.head_dim), dtype=torch.bfloat16, device=x.device + ) + local_output[(compressed_positions - prefix_len).to(torch.long)] = kv_compressed + + if self.head_dim == 512: + quantize_and_insert_k_cache( + local_output, + kv_cache, + slot_mapping[start:end], + block_size=kv_cache.shape[1], + is_ue8m0=True, + ) + else: + ops.indexer_k_quant_and_cache( + local_output, + kv_cache, + slot_mapping[start:end], + self._quant_block, + "ue8m0", + ) + + start = end + @triton.jit def _save_partial_states_kernel( diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 43242eddb5b2..e82cf5a5c342 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,6 +4,8 @@ DeepseekV4 MLA Attention Layer """ +import math + from dataclasses import dataclass from typing import TYPE_CHECKING, cast @@ -42,7 +44,11 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor +from vllm.model_executor.layers.deepseek_compressor import ( + DeepseekCompressor, + apply_gptj_rope_ref, + hadamard_transform_ref, +) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import ( @@ -51,6 +57,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) +from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import maybe_execute_in_parallel from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( @@ -193,8 +200,6 @@ def __init__( # Pick fp8_einsum recipe based on GPU arch: # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 - from vllm.platforms import current_platform - cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) @@ -205,6 +210,8 @@ def __init__( self.topk_indices_buffer = mla_modules.topk_indices_buffer self.indexer = mla_modules.indexer + # Keep ROCm on the BF16 reference wo_a path util kernel ready. + self.use_ref_wo_a_path = current_platform.is_rocm() # Per-head RMS normalization for Q (no learnable weights) self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) @@ -300,6 +307,32 @@ def forward( ) o = o_padded[:, : self.n_local_heads, :] + if self.use_ref_wo_a_path: + o_ref = _apply_inv_rope_ref( + self.rotary_emb, o, positions, self.rope_head_dim + ).to(torch.bfloat16) + o_ref = o_ref.view(num_tokens, self.n_local_groups, -1) + + hidden_dim = o_ref.shape[-1] + if hasattr(self.wo_a, "weight_scale_inv"): + wo_a_weight = self.wo_a.weight.view( + self.n_local_groups, self.o_lora_rank, hidden_dim + ).to(torch.float32) + wo_a_scale = _expand_2d_block_scales( + self.wo_a.weight_scale_inv.view( + self.n_local_groups, -1, self.wo_a.weight_scale_inv.shape[-1] + ), + self.o_lora_rank, + hidden_dim, + ) + wo_a_weight = (wo_a_weight * wo_a_scale).to(torch.bfloat16) + else: + wo_a_weight = self.wo_a.weight.view( + self.n_local_groups, self.o_lora_rank, hidden_dim + ).to(torch.bfloat16) + z = torch.einsum("tgd,grd->tgr", o_ref, wo_a_weight) + return self.wo_b(z.flatten(1)) + # O projection: inverse RoPE + FP8 quant + einsum + wo_b o_fp8, o_scale = fused_inv_rope_fp8_quant( o, @@ -494,7 +527,128 @@ def deepseek_v4_fp8_einsum( equation: str, recipe: list[int], ) -> None: - fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + try: + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + except RuntimeError as exc: + if "DeepGEMM backend is not available or outdated" not in str(exc): + raise + _deepseek_v4_fp8_einsum_fallback(a, a_scale, b, b_scale, out, equation) + + +def _decode_e8m0_scales(scale: torch.Tensor) -> torch.Tensor: + if scale.dtype == torch.float8_e8m0fnu: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + return _upcast_e8m0_to_fp32(scale).contiguous() + return scale.to(torch.float32) + + +def _expand_last_dim_scales(scale: torch.Tensor, last_dim: int) -> torch.Tensor: + scale = _decode_e8m0_scales(scale) + block = math.ceil(last_dim / scale.shape[-1]) + return torch.repeat_interleave(scale, block, dim=-1)[..., :last_dim] + + +def _expand_2d_block_scales( + scale: torch.Tensor, + rows: int, + cols: int, +) -> torch.Tensor: + scale = _decode_e8m0_scales(scale) + row_blocks, col_blocks = scale.shape[-2:] + row_block = math.ceil(rows / row_blocks) + col_block = math.ceil(cols / col_blocks) + scale = torch.repeat_interleave(scale, row_block, dim=-2)[..., :rows, :] + scale = torch.repeat_interleave(scale, col_block, dim=-1)[..., :, :cols] + return scale + + +def _apply_gptj_inv_rope_ref( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if rope_dim == 0 or x.numel() == 0: + return x + half_rot = rope_dim // 2 + nope_dim = x.shape[-1] - rope_dim + dtype = x.dtype + x = x.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot : 2 * half_rot].to(torch.float32) + view_shape = (positions.shape[0],) + (1,) * (x.dim() - 2) + (half_rot,) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = x[..., nope_dim:] + y_even = rope[..., 0::2] + y_odd = rope[..., 1::2] + rope_out = torch.stack( + (y_even * cos + y_odd * sin, y_odd * cos - y_even * sin), + dim=-1, + ).flatten(-2) + x = x.clone() + x[..., nope_dim:] = rope_out + return x.to(dtype) + + +def _apply_inv_rope_ref( + rotary_emb: torch.nn.Module, + x: torch.Tensor, + positions: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + if hasattr(rotary_emb, "forward_native"): + try: + query, _ = rotary_emb.forward_native( + positions, + x.clone(), + None, + inverse=True, + ) + return query + except TypeError: + pass + return _apply_gptj_inv_rope_ref(x, positions, rotary_emb.cos_sin_cache, rope_dim) + + +def _deepseek_v4_fp8_einsum_fallback( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, +) -> None: + if equation != "bhr,hdr->bhd": + raise RuntimeError(f"Unsupported fallback equation: {equation}") + + num_groups = a.shape[1] + hidden_dim = a.shape[2] + output_dim = b.shape[0] // num_groups + + if b.shape[0] % num_groups != 0: + raise RuntimeError( + f"Cannot reshape weight of shape {tuple(b.shape)} into " + f"({num_groups}, {output_dim}, {hidden_dim})." + ) + + a_deq = (a.to(torch.float32) * _expand_last_dim_scales(a_scale, hidden_dim)).to( + torch.bfloat16 + ) + + b_deq = b.view(num_groups, output_dim, hidden_dim).to(torch.float32) + b_scale_deq = _expand_2d_block_scales( + b_scale.view(num_groups, -1, b_scale.shape[-1]), + output_dim, + hidden_dim, + ) + b_deq = (b_deq * b_scale_deq).to(torch.bfloat16) + + out.copy_(torch.einsum(equation, a_deq, b_deq).to(out.dtype)) def deepseek_v4_fp8_einsum_fake( @@ -564,6 +718,10 @@ def __init__( self.aux_stream = aux_stream self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + # AITER MLA decode scratch buffers (lazy-init on first call) + self._aiter_scratch: "AiterSparseScratch | None" = None + self._aiter_extra_scratch: "AiterSparseScratch | None" = None + # Determine padded head count for FlashMLA if num_heads not in self.SUPPORTED_HEAD_COUNTS: if num_heads < 64: @@ -591,8 +749,11 @@ def __init__( vllm_config.scheduler_config.max_num_batched_tokens ) self.max_model_len = vllm_config.model_config.max_model_len - # DeepseekV4 only supports fp8 kv-cache format for now + # DeepseekV4 only supports fp8 kv-cache format for now. Treat "auto" + # as the model default and normalize it before the fp8-only checks. kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" + if kv_cache_dtype == "auto": + kv_cache_dtype = "fp8" assert kv_cache_dtype.startswith("fp8"), ( f"DeepseekV4 only supports fp8 kv-cache format for now, " @@ -739,6 +900,21 @@ def _forward_decode( swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens + if current_platform.is_rocm(): + # On ROCm, DeepSeek sparse attention is only supported via AITER, + # so we always dispatch to the AITER decode path. + self._forward_decode_aiter( + q=q, + kv_cache=kv_cache, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + output=output, + ) + return + # We treat queries in the same seq as different queries # and later we only attend by generated indices. # q arrives pre-padded to self.padded_heads by the outer wrapper. @@ -903,15 +1079,284 @@ def _forward_prefill( N, ) - output_chunk, _, _ = flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], + if current_platform.is_rocm(): + output_chunk = self._ref_sparse_attn_prefill( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + topk_length=combined_lens, + ) + output[query_start:query_end].copy_(output_chunk.to(output.dtype)) + else: + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) + + def _decode_e8m0_scales(self, scale: torch.Tensor) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, + ) + + return _upcast_e8m0_to_fp32(scale.contiguous()) + + def _dequantize_cache_rows(self, rows: torch.Tensor) -> torch.Tensor: + rows = rows.reshape(-1, rows.shape[-1]) + fp8_dtype = current_platform.fp8_dtype() + fp8_dim = self.nope_head_dim + rope_bytes = self.rope_head_dim * 2 + scale_dim = fp8_dim // 64 + + fp8_vals = rows[:, :fp8_dim].contiguous().view(fp8_dtype) + rope_vals = rows[:, fp8_dim : fp8_dim + rope_bytes].contiguous().view( + torch.bfloat16 + ) + scale_bytes = rows[ + :, fp8_dim + rope_bytes : fp8_dim + rope_bytes + scale_dim + ].contiguous() + scales = self._decode_e8m0_scales(scale_bytes) + scales = torch.repeat_interleave(scales, 64, dim=-1) + nope = fp8_vals.to(torch.float32) * scales + return torch.cat([nope, rope_vals.to(torch.float32)], dim=-1).to(torch.bfloat16) + + def _gather_dequantized_cache_tokens( + self, + cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, + ) -> torch.Tensor: + if slot_ids.numel() == 0: + return torch.empty((0, self.head_dim), dtype=torch.bfloat16, device=cache.device) + slot_ids = slot_ids.to(torch.int64) + rows = cache[slot_ids // block_size, slot_ids % block_size] + return self._dequantize_cache_rows(rows).reshape(-1, self.head_dim) + + def _forward_decode_aiter( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_only: bool, + topk_indices: torch.Tensor | None, + topk_lens: torch.Tensor | None, + swa_indices: torch.Tensor, + swa_lens: torch.Tensor, + output: torch.Tensor, + ) -> None: + """AITER-accelerated sparse MLA decode (ROCm/MI355X). + + Uses aiter.mla.mla_decode_fwd with FP8 persistent-mode ASM kernels + for ~2-3x decode speedup over the torch reference at high batch sizes. + This is the only sparse-attention decode path on ROCm. + """ + from vllm.v1.attention.ops.rocm_aiter_dsv4_decode import ( + AiterSparseScratch, + aiter_sparse_attn_decode, + ) + + if self._aiter_scratch is None: + self._aiter_scratch = AiterSparseScratch() + if self._aiter_extra_scratch is None: + self._aiter_extra_scratch = AiterSparseScratch() + + blocked_swa = self._dequantize_blocked_k_cache( + self.swa_cache_layer.kv_cache) + blocked_extra = ( + None if swa_only + else self._dequantize_blocked_k_cache(kv_cache) + ) + + attn_out = aiter_sparse_attn_decode( + q=q.unsqueeze(1), + blocked_k=blocked_swa, + indices_in_kvcache=swa_indices.unsqueeze(1), + topk_length=swa_lens, + attn_sink=self.attn_sink[:q.shape[1]], + scale=self.scale, + head_dim=self.head_dim, + extra_blocked_k=blocked_extra, + extra_indices_in_kvcache=topk_indices, + extra_topk_length=topk_lens, + scratch=self._aiter_scratch, + extra_scratch=self._aiter_extra_scratch, + ) + output.copy_(attn_out.to(output.dtype)) + + def _forward_decode_fallback( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_metadata: "DeepseekSparseSWAMetadata", + swa_only: bool, + topk_indices: torch.Tensor | None, + topk_lens: torch.Tensor | None, + swa_indices: torch.Tensor, + swa_lens: torch.Tensor, + output: torch.Tensor, + ) -> None: + blocked_swa = self._dequantize_blocked_k_cache(self.swa_cache_layer.kv_cache) + blocked_extra = None if swa_only else self._dequantize_blocked_k_cache(kv_cache) + attn_out = self._ref_sparse_attn_decode( + q=q.unsqueeze(1), + blocked_k=blocked_swa, + indices_in_kvcache=swa_indices.unsqueeze(1), + topk_length=swa_lens, + attn_sink=self.attn_sink[: q.shape[1]], + extra_blocked_k=blocked_extra, + extra_indices_in_kvcache=topk_indices, + extra_topk_length=topk_lens, + ) + output.copy_(attn_out.to(output.dtype)) + + def _dequantize_blocked_k_cache(self, quant_k_cache: torch.Tensor) -> torch.Tensor: + fp8_dtype = current_platform.fp8_dtype() + d = self.head_dim + d_nope = self.nope_head_dim + d_rope = self.rope_head_dim + tile_size = 64 + num_tiles = d_nope // tile_size + + num_blocks, block_size, _ = quant_k_cache.shape + quant_k_cache = quant_k_cache.view(num_blocks, -1) + input_nope_rope = quant_k_cache[:, : block_size * (d_nope + 2 * d_rope)].view( + num_blocks, block_size, d_nope + 2 * d_rope + ) + input_nope = input_nope_rope[:, :, :d_nope].view(fp8_dtype) + input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16) + input_scale = ( + quant_k_cache[:, block_size * (d_nope + 2 * d_rope) :] + .view(num_blocks, block_size, 8)[:, :, :num_tiles] + .view(torch.float8_e8m0fnu) + ) + + result = torch.empty( + (num_blocks, block_size, 1, d), + dtype=torch.bfloat16, + device=quant_k_cache.device, + ) + result[..., d_nope:] = input_rope.unsqueeze(2) + for tile_idx in range(num_tiles): + cur_nope = input_nope[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].to(torch.bfloat16) + cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1) + result[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ] = (cur_nope * cur_scales).unsqueeze(2) + return result + + def _ref_sparse_attn_prefill( + self, + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + topk_length: torch.Tensor | None, + ) -> torch.Tensor: + indices = indices.clone().squeeze(1) + s_q, h_q, d_qk = q.shape + topk = indices.shape[-1] + s_kv = kv.shape[0] + if topk_length is not None: + mask = torch.arange(topk, device=indices.device).unsqueeze(0) >= topk_length.unsqueeze(1) + indices[mask] = -1 + invalid_mask = (indices < 0) | (indices >= s_kv) + indices[invalid_mask] = 0 + + qf = q.float() + gathered_kv = ( + kv.index_select(0, indices.flatten()).reshape(s_q, topk, d_qk).float() + ) + scores = qf @ gathered_kv.transpose(1, 2) + scores *= self.scale + scores[invalid_mask.unsqueeze(1).expand_as(scores)] = float("-inf") + + orig_lse = torch.logsumexp(scores, dim=-1) + lse_for_o = orig_lse + if self.attn_sink is not None: + lse_for_o = torch.logsumexp( + torch.stack([orig_lse, self.attn_sink[:h_q].view(1, h_q).expand_as(orig_lse)], dim=0), + dim=0, + ) + lse_for_o = lse_for_o.clone() + lse_for_o[lse_for_o == float("-inf")] = float("+inf") + probs = torch.exp(scores - lse_for_o.unsqueeze(-1)) + out = probs @ gathered_kv[..., : self.head_dim] + lonely_q_mask = orig_lse == float("-inf") + out[lonely_q_mask.unsqueeze(-1).expand_as(out)] = 0.0 + return out.to(torch.bfloat16) + + def _ref_sparse_attn_decode( + self, + q: torch.Tensor, + blocked_k: torch.Tensor, + indices_in_kvcache: torch.Tensor, + topk_length: torch.Tensor | None, + attn_sink: torch.Tensor | None, + extra_blocked_k: torch.Tensor | None = None, + extra_indices_in_kvcache: torch.Tensor | None = None, + extra_topk_length: torch.Tensor | None = None, + ) -> torch.Tensor: + b, s_q, h_q, d_qk = q.shape + d_v = self.head_dim + + def process_scope( + cur_blocked_k: torch.Tensor, + cur_indices: torch.Tensor, + cur_topk_length: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + cur_indices = cur_indices.reshape(b, s_q, -1) + topk = cur_indices.size(-1) + fixed_indices = torch.clamp_min(cur_indices, 0) + gathered_kv = ( + cur_blocked_k.view(-1, d_qk) + .index_select(0, fixed_indices.view(-1)) + .view(b, s_q, topk, d_qk) + ) + invalid_mask = cur_indices == -1 + if cur_topk_length is not None: + cur_topk_length = cur_topk_length.reshape(b) + invalid_mask |= torch.arange(0, topk, device=invalid_mask.device).view( + 1, 1, topk + ) >= cur_topk_length.view(b, 1, 1) + return gathered_kv, invalid_mask + + gathered_kv, invalid_mask = process_scope( + blocked_k, indices_in_kvcache, topk_length + ) + if extra_blocked_k is not None: + assert extra_indices_in_kvcache is not None + gathered_kv1, invalid_mask1 = process_scope( + extra_blocked_k, extra_indices_in_kvcache, extra_topk_length ) + gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2) + invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2) + + gathered_kv = gathered_kv.view(b * s_q, -1, d_qk).float() + gathered_kv[gathered_kv != gathered_kv] = 0.0 + qf = q.float().view(b * s_q, h_q, d_qk) + attn_weight = qf @ gathered_kv.transpose(-1, -2) + attn_weight *= self.scale + attn_weight[ + invalid_mask.view(b * s_q, 1, -1).expand(b * s_q, h_q, invalid_mask.size(-1)) + ] = float("-inf") + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1)) + output = attn_weight @ gathered_kv[..., :d_v] + output = output.view(b, s_q, h_q, d_v) + lse = lse.view(b, s_q, h_q) + + if attn_sink is not None: + output *= ( + 1.0 / (1.0 + torch.exp(attn_sink.view(1, 1, h_q) - lse)) + ).unsqueeze(-1) + + lonely_q_mask = lse == float("-inf") + output[lonely_q_mask.unsqueeze(-1).expand_as(output)] = 0.0 + return output.squeeze(1).to(torch.bfloat16) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): @@ -1064,13 +1509,40 @@ def forward( q = q.view(-1, self.n_head, self.head_dim) k = self.compressor(hidden_states, positions, rotary_emb) weights, _ = self.weights_proj(hidden_states) - q_quant, weights = fused_indexer_q_rope_quant( - positions, - q, - rotary_emb.cos_sin_cache, - weights, - self.softmax_scale, - self.n_head**-0.5, - use_fp4=self.use_fp4_kv, - ) + if current_platform.is_rocm() and not self.use_fp4_kv: + q_quant, weights = self._quantize_indexer_q_torch( + q, positions, rotary_emb.cos_sin_cache, weights + ) + else: + q_quant, weights = fused_indexer_q_rope_quant( + positions, + q, + rotary_emb.cos_sin_cache, + weights, + self.softmax_scale, + self.n_head**-0.5, + use_fp4=self.use_fp4_kv, + ) return self.indexer_op(hidden_states, q_quant, k, weights) + + def _quantize_indexer_q_torch( + self, + q: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + q = apply_gptj_rope_ref(q, positions, cos_sin_cache, self.rope_dim).to( + torch.bfloat16 + ) + q = hadamard_transform_ref(q).to(torch.float32) + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else 448.0 + q_scale = torch.abs(q).amax(dim=-1).clamp(min=1e-12) / fp8_max + q_quant = (q / q_scale.unsqueeze(-1)).to(current_platform.fp8_dtype()) + weights = ( + weights.to(torch.float32) + * q_scale + * self.softmax_scale + * (self.n_head**-0.5) + ) + return q_quant, weights diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index f476d980d555..9fa5aac6fb85 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, FusedMoEQuantDesc, + RoutingMethodType, mxfp4_mxfp8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, @@ -217,6 +218,8 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: TRTLLM MXFP8; SM90 falls through to Triton_unfused or Marlin (the backend-level ``is_supported_config`` check filters by device capability). """ + if current_platform.is_rocm(): + return [Mxfp4MoeBackend.AITER] _AVAILABLE_BACKENDS = [ Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.DEEPGEMM_MXFP4, @@ -484,8 +487,22 @@ def _return_or_raise( activation_format, ) + # DeepSeek-V4 on ROCm is more accurate with the unfused Triton MXFP4 path + # than the default AITER path. Prefer Triton-unfused for this routing mode, + # while keeping AITER as a fallback if Triton-unfused rejects the config. + if ( + current_platform.is_rocm() + and config.routing_method == RoutingMethodType.DeepseekV4 + ): + priority_backends = [ + Mxfp4MoeBackend.TRITON_UNFUSED, + Mxfp4MoeBackend.AITER, + ] + else: + priority_backends = _get_priority_backends() + # Iterate priority backends: TRTLLM MXFP8, then Triton. - for backend in _get_priority_backends(): + for backend in priority_backends: activation_key = _backend_activation_key(backend) for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( @@ -1107,6 +1124,64 @@ def convert_weight_to_mxfp4_moe_kernel_format( w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER: + from vllm._aiter_ops import rocm_aiter_ops + + if w13_bias is not None: + w13_bias = w13_bias.data.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.data.to(torch.float32) + + e, n, k = w13_weight.shape + + w13_weight.view(torch.uint8).copy_( + w13_weight.data.view(torch.uint8) + .view(e, n // 2, 2, k) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, k) + ) + w13_weight_scale.data = ( + w13_weight_scale.data.view(e, n // 2, 2, -1) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, -1) + ) + + w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2) + w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2) + + w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True) + shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4( + w13_weight_scale.view(-1, w13_weight_scale.shape[-1]), + num_experts, + True, + ) + + w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False) + shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4( + w2_weight_scale.view(-1, w2_weight_scale.shape[-1]), + num_experts, + False, + ) + + if w13_bias is not None: + w13_bias = ( + w13_bias.data.view(-1, n // 2, 2) + .permute(0, 2, 1) + .contiguous() + .view(-1, n) + ) + + return ( + w13_weight, + w2_weight, + shuffled_w13_scale, + shuffled_w2_scale, + w13_bias, + w2_bias, + ) + elif mxfp4_backend in TRITON_BACKENDS: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -1162,7 +1237,7 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor: else: raise ValueError( f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. " - f"Expected TRTLLM or Triton backend." + f"Expected TRTLLM, Triton, or AITER backend." ) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index 1521a6b601bf..ad1655e7eadb 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -234,6 +234,40 @@ def mhc_pre( num_tokens = residual_flat.shape[0] fn_flat = fn + if current_platform.is_rocm(): + x = residual_flat.view(num_tokens, hc_mult * hidden_size).to(torch.float32) + mixes = torch.matmul(x, fn_flat.t()) + sqrsum = x.square().sum(dim=-1, keepdim=True) + mixes = mixes * torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps) + + pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult] + pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps + + post_logits = ( + mixes[:, hc_mult : 2 * hc_mult] * hc_scale[1] + + hc_base[hc_mult : 2 * hc_mult] + ) + post_mix = torch.sigmoid(post_logits) * hc_post_mult_value + + comb_logits = ( + mixes[:, 2 * hc_mult :].view(num_tokens, hc_mult, hc_mult) * hc_scale[2] + + hc_base[2 * hc_mult :].view(1, hc_mult, hc_mult) + ) + comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps + comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) + for _ in range(sinkhorn_repeat - 1): + comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps) + comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) + + layer_input = torch.sum( + pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1 + ).to(torch.bfloat16) + return ( + post_mix.view(*outer_shape, hc_mult, 1), + comb_mix.view(*outer_shape, hc_mult, hc_mult), + layer_input.view(*outer_shape, hidden_size), + ) + # these number are from deepgemm kernel impl block_k = 64 block_m = 64 @@ -414,6 +448,14 @@ def mhc_post( post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: + if current_platform.is_rocm(): + mixed_residual = torch.einsum( + "...ij,...ih->...jh", + comb_res_mix.to(torch.float32), + residual.to(torch.float32), + ) + post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32) + return (mixed_residual + post_term).to(residual.dtype) out = torch.empty_like(residual) mhc_post_tilelang( comb_res_mix, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9613b11d35e2..d9aab35c25f4 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -843,6 +843,15 @@ def w8a8_triton_block_scaled_mm( assert len(block_size) == 2 block_n, block_k = block_size[0], block_size[1] + # Triton cannot currently bind E8M0 scale tensors directly. On ROCm, + # DeepSeek-V4 checkpoints store block scales in exponent-only E8M0 format, + # so decode them to fp32 before launching the kernel. + if current_platform.is_rocm(): + if As.dtype == torch.float8_e8m0fnu: + As = _upcast_e8m0_to_fp32(As).contiguous() + if Bs.dtype == torch.float8_e8m0fnu: + Bs = _upcast_e8m0_to_fp32(Bs).contiguous() + assert A.shape[-1] == B.shape[-1] assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index ca82f2feb7ef..9ad881cdbe85 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -499,13 +499,31 @@ def forward_hip( k: torch.Tensor, weights: torch.Tensor, ): - assert not self.skip_k_cache_insert, ( - "AMD platform doesn't support skip cache insert yet" - ) assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet" assert isinstance(q_quant, torch.Tensor), ( "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" ) + if self.skip_k_cache_insert or not rocm_aiter_ops.is_enabled(): + from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + rocm_aiter_sparse_attn_indexer_native, + ) + + return rocm_aiter_sparse_attn_indexer_native( + hidden_states, + _encode_layer_name(self.k_cache.prefix), + self.k_cache.kv_cache, + q_quant, + k, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + skip_k_cache_insert=self.skip_k_cache_insert, + ) if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, @@ -522,8 +540,6 @@ def forward_hip( self.max_total_seq_len, self.topk_indices_buffer, ) - else: - raise RuntimeError( - "Sparse attention indexer ROCm custom op requires ROCm " - "Aiter ops to be enabled." - ) + raise RuntimeError( + "Sparse attention indexer ROCm path could not be selected." + ) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 97f755240a4c..5c23bbdcd899 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -67,7 +67,6 @@ _DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8") - class DeepseekV4MLP(nn.Module): def __init__( self, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 866b9ffd1a6d..70534c106ffe 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -409,6 +409,7 @@ class RocmPlatform(Platform): "gptq", "gptq_marlin", # will be overwritten with gptq "fp8", + "deepseek_v4_fp8", "compressed-tensors", "fbgemm_fp8", "gguf", diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c33203..f66818ab8972 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -473,7 +473,11 @@ def tf32_hc_prenorm_gemm( """ _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: - return _missing() + out.zero_() + sqrsum.zero_() + out[0].copy_(torch.matmul(x.to(torch.float32), fn.t().to(torch.float32))) + sqrsum[0].copy_(x.to(torch.float32).square().sum(dim=-1)) + return out return _tf32_hc_prenorm_gemm_impl( x, fn, diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index b17fd5d34418..28564e6a97d3 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -7,6 +7,7 @@ from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, @@ -360,7 +361,7 @@ def build_tile_scheduler( _LAYER_TYPE_C4A: None, _LAYER_TYPE_C128A: None, } - if num_decode_tokens == 0: + if num_decode_tokens == 0 or current_platform.is_rocm(): return out for layer_type in self._layer_types: # get_mla_metadata() is the official FlashMLA entry point that diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py index 97c9538889a1..de8367e0dd65 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py @@ -224,7 +224,6 @@ def fused_inv_rope_fp8_quant( HALF_ROPE=rope_dim // 2, TMA_ALIGNED_SCALES=tma_aligned_scales, num_stages=1, - launch_pdl=False, ) grid = (tma_aligned_T, n_groups * heads_per_group) diff --git a/vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py b/vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py new file mode 100644 index 000000000000..2ca6adbffbba --- /dev/null +++ b/vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py @@ -0,0 +1,495 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +AITER-accelerated sparse MLA decode for DeepSeek V4 on ROCm (MI355X / gfx950). + +Drop-in replacement for `DeepseekV4MLAAttention._ref_sparse_attn_decode`. +Uses `aiter.mla.mla_decode_fwd` with persistent-mode ASM kernels and FP8 +inputs to get 2-3x speedup at high batch sizes. + +Key design decisions (validated on MI355X, see benchmarks/dsv4_mi355/PLAN.md §12): + - FP8/FP8 path only: gfx950 persistent-mode ASM kernels with return_lse=True + exist ONLY for FP8/FP8 (not BF16). We need LSE for attn_sink correction. + - Fixed-stride kv_indices: persistent-mode expects (total_q * topk) flat + layout with -1 sentinels for invalid entries; NOT ragged. + - Per-scope scratch: SWA and extra scopes have different topk, requiring + independent AITER metadata buffers. + - Cudagraph-safe: all per-step indexing tensors, the FP8 query buffer, the + output buffer, and the constant scale tensors are preallocated in + `AiterSparseScratch` and rewritten in-place so cudagraph capture sees + stable memory layouts. +""" + +from __future__ import annotations + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class AiterSparseScratch: + """Cached per-step AITER persistent-mode scratch buffers (cudagraph-safe). + + Allocate once per `(total_q, nhead, topk, d_qk, d_v, dtype, kvtype)` key + and reuse across all 61 DSv4 attention layers in the same decode step. + Buffers fall into three groups: + + * AITER persistent metadata (`work_*`, `reduce_*`) — sized at rebuild + from `get_mla_metadata_info_v1` (purely shape-determined), but rewritten + in-place by `aiter.get_mla_metadata_v1` every step, because the work + plan encodes the *actual* kv lengths and the persistent ASM kernel + reads out-of-bounds if it is left stale. + * Per-step indexing/IO buffers (`qo_indptr`, `kv_indptr`, `kv_indices_2d`, + `kv_last_page_lens`, `valid_mask`, `valid_lens`, `col_arange`, `q_fp8`, + `out_buf`) — written in-place each step. + * Constant scale tensors (`q_scale`, `kv_scale`) — initialised once. + + The metadata buffers, the per-step buffers and the scale tensors all keep + stable `data_ptr()`s across the entire lifetime of a shape key, so a + HIP/CUDA graph captured around the second or later step replays correctly. + """ + + __slots__ = ( + # AITER persistent metadata buffers + "work_meta_data", + "work_indptr", + "work_info_set", + "reduce_indptr", + "reduce_final_map", + "reduce_partial_map", + # Per-step indexing buffers + "qo_indptr", + "kv_indptr", + "kv_indices_2d", + "kv_last_page_lens", + "valid_mask", + "valid_lens", + "col_arange", + # FP8 query buffer + output buffer + "q_fp8", + "out_buf", + # Constant scale tensors (always 1.0 for our quantization scheme) + "q_scale", + "kv_scale", + # GQA ratios captured at rebuild time so per-step refresh can call + # `get_mla_metadata_v1` with the same parameters every time. + "_gqa_ratio", + "_nhead_kv", + "_page_size", + "_topk", + "_dtype", + "_kvtype", + "_max_split_per_batch", + # Identity key for cache lookups + "_key", + ) + + def __init__(self) -> None: + for slot in self.__slots__: + setattr(self, slot, None) + self._key = () + + def matches( + self, + total_q: int, + nhead: int, + topk: int, + d_qk: int, + d_v: int, + dtype: torch.dtype, + kvtype: torch.dtype, + ) -> bool: + return self._key == (total_q, nhead, topk, d_qk, d_v, dtype, kvtype) + + def rebuild( + self, + *, + total_q: int, + nhead: int, + nhead_kv: int, + topk: int, + d_qk: int, + d_v: int, + page_size: int, + dtype: torch.dtype, + kvtype: torch.dtype, + device: torch.device, + max_split_per_batch: int = 256, + ) -> None: + """Allocate every persistent buffer for the given shape key. + + Buffer sizes returned by `get_mla_metadata_info_v1` are determined by + shapes and `max_split_per_batch` only, so they are large enough for + any kv-length distribution. The actual work plan is computed on the + per-step path by `refresh_metadata`, which writes these buffers + in-place using the freshly populated `qo_indptr`/`kv_indptr`/ + `kv_last_page_lens` -- those pointers stay stable for the lifetime + of this scratch. + """ + import aiter + + # ---- AITER persistent metadata buffers (sizes only) ------------- # + ( + (wmd_size, wmd_type), + (wi_size, wi_type), + (wis_size, wis_type), + (ri_size, ri_type), + (rfm_size, rfm_type), + (rpm_size, rpm_type), + ) = aiter.get_mla_metadata_info_v1( + total_q, + 1, + nhead, + dtype, + kvtype, + is_sparse=True, + fast_mode=True, + num_kv_splits=max_split_per_batch, + ) + self.work_meta_data = torch.empty(wmd_size, dtype=wmd_type, device=device) + self.work_indptr = torch.empty(wi_size, dtype=wi_type, device=device) + self.work_info_set = torch.empty(wis_size, dtype=wis_type, device=device) + self.reduce_indptr = torch.empty(ri_size, dtype=ri_type, device=device) + self.reduce_final_map = torch.empty(rfm_size, dtype=rfm_type, device=device) + self.reduce_partial_map = torch.empty(rpm_size, dtype=rpm_type, device=device) + + # ---- Per-step indexing buffers ---------------------------------- # + # qo_indptr is always [0, 1, 2, ..., total_q] for one query per token. + self.qo_indptr = torch.arange( + total_q + 1, dtype=torch.int32, device=device + ) + self.kv_indptr = torch.zeros( + total_q + 1, dtype=torch.int32, device=device + ) + self.kv_indices_2d = torch.empty( + (total_q, topk), dtype=torch.int32, device=device + ) + # kv_last_page_lens is always all-ones for our 1-token-per-page layout. + self.kv_last_page_lens = torch.ones( + total_q, dtype=torch.int32, device=device + ) + self.valid_mask = torch.empty( + (total_q, topk), dtype=torch.bool, device=device + ) + self.valid_lens = torch.empty( + total_q, dtype=torch.int32, device=device + ) + self.col_arange = torch.arange(topk, dtype=torch.int32, device=device) + + # ---- FP8 query + bf16 output buffers --------------------------- # + self.q_fp8 = torch.empty( + (total_q, nhead, d_qk), dtype=dtype, device=device + ) + self.out_buf = torch.empty( + (total_q, nhead, d_v), dtype=torch.bfloat16, device=device + ) + + # ---- Constant scale tensors (1.0 for our quant scheme) --------- # + self.q_scale = torch.ones(1, dtype=torch.float32, device=device) + self.kv_scale = torch.ones(1, dtype=torch.float32, device=device) + + # Cache parameters for `refresh_metadata`. + self._gqa_ratio = nhead // nhead_kv + self._nhead_kv = nhead_kv + self._page_size = page_size + self._topk = topk + self._dtype = dtype + self._kvtype = kvtype + self._max_split_per_batch = max_split_per_batch + + self._key = (total_q, nhead, topk, d_qk, d_v, dtype, kvtype) + + def refresh_metadata(self) -> None: + """Re-run `aiter.get_mla_metadata_v1` against the current + `kv_indptr` / `kv_last_page_lens`, writing the new work plan into the + same `work_*` / `reduce_*` buffers in-place. + + Must be called every step *after* `kv_indptr` is updated and *before* + `aiter.mla.mla_decode_fwd`. The persistent ASM kernel reads + out-of-bounds if it is left with a stale work plan, so this call is + not optional even when shapes are unchanged. + """ + import aiter + + aiter.get_mla_metadata_v1( + self.qo_indptr, + self.kv_indptr, + self.kv_last_page_lens, + self._gqa_ratio, + self._nhead_kv, + True, + self.work_meta_data, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + page_size=self._page_size, + kv_granularity=max(self._page_size, 16), + max_seqlen_qo=1, + uni_seqlen_qo=1, + fast_mode=True, + max_split_per_batch=self._max_split_per_batch, + topk=self._topk, + dtype_q=self._dtype, + dtype_kv=self._kvtype, + ) + + +def aiter_sparse_attn_decode( + *, + q: torch.Tensor, + blocked_k: torch.Tensor, + indices_in_kvcache: torch.Tensor, + topk_length: torch.Tensor | None, + attn_sink: torch.Tensor | None, + scale: float, + head_dim: int, + extra_blocked_k: torch.Tensor | None = None, + extra_indices_in_kvcache: torch.Tensor | None = None, + extra_topk_length: torch.Tensor | None = None, + scratch: AiterSparseScratch | None = None, + extra_scratch: AiterSparseScratch | None = None, +) -> torch.Tensor: + """AITER-backed replacement for _ref_sparse_attn_decode. + + Args: + q: (b, 1, h_q, d_qk) bf16 — unsqueezed decode query + blocked_k: (n_blk, blk_sz, 1, d_qk) bf16 — dequantized SWA K cache + indices_in_kvcache: (b, 1, topk_swa) int32 + topk_length: (b,) int32 or None + attn_sink: (h_q,) fp32 or None + scale: softmax scale (1/sqrt(d_qk)) + head_dim: kv_lora_rank (d_v, typically 512) + extra_blocked_k: optional second scope K cache + extra_indices_in_kvcache: optional second scope indices + extra_topk_length: optional second scope lengths + scratch: persistent scratch for SWA scope + extra_scratch: persistent scratch for extra scope + + Returns: (b, h_q, d_v) bf16 + """ + b, s_q, h_q, d_qk = q.shape + d_v = head_dim + + # Head-split workaround for the AITER persistent ASM kernel + # `mla_a8w8_qh64_qseqlen4_gqaratio16_lse_ps` (selected when h_q==64, e.g. + # TP=2 on a 128-head model or TP=4 on a 256-head model). With + # `return_lse=True` on gfx950 it returns `lse=+inf` for some (batch, head) + # rows, which propagates NaN through the dual-scope LSE merge. The + # `qh16_qseqlen1` kernel selected for h_q<=32 is numerically clean + # (validated cosine > 0.999 vs PyTorch reference on TP4 / TP8). Recurse + # at the outer level here — split q + attn_sink in half, run a complete + # `aiter_sparse_attn_decode` per half, then concatenate the final outputs. + # Doing the split here (instead of inside `_aiter_decode_one_scope`) keeps + # each recursive call self-contained: LSE merging and sink correction + # never cross the h_q boundary, so the result does not depend on the + # kernel's `lse` shape convention (which has differed across aiter + # versions). 2x kernel launches per decode call in the split case, but + # those configs are rarely used in production; correctness first. + # TODO: drop once AITER fixes the qh64_qseqlen4_gqaratio16_lse_ps kernel. + if h_q > 32: + h_half = h_q // 2 + sink_a = attn_sink[..., :h_half] if attn_sink is not None else None + sink_b = attn_sink[..., h_half:] if attn_sink is not None else None + out_a = aiter_sparse_attn_decode( + q=q[..., :h_half, :].contiguous(), + blocked_k=blocked_k, + indices_in_kvcache=indices_in_kvcache, + topk_length=topk_length, + attn_sink=sink_a, + scale=scale, + head_dim=head_dim, + extra_blocked_k=extra_blocked_k, + extra_indices_in_kvcache=extra_indices_in_kvcache, + extra_topk_length=extra_topk_length, + scratch=AiterSparseScratch(), + extra_scratch=AiterSparseScratch(), + ) + out_b = aiter_sparse_attn_decode( + q=q[..., h_half:, :].contiguous(), + blocked_k=blocked_k, + indices_in_kvcache=indices_in_kvcache, + topk_length=topk_length, + attn_sink=sink_b, + scale=scale, + head_dim=head_dim, + extra_blocked_k=extra_blocked_k, + extra_indices_in_kvcache=extra_indices_in_kvcache, + extra_topk_length=extra_topk_length, + scratch=AiterSparseScratch(), + extra_scratch=AiterSparseScratch(), + ) + return torch.cat([out_a, out_b], dim=-2) + + if scratch is None: + scratch = AiterSparseScratch() + if extra_scratch is None: + extra_scratch = AiterSparseScratch() + + out_swa, lse_swa = _aiter_decode_one_scope( + q=q, + blocked_k=blocked_k, + indices=indices_in_kvcache, + lens=topk_length, + sm_scale=scale, + d_v=d_v, + scratch=scratch, + ) + + if extra_blocked_k is not None: + assert extra_indices_in_kvcache is not None + out_ext, lse_ext = _aiter_decode_one_scope( + q=q, + blocked_k=extra_blocked_k, + indices=extra_indices_in_kvcache, + lens=extra_topk_length, + sm_scale=scale, + d_v=d_v, + scratch=extra_scratch, + ) + lse_total = torch.logsumexp( + torch.stack([lse_swa, lse_ext], dim=0), dim=0) + w_swa = (lse_swa - lse_total).exp().unsqueeze(-1) + w_ext = (lse_ext - lse_total).exp().unsqueeze(-1) + out = w_swa * out_swa + w_ext * out_ext + lse = lse_total + else: + out = out_swa + lse = lse_swa + + if attn_sink is not None: + sink = attn_sink.view(1, 1, h_q).to(lse.dtype) + correction = 1.0 / (1.0 + (sink - lse).exp()) + out = out * correction.unsqueeze(-1).to(out.dtype) + + lonely = lse == float("-inf") + if lonely.any(): + out = out.masked_fill(lonely.unsqueeze(-1), 0.0) + + return out.view(b, s_q, h_q, d_v).squeeze(1).to(torch.bfloat16) + + +def _aiter_decode_one_scope( + *, + q: torch.Tensor, + blocked_k: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor | None, + sm_scale: float, + d_v: int, + scratch: AiterSparseScratch, +) -> tuple[torch.Tensor, torch.Tensor]: + """Single-scope AITER mla_decode_fwd call. + + Writes into `scratch`'s preallocated buffers in-place so the entire + decode call is cudagraph-capture friendly. + + Returns (output, lse) where: + output: (total_q, h_q, d_v) bf16 — alias of `scratch.out_buf` + lse: (total_q, h_q) fp32 — allocated by AITER kernel + """ + import aiter + + b, s_q, h_q, d_qk = q.shape + device = q.device + total_q = b * s_q + + indices_2d = indices.reshape(total_q, -1).contiguous() + topk_max = indices_2d.size(-1) + + fp8_dtype = torch.float8_e4m3fn + if not scratch.matches( + total_q, h_q, topk_max, d_qk, d_v, fp8_dtype, fp8_dtype + ): + scratch.rebuild( + total_q=total_q, + nhead=h_q, + nhead_kv=1, + topk=topk_max, + d_qk=d_qk, + d_v=d_v, + page_size=1, + dtype=fp8_dtype, + kvtype=fp8_dtype, + device=device, + ) + + # ---- Build valid_mask + valid_lens directly into scratch ----------- # + if lens is not None: + if lens.numel() == b and s_q > 1: + lens_per_tok = lens.repeat_interleave(s_q) + else: + lens_per_tok = lens.reshape(-1) + # valid_mask[i, j] = (j < lens_per_tok[i]) AND (indices_2d[i, j] >= 0) + torch.lt( + scratch.col_arange.unsqueeze(0), + lens_per_tok.unsqueeze(1), + out=scratch.valid_mask, + ) + scratch.valid_mask &= indices_2d >= 0 + else: + torch.ge(indices_2d, 0, out=scratch.valid_mask) + + torch.sum( + scratch.valid_mask, dim=-1, dtype=torch.int32, out=scratch.valid_lens + ) + + # ---- Compute kv_indptr in place (cumsum(min(valid_lens, topk))) ---- # + scratch.kv_indptr[0] = 0 + torch.cumsum( + scratch.valid_lens.clamp(max=topk_max), + dim=0, + out=scratch.kv_indptr[1:], + ) + + # ---- Fill kv_indices_2d in-place: keep valid, sentinel -1 elsewhere - # + scratch.kv_indices_2d.copy_(indices_2d) + scratch.kv_indices_2d.masked_fill_(~scratch.valid_mask, -1) + + # ---- Cast q to FP8 in-place into the preallocated buffer ----------- # + scratch.q_fp8.copy_(q.reshape(total_q, h_q, d_qk)) + + # ---- Cast blocked_k to FP8 (per-layer-sized, not in scratch) ------- # + # Note: kv cache dequantize + FP8 cast itself is the bigger perf concern + # tracked separately. Here we just cast to satisfy the kernel's dtype. + kv_fp8 = blocked_k.to(fp8_dtype) + kv_view = kv_fp8.view(-1, 1, 1, d_qk) + + # ---- Refresh AITER work plan against the current kv_indptr --------- # + # The persistent ASM kernel encodes per-batch lengths into work_*; if we + # leave that stale, the kernel reads out of bounds. Rewrite into the same + # buffers in place so pointers stay stable for cudagraph capture. + scratch.refresh_metadata() + + # ---- Persistent-mode FP8 mla_decode_fwd ---------------------------- # + _, lse = aiter.mla.mla_decode_fwd( + scratch.q_fp8, + kv_view, + scratch.out_buf, + scratch.qo_indptr, + scratch.kv_indptr, + scratch.kv_indices_2d.view(-1), + scratch.kv_last_page_lens, + 1, + 1, + 1, + sm_scale, + num_kv_splits=256, + q_scale=scratch.q_scale, + kv_scale=scratch.kv_scale, + work_meta_data=scratch.work_meta_data, + work_indptr=scratch.work_indptr, + work_info_set=scratch.work_info_set, + reduce_indptr=scratch.reduce_indptr, + reduce_final_map=scratch.reduce_final_map, + reduce_partial_map=scratch.reduce_partial_map, + return_lse=True, + ) + + if lse is None: + raise RuntimeError("aiter.mla.mla_decode_fwd returned no LSE") + + return scratch.out_buf, lse diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 81cc489db0d8..ea75f19862e8 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -5,6 +5,7 @@ from importlib.util import find_spec import torch +import torch.nn.functional as F from vllm.forward_context import get_forward_context from vllm.platforms import current_platform @@ -232,6 +233,39 @@ def fp8_paged_mqa_logits_torch( fp8_dtype = current_platform.fp8_dtype() batch_size, next_n, _, dim = q.size() + if next_n == 1: + block_size = kv_cache.shape[1] + logits = torch.full( + [batch_size, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + if context_lens.dim > 1: + context_lens = context_lens.squeeze(-1) + kv_cache_flat = kv_cache.view(-1, block_size * (dim + 4)) + for i in range(batch_size): + q_i = q[i, 0].to(torch.float32) + q_scale = weights[i] + seq_len = int(context_lens[i].item()) + assert seq_len <= max_model_len + num_pages = cdiv(seq_len, block_size) + padded_seq_len = num_pages * block_size + pages = block_tables[i, :num_pages] + cache = kv_cache_flat[pages] + scale_offset = block_size * dim + cache_value = cache[..., :scale_offset].view(dtype=fp8_dtype).to(torch.float32) + cache_scale = cache[..., scale_offset:].view(dtype=torch.float32).contiguous() + cache_value = cache_value.view(padded_seq_len, dim) + cache_scale = cache_scale.view(padded_seq_len) + score = F.linear(cache_value, q_i) + score = F.relu(score) + score *= q_scale[None, :] + score = score.sum(dim=1) + score *= cache_scale + logits[i, :seq_len] = score[:seq_len] + return logits + kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] scale = scale.contiguous().view(torch.float) q = q.float() @@ -243,20 +277,30 @@ def fp8_paged_mqa_logits_torch( device=q.device, dtype=torch.float32, ) - context_lens = context_lens.tolist() for i in range(batch_size): context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + if context_len.ndim == 0: + context_len_i = int(context_len.item()) + q_offsets = torch.arange( + context_len_i - next_n, context_len_i, device=q.device + ) + context_limit = torch.full( + (next_n,), context_len_i, dtype=torch.int32, device=q.device + ) + else: + context_limit = context_len.to(device=q.device, dtype=torch.int32) + q_offsets = context_limit - 1 weight_slice = ( weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() ) - for block_rk in range(cdiv(context_len, block_size)): + max_context_len = int(context_limit.max().item()) + for block_rk in range(cdiv(max_context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] k_offsets = torch.arange( - block_rk * block_size, (block_rk + 1) * block_size, device="cuda" + block_rk * block_size, (block_rk + 1) * block_size, device=q.device ) - mask = (k_offsets[None, :] < context_len) & ( + mask = (k_offsets[None, :] < context_limit[:, None]) & ( k_offsets[None, :] <= q_offsets[:, None] ) s = torch.where( @@ -461,6 +505,27 @@ def rocm_fp8_mqa_logits( return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) +def _topk_indices_torch(logits: torch.Tensor, topk_tokens: int) -> torch.Tensor: + k = min(topk_tokens, logits.shape[-1]) + values, indices = torch.topk(logits, k=k, dim=-1) + indices = indices.to(torch.int32) + indices = torch.where( + values == float("-inf"), + torch.full_like(indices, -1, dtype=torch.int32), + indices, + ) + if k == topk_tokens: + return indices + padded = torch.full( + (logits.shape[0], topk_tokens), + -1, + dtype=torch.int32, + device=logits.device, + ) + padded[:, :k] = indices + return padded + + def rocm_aiter_sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, @@ -479,8 +544,9 @@ def rocm_aiter_sparse_attn_indexer_fake( # profile run # NOTE(Chen): create the max possible flattened_kv. So that # profile_run can get correct memory usage. + device = hidden_states.device if k is None else k.device _flattened_kv = torch.empty( - [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + [total_seq_lens, head_dim + 4], device=device, dtype=torch.uint8 ) fp8_dtype = current_platform.fp8_dtype() _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() @@ -488,7 +554,7 @@ def rocm_aiter_sparse_attn_indexer_fake( return topk_indices_buffer -def rocm_aiter_sparse_attn_indexer( +def rocm_aiter_sparse_attn_indexer_native( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, kv_cache: torch.Tensor, @@ -502,6 +568,7 @@ def rocm_aiter_sparse_attn_indexer( max_model_len: int, total_seq_lens: int, topk_indices_buffer: torch.Tensor | None, + skip_k_cache_insert: bool = False, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata @@ -534,19 +601,24 @@ def rocm_aiter_sparse_attn_indexer( has_decode = layer_attn_metadata.num_decodes > 0 has_prefill = layer_attn_metadata.num_prefills > 0 num_decode_tokens = layer_attn_metadata.num_decode_tokens + device = hidden_states.device if k is None else k.device # during speculative decoding, k may be padded to the CUDA graph batch # size while slot_mapping only covers actual tokens. num_tokens = slot_mapping.shape[0] - k = k[:num_tokens] + if k is not None: + k = k[:num_tokens] + elif not skip_k_cache_insert: + raise ValueError("k must be provided when skip_k_cache_insert is False") - ops.indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) + if not skip_k_cache_insert: + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: @@ -555,12 +627,12 @@ def rocm_aiter_sparse_attn_indexer( for chunk in prefill_metadata.chunks: k_fp8 = torch.empty( [chunk.total_seq_lens, head_dim], - device=k.device, + device=device, dtype=fp8_dtype, ) k_scale = torch.empty( [chunk.total_seq_lens, 4], - device=k.device, + device=device, dtype=torch.uint8, ) @@ -579,21 +651,10 @@ def rocm_aiter_sparse_attn_indexer( chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) - num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + topk_indices.copy_(_topk_indices_torch(logits, topk_tokens)) if has_decode: decode_metadata = layer_attn_metadata.decode @@ -630,19 +691,8 @@ def rocm_aiter_sparse_attn_indexer( max_model_len=max_model_len, ) - num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + topk_indices.copy_(_topk_indices_torch(logits, topk_tokens)[:num_decode_tokens]) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -656,3 +706,36 @@ def rocm_aiter_sparse_attn_indexer( ) return topk_indices_buffer + + +def rocm_aiter_sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + return rocm_aiter_sparse_attn_indexer_native( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + skip_k_cache_insert=False, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a0ba47f945a7..00e77c081857 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4040,6 +4040,7 @@ def execute_model( batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, slot_mapping=slot_mappings, + additional_kwargs={"req_ids": self.input_batch.req_ids.copy()}, skip_compiled=has_encoder_input, ), record_function_or_nullcontext("gpu_model_runner: forward"), @@ -5532,6 +5533,7 @@ def _dummy_run( batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, slot_mapping=slot_mappings, + additional_kwargs={"req_ids": self.input_batch.req_ids.copy()}, ), ): outputs = self.model(