From a8d08662880759581c32f925639cab4745fa225c Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Tue, 28 Apr 2026 10:55:03 +0000 Subject: [PATCH 1/8] DeepSeekV4 enablement on ROCm with torch-fallback & Triton for FP8 --- vllm/config/compilation.py | 1 + .../layers/deepseek_compressor.py | 9 +- .../layers/deepseek_v4_attention.py | 221 +++++- .../router/fused_topk_bias_router.py | 122 +++- vllm/model_executor/layers/mhc.py | 180 ++++- .../layers/sparse_attn_indexer.py | 37 +- vllm/model_executor/layers/utils.py | 11 +- vllm/model_executor/models/deepseek_v4.py | 12 +- vllm/platforms/rocm.py | 1 + vllm/triton_utils/__init__.py | 34 +- vllm/utils/deep_gemm.py | 9 +- .../fused_inv_rope_fp8_quant.py | 6 +- vllm/v1/attention/ops/flashmla.py | 25 + .../v1/attention/ops/rocm_flash_mla_sparse.py | 648 ++++++++++++++++++ .../attention/ops/rocm_sparse_attn_indexer.py | 549 +++++++++++++++ 15 files changed, 1808 insertions(+), 57 deletions(-) create mode 100644 vllm/v1/attention/ops/rocm_flash_mla_sparse.py create mode 100644 vllm/v1/attention/ops/rocm_sparse_attn_indexer.py diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 0f02a92681c1..d5fa087a329a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -749,6 +749,7 @@ class CompilationConfig: "vllm::kda_attention", "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", + "vllm::rocm_sparse_attn_indexer_no_insert", "vllm::deepseek_v4_attention", ] diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index af2783f604da..1bf4a4ac52e0 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -16,7 +16,7 @@ ) from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 from vllm.platforms import current_platform -from vllm.triton_utils import tl, triton +from vllm.triton_utils import maybe_launch_pdl, tl, triton from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -329,7 +329,10 @@ def forward( TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), STATE_WIDTH=state_width, COMPRESS_RATIO=self.compress_ratio, - launch_pdl=False, + # PDL is a NVIDIA Hopper-only Triton launch attribute; omit + # on other backends (e.g. ROCm) to avoid KeyError in + # JITKernel. See note above re: read-after-write race. + **maybe_launch_pdl(), ) # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. @@ -378,7 +381,7 @@ def forward( SCALE_DIM=self._scale_dim, KV_BLOCK_STRIDE=kv_cache.stride(0), num_warps=self._num_warps, - launch_pdl=False, + **maybe_launch_pdl(), ) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 43242eddb5b2..1e574bfe7646 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -25,6 +25,7 @@ fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, + quantize_and_insert_k_cache, ) if TYPE_CHECKING: @@ -72,6 +73,86 @@ logger = init_logger(__name__) + +def _apply_rope_gptj_last_dims( + x: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + rope_dim: int, +) -> torch.Tensor: + """GPT-J-style (interleaved-pair) RoPE on the last rope_dim elements. + + Numerically matches ``fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert`` / + tests in ``test_fused_deepseek_v4_qnorm_rope_kv_insert``. + """ + half = rope_dim // 2 + head_dim = x.shape[-1] + nope_dim = head_dim - rope_dim + + cs = cos_sin_cache[positions].to(torch.float32) + cos = cs[..., :half] + sin = cs[..., half:] + + rope = x[..., nope_dim:].float() + shape = rope.shape + rope = rope.reshape(*shape[:-1], half, 2) + even = rope[..., 0] + odd = rope[..., 1] + + for _ in range(rope.ndim - 3): + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + new_even = even * cos - odd * sin + new_odd = even * sin + odd * cos + rope_rotated = torch.stack((new_even, new_odd), dim=-1).reshape(shape) + + out = x.clone().float() + out[..., nope_dim:] = rope_rotated + return out.to(x.dtype) + + +def _deepseek_v4_qnorm_rope_kv_insert_reference( + q: torch.Tensor, + kv: torch.Tensor, + k_cache: torch.Tensor, + slot_mapping: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + eps: float, + cache_block_size: int, + q_head_norm: nn.Module, + rope_dim: int, +) -> None: + """PyTorch/Triton reference for ROCm builds where the fused CUDA op is absent.""" + head_dim = q.shape[-1] + + q.copy_( + _apply_rope_gptj_last_dims( + q_head_norm(q.reshape(-1, head_dim)).view_as(q), + positions, + cos_sin_cache, + rope_dim, + ) + ) + + num_tokens_insert = slot_mapping.shape[0] + if num_tokens_insert == 0: + return + + kv_slice = kv[:num_tokens_insert] + pos_slice = positions[:num_tokens_insert] + kv_roped = _apply_rope_gptj_last_dims( + kv_slice, pos_slice, cos_sin_cache, rope_dim + ) + quantize_and_insert_k_cache( + kv_roped, + k_cache, + slot_mapping, + block_size=cache_block_size, + ) + + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time # reservation in attention_impl's dummy-run branch). @@ -441,16 +522,31 @@ def _fused_qnorm_rope_kv_insert( # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE # KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert # kv is unchanged; mla_attn reads kv solely via swa_kv_cache. - torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - q, - kv, - swa_kv_cache_2d, - swa_metadata.slot_mapping, - positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, - self.eps, - swa_metadata.block_size, + fused_op = getattr( + torch.ops._C, + "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", + None, ) + pos_i64 = positions.to(torch.int64) + cos_sin = self.rotary_emb.cos_sin_cache + block_sz = swa_metadata.block_size + slot_map = swa_metadata.slot_mapping + + if fused_op is not None: + fused_op(q, kv, swa_kv_cache_2d, slot_map, pos_i64, cos_sin, self.eps, block_sz) + else: + _deepseek_v4_qnorm_rope_kv_insert_reference( + q, + kv, + swa_kv_cache_2d, + slot_map, + pos_i64, + cos_sin, + self.eps, + block_sz, + self.q_head_norm, + self.rope_head_dim, + ) def deepseek_v4_attention( @@ -485,6 +581,99 @@ def deepseek_v4_attention_fake( ) +def _fp8_einsum_torch_fallback( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: tuple[int, int, int], +) -> None: + """Pure-torch reference for DeepseekV4's block-scaled FP8 einsum. + + Used when DeepGEMM's ``fp8_einsum`` is unavailable (e.g. ROCm). Slow + but correct: dequantizes both operands to bf16 using the per-block + scales implied by ``recipe`` and runs the einsum natively. + + Only the einsum/recipe combinations actually emitted by + ``DeepseekV4MLAAttention.forward`` are supported; anything else + raises ``NotImplementedError`` so we fail loudly rather than + silently produce wrong results. + """ + if equation != "bhr,hdr->bhd": + raise NotImplementedError( + f"FP8 einsum torch fallback only supports 'bhr,hdr->bhd' " + f"(DeepseekV4 wo_a projection); got '{equation}'." + ) + + m_block, n_block, k_block = recipe + + # Recover the logical (H, D, R) layout for ``b``. ``wo_a`` is a + # ColumnParallelLinear with ``out_features = n_groups * o_lora_rank`` + # marked ``is_bmm=True``: the weight is stored 2-D as ``(H*D, R)`` with + # H = n_local_groups (the leading group dim) and D = o_lora_rank, and + # the FP8 GEMM treats the leading H slices as batched. Same trick for + # the per-block weight scale ``(H*D/n_block, R/k_block)``. + h_groups = a.shape[-2] + d_out = out.shape[-1] + r_contract = a.shape[-1] + + b_3d = b + if b.dim() == 2: + if b.shape[0] != h_groups * d_out or b.shape[1] != r_contract: + raise RuntimeError( + f"Unexpected wo_a weight shape {tuple(b.shape)}; " + f"expected ({h_groups * d_out}, {r_contract}) for " + f"H={h_groups}, D={d_out}, R={r_contract}." + ) + b_3d = b.view(h_groups, d_out, r_contract) + elif b.dim() != 3: + raise RuntimeError( + f"Expected wo_a weight to be 2-D or 3-D, got {b.dim()}-D" + ) + + n_d_scale = (d_out + n_block - 1) // n_block if n_block > 1 else d_out + n_r_scale = ( + (r_contract + k_block - 1) // k_block if k_block > 1 else r_contract + ) + + b_scale_3d = b_scale + if b_scale.dim() == 2: + if b_scale.shape != (h_groups * n_d_scale, n_r_scale): + raise RuntimeError( + f"Unexpected wo_a scale shape {tuple(b_scale.shape)}; " + f"expected ({h_groups * n_d_scale}, {n_r_scale})." + ) + b_scale_3d = b_scale.view(h_groups, n_d_scale, n_r_scale) + + a_f32 = a.to(torch.float32) + b_f32 = b_3d.to(torch.float32) + a_scale_f32 = a_scale.to(torch.float32).contiguous() + b_scale_f32 = b_scale_3d.to(torch.float32).contiguous() + + # a: (B, H, R) a_scale: (B, H, R // k_block) + a_scale_r = a_scale_f32 + if k_block > 1: + a_scale_r = a_scale_r.repeat_interleave(k_block, dim=-1) + a_scale_r = a_scale_r[..., :r_contract] + if m_block > 1: + a_scale_r = a_scale_r.repeat_interleave(m_block, dim=0)[: a_f32.shape[0]] + a_bf16 = (a_f32 * a_scale_r).to(torch.bfloat16) + + # b: (H, D, R) b_scale: (H, D // n_block, R // k_block) + b_scale_dr = b_scale_f32 + if k_block > 1: + b_scale_dr = b_scale_dr.repeat_interleave(k_block, dim=-1) + if n_block > 1: + b_scale_dr = b_scale_dr.repeat_interleave(n_block, dim=-2) + b_scale_dr = b_scale_dr[..., :d_out, :r_contract] + b_bf16 = (b_f32 * b_scale_dr).to(torch.bfloat16) + + result = torch.einsum(equation, a_bf16, b_bf16) + out.copy_(result.to(out.dtype)) + + def deepseek_v4_fp8_einsum( a: torch.Tensor, a_scale: torch.Tensor, @@ -494,7 +683,19 @@ def deepseek_v4_fp8_einsum( equation: str, recipe: list[int], ) -> None: - fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + # DeepGEMM's fp8_einsum is the canonical fast path on NVIDIA. On + # platforms without it (e.g. ROCm), fall back to a torch dequant + + # einsum reference. The choice is made at call time (not import) so + # this op stays usable in unit tests that mock current_platform. + from vllm.platforms import current_platform + + if current_platform.is_cuda(): + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + return + + _fp8_einsum_torch_fallback( + a, a_scale, b, b_scale, out, equation, tuple(recipe) + ) def deepseek_v4_fp8_einsum_fake( diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 84eaad7f65e6..0de5983881be 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -57,6 +57,87 @@ def vllm_topk_sigmoid( return topk_weights, topk_indices +def _topk_softplus_sqrt_torch( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, + routed_scaling_factor: float, + e_score_correction_bias: torch.Tensor | None, + input_tokens: torch.Tensor | None, + hash_indices_table: torch.Tensor | None, +) -> None: + # Reference implementation of csrc/moe/topk_softplus_sqrt_kernels.cu used + # on platforms where the fused kernel is unavailable (e.g. ROCm). Math + # mirrors the kernel exactly: weight_base = sqrt(softplus(x)) per expert, + # bias is added only for ranking (subtracted back from output), then + # optional renormalize + routed_scaling_factor. + num_tokens, num_experts = gating_output.shape + topk = topk_weights.shape[-1] + + # softplus(x) with beta=1 and the same numerical-stability cutoff used by + # the kernel ((val_b > 20) ? val : log1p(exp(val_b)) / beta). + x_f32 = gating_output.to(torch.float32) + softplus_x = torch.nn.functional.softplus(x_f32, beta=1.0, threshold=20.0) + weights_base = torch.sqrt(softplus_x) # (T, E) + + use_hash = ( + input_tokens is not None and hash_indices_table is not None + ) + + if use_hash: + # tid2eid: (V, k); input_tokens: (T,) -> selected_experts: (T, k) + tid2eid = hash_indices_table + selected_experts = tid2eid[input_tokens.to(torch.long)] + selected_weights = torch.gather( + weights_base, -1, selected_experts.to(torch.long) + ) + if renormalize: + denom = selected_weights.sum(dim=-1, keepdim=True) + denom = torch.where( + denom > 0, denom, torch.ones_like(denom) + ) + selected_weights = selected_weights / denom + selected_weights = selected_weights * routed_scaling_factor + + topk_weights.copy_(selected_weights.to(topk_weights.dtype)) + topk_indices.copy_(selected_experts.to(topk_indices.dtype)) + # The CUDA kernel leaves token_expert_indices untouched in the hash + # path, so we mirror that (caller treats it as scratch in this case). + return + + if e_score_correction_bias is not None: + ranking = weights_base + e_score_correction_bias.to(torch.float32) + else: + ranking = weights_base + + _, topk_ids = torch.topk(ranking, topk, dim=-1) + out_weights = torch.gather(weights_base, -1, topk_ids) + if renormalize: + denom = out_weights.sum(dim=-1, keepdim=True) + denom = torch.where(denom > 0, denom, torch.ones_like(denom)) + out_weights = out_weights / denom + out_weights = out_weights * routed_scaling_factor + + topk_weights.copy_(out_weights.to(topk_weights.dtype)) + topk_indices.copy_(topk_ids.to(topk_indices.dtype)) + + # token_expert_indices[t, k_idx] = k_idx * T + t (matches kernel's + # source_rows write at line 388 of topk_softplus_sqrt_kernels.cu). + arange_t = torch.arange( + num_tokens, + device=gating_output.device, + dtype=token_expert_indices.dtype, + ).unsqueeze(-1) + arange_k = torch.arange( + topk, + device=gating_output.device, + dtype=token_expert_indices.dtype, + ).unsqueeze(0) + token_expert_indices.copy_(arange_k * num_tokens + arange_t) + + def vllm_topk_softplus_sqrt( topk_weights: torch.Tensor, topk_indices: torch.Tensor, @@ -68,17 +149,36 @@ def vllm_topk_softplus_sqrt( hash_indices_table: torch.Tensor | None = None, routed_scaling_factor: float = 1.0, ) -> tuple[torch.Tensor, ...]: - ops.topk_hash_softplus_sqrt( - topk_weights, - topk_indices, - token_expert_indices, - gating_output, - renormalize, - routed_scaling_factor, - e_score_correction_bias, - input_tokens, - hash_indices_table, - ) + # The fused topk_softplus_sqrt CUDA kernel is gated behind #ifndef USE_ROCM + # in csrc/moe/torch_bindings.cpp and the .cu source isn't added to + # VLLM_MOE_EXT_SRC for ROCm builds (CMakeLists.txt). Fall back to a torch + # reference on platforms that don't ship the symbol. + from vllm.platforms import current_platform + + if current_platform.is_cuda(): + ops.topk_hash_softplus_sqrt( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + input_tokens, + hash_indices_table, + ) + else: + _topk_softplus_sqrt_torch( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + input_tokens, + hash_indices_table, + ) return topk_weights, topk_indices diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index 1521a6b601bf..96f894c790c3 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -6,23 +6,38 @@ import torch +from vllm.logger import logger from vllm.platforms import current_platform from vllm.utils.import_utils import has_tilelang from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import direct_register_custom_op -# tilelang is only available on CUDA platforms -if TYPE_CHECKING or current_platform.is_cuda_alike(): - if not has_tilelang(): - raise ImportError( - "tilelang is required for mhc but is not installed. Install it with " - "`pip install tilelang`." - ) +# tilelang only ships kernels for NVIDIA CUDA targets and the mHC kernels +# in this file additionally rely on Hopper-only PDL primitives +# (T.pdl_sync/T.pdl_trigger) and PTXAS register tuning. On non-CUDA +# platforms (e.g. ROCm), fall back to a torch reference implementation. +_USE_TILELANG = ( + TYPE_CHECKING or current_platform.is_cuda() +) and has_tilelang() + +if _USE_TILELANG: import tilelang import tilelang.language as T else: tilelang = None # type: ignore[assignment] T = None # type: ignore[assignment] + if current_platform.is_cuda() and not has_tilelang(): + # Preserve the previous CUDA-only requirement: tilelang is the + # canonical fast path on NVIDIA. Surface the missing dependency + # loudly there so users do not silently fall onto the slow path. + raise ImportError( + "tilelang is required for mhc but is not installed. Install it with " + "`pip install tilelang`." + ) + logger.info_once( + "tilelang is unavailable on this platform; using torch reference " + "implementation for DeepSeek-V4 mHC pre/post blocks." + ) @cache @@ -38,12 +53,27 @@ def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: return split_k -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, +def _tilelang_jit(*args, **kwargs): + """Decorator that becomes a no-op when tilelang is unavailable.""" + if _USE_TILELANG: + return tilelang.jit(*args, **kwargs) + + def _decorator(fn): + return fn + + return _decorator + + +@_tilelang_jit( + pass_configs=( + { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + } + if _USE_TILELANG + else {} + ), ) def mhc_pre_big_fuse_tilelang( gemm_out_mul, @@ -178,6 +208,74 @@ def mhc_pre_big_fuse_tilelang( T.pdl_trigger() +def _mhc_pre_torch( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pure-torch reference for ``mhc_pre``. + + Mirrors ``mhc_pre_big_fuse_tilelang`` (RMS-norm scaling of a fused + GEMM, then sigmoid+bias for pre/post mixes, softmax+Sinkhorn for the + comb mix, and a residual-blend to produce ``layer_input``). Used on + platforms without a working tilelang/DeepGEMM stack (e.g. ROCm). + """ + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + outer_shape = residual.shape[:-2] + + residual_flat = residual.reshape(-1, hc_mult, hidden_size) + num_tokens = residual_flat.shape[0] + + x_f32 = residual_flat.reshape(num_tokens, hc_mult * hidden_size).to( + torch.float32 + ) + mixes = torch.matmul(x_f32, fn.t()) + sqrsum = x_f32.square().sum(dim=-1) + + rms = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps) + mixes = mixes * rms.unsqueeze(-1) + + pre_part = mixes[:, :hc_mult] + post_part = mixes[:, hc_mult : 2 * hc_mult] + comb_part = mixes[:, 2 * hc_mult :].reshape(num_tokens, hc_mult, hc_mult) + + post_base = hc_base[hc_mult : 2 * hc_mult] + post_mix = ( + torch.sigmoid(post_part * hc_scale[1] + post_base) * hc_post_mult_value + ) + + comb_base = hc_base[2 * hc_mult :].reshape(hc_mult, hc_mult) + cm = comb_part * hc_scale[2] + comb_base + cm = torch.softmax(cm, dim=-1) + hc_sinkhorn_eps + cm = cm / (cm.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) + for _ in range(max(0, sinkhorn_repeat - 1)): + cm = cm / (cm.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps) + cm = cm / (cm.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) + comb_mix_flat = cm.reshape(num_tokens, hc_mult2) + + pre_base = hc_base[:hc_mult] + pre_mix = torch.sigmoid(pre_part * hc_scale[0] + pre_base) + hc_pre_eps + + layer_input_f32 = torch.einsum( + "bn,bnh->bh", pre_mix, residual_flat.to(torch.float32) + ) + layer_input = layer_input_f32.to(torch.bfloat16) + + post_mix = post_mix.view(*outer_shape, hc_mult, 1) + comb_mix = comb_mix_flat.view(*outer_shape, hc_mult, hc_mult) + layer_input = layer_input.view(*outer_shape, hidden_size) + return post_mix, comb_mix, layer_input + + def mhc_pre( residual: torch.Tensor, fn: torch.Tensor, @@ -228,6 +326,19 @@ def mhc_pre( assert hc_scale.shape == (3,) assert hc_base.shape == (hc_mult3,) + if not _USE_TILELANG: + return _mhc_pre_torch( + residual, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + ) + outer_shape = residual.shape[:-2] residual_flat = residual.view(-1, hc_mult, hidden_size) @@ -349,12 +460,16 @@ def _mhc_pre_fake( return post_mix, comb_mix, layer_input -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, +@_tilelang_jit( + pass_configs=( + { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + } + if _USE_TILELANG + else {} + ), ) def mhc_post_tilelang( a, @@ -366,7 +481,7 @@ def mhc_post_tilelang( hidden: int, n_thr: int = 128, h_blk: int = 1024, -) -> tilelang.JITKernel: +): # rename for shorter code n = T.dynamic("num_tokens") h = hidden @@ -408,12 +523,37 @@ def mhc_post_tilelang( T.pdl_trigger() +def _mhc_post_torch( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + """Pure-torch reference for ``mhc_post``. + + Mirrors ``mhc_post_tilelang``: + out[..., i_hco, h] = post_layer_mix[..., i_hco, 0] * x[..., h] + + sum_{i_hci}(comb_res_mix[..., i_hci, i_hco] + * residual[..., i_hci, h]) + + Equivalently: ``post * x + comb.transpose(-1,-2) @ residual``. + """ + x_f32 = x.to(torch.float32).unsqueeze(-2) + residual_f32 = residual.to(torch.float32) + term1 = post_layer_mix * x_f32 + term2 = torch.matmul(comb_res_mix.transpose(-1, -2), residual_f32) + return (term1 + term2).to(torch.bfloat16) + + def mhc_post( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: + if not _USE_TILELANG: + return _mhc_post_torch(x, residual, post_layer_mix, comb_res_mix) + out = torch.empty_like(residual) mhc_post_tilelang( comb_res_mix, diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index ca82f2feb7ef..3332f26c8c48 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -32,6 +32,13 @@ elif current_platform.is_xpu(): from vllm._xpu_ops import xpu_ops +# Registers `vllm::rocm_sparse_attn_indexer_no_insert` (the V4 layout where the +# compressor pre-inserts K and the indexer is called with k=None). +# Keep this import at module scope so the op is visible at compile time, not +# just on the first forward. +if current_platform.is_rocm(): + import vllm.v1.attention.ops.rocm_sparse_attn_indexer # noqa: F401 + logger = init_logger(__name__) RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 @@ -499,13 +506,34 @@ def forward_hip( k: torch.Tensor, weights: torch.Tensor, ): - assert not self.skip_k_cache_insert, ( - "AMD platform doesn't support skip cache insert yet" - ) assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet" assert isinstance(q_quant, torch.Tensor), ( "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" ) + + if self.skip_k_cache_insert: + # DeepSeek-V4 layout: the compressor has already inserted the + # compressed K into the indexer's KV cache and passes k=None. + # The AITER op below always issues its own + # ``ops.indexer_k_quant_and_cache(k, ...)`` and dereferences ``k``, + # so it can't be reused here. Use the dedicated no-insert ROCm + # path that uses only ROCm-available helpers (and a Triton MQA + # kernel under the hood). + return torch.ops.vllm.rocm_sparse_attn_indexer_no_insert( + hidden_states, + _encode_layer_name(self.k_cache.prefix), + self.k_cache.kv_cache, + q_quant, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) + if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, @@ -525,5 +553,6 @@ def forward_hip( else: raise RuntimeError( "Sparse attention indexer ROCm custom op requires ROCm " - "Aiter ops to be enabled." + "Aiter ops to be enabled (or skip_k_cache_insert=True for " + "the V4 layout)." ) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index e26b511de4ce..3aa0474c340c 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -303,7 +303,16 @@ def cublas_gemm_bf16_bf16_fp32( x: torch.Tensor, weight: torch.Tensor, ): - return ops.router_gemm_bf16_fp32(x, weight) + # The fused C++ op (csrc/moe/router_gemm.cu, registered via + # torch_bindings.cpp's `router_gemm_bf16_fp32`) is gated behind + # `#ifndef USE_ROCM` and is only compiled into _moe_C.so on CUDA builds. + # On other backends (e.g. ROCm) we fall back to a torch GEMM with the + # same bf16-in / fp32-out contract. rocBLAS already does fp32 accumulation + # internally for bf16 GEMMs on MI300X, so casting the bf16 output to fp32 + # matches the cuBLAS bf16 x bf16 -> fp32 path numerically. + if current_platform.is_cuda(): + return ops.router_gemm_bf16_fp32(x, weight) + return torch.nn.functional.linear(x, weight).to(torch.float32) def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 97f755240a4c..0e7662579d24 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -153,8 +153,10 @@ def expert_dtype(self) -> str: except Exception: # vllm_config not yet set; defer the decision until a # later call lands inside set_current_vllm_config. - return "fp4" - expert_dtype = getattr(hf_config, "expert_dtype", "fp4") + #return "fp4" + return "fp8" + #expert_dtype = getattr(hf_config, "expert_dtype", "fp4") + expert_dtype = getattr(hf_config, "expert_dtype", "fp8") if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES: raise ValueError( f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; " @@ -1507,14 +1509,16 @@ class DeepseekV4ForCausalLM(nn.Module): # Default mapper assumes the original FP4-expert checkpoint layout. # Overridden per-instance in __init__ when expert_dtype != "fp4". - hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") + #hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") + hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config - expert_dtype = getattr(config, "expert_dtype", "fp4") + #expert_dtype = getattr(config, "expert_dtype", "fp4") + expert_dtype = "fp8" if expert_dtype != "fp4": self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 866b9ffd1a6d..e4d17eeaa969 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -425,6 +425,7 @@ class RocmPlatform(Platform): "fp8_per_block", "online", "gpt_oss_mxfp4", + "deepseek_v4_fp8", ] @classmethod diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index f4866a702dd9..4bfa2bf76996 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -20,4 +20,36 @@ LOG2E = 1.4426950408889634 LOGE2 = 0.6931471805599453 -__all__ = ["HAS_TRITON", "triton", "tl", "tldevice", "LOG2E", "LOGE2"] + +def maybe_launch_pdl(value: bool = False) -> dict: + """Return launch metadata for Triton kernel calls that may use PDL. + + The ``launch_pdl`` launch attribute (Programmatic Dependent Launch) is + a NVIDIA Hopper SM90+ feature exposed by NVIDIA's Triton runtime. + Other Triton backends (notably ROCm/HIP) do not recognize this kwarg + and raise ``KeyError`` from ``JITKernel._pack_args``. Use this helper + in the kernel call site: + + kernel[grid](..., **maybe_launch_pdl()) + + so the attribute is only forwarded on platforms whose Triton runtime + supports it. + """ + # Lazy import to avoid pulling in the full platform stack at module + # import time of vllm.triton_utils. + from vllm.platforms import current_platform + + if current_platform.is_cuda(): + return {"launch_pdl": value} + return {} + + +__all__ = [ + "HAS_TRITON", + "triton", + "tl", + "tldevice", + "LOG2E", + "LOGE2", + "maybe_launch_pdl", +] diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c33203..c9d8d12c621f 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -469,7 +469,14 @@ def tf32_hc_prenorm_gemm( out = x.float() @ fn.T sqrsum = x.float().square().sum(-1) - See the caller function for shape requirement + See the caller function for shape requirement. + + The DeepGEMM kernel splits the K dimension into ``num_split`` partial + sums for parallelism (``out`` has a leading ``num_split`` axis and the + consumer reduces over it). When DeepGEMM is not available (e.g. on + ROCm), fall back to a single-shot torch matmul written into split 0 + while zeroing the remaining splits, which is mathematically equivalent + after the consumer's reduction. """ _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py index 97c9538889a1..d9ad22ae0556 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py @@ -9,7 +9,7 @@ import torch -from vllm.triton_utils import tl, triton +from vllm.triton_utils import maybe_launch_pdl, tl, triton @triton.jit @@ -224,7 +224,9 @@ def fused_inv_rope_fp8_quant( HALF_ROPE=rope_dim // 2, TMA_ALIGNED_SCALES=tma_aligned_scales, num_stages=1, - launch_pdl=False, + # PDL is a NVIDIA Hopper-only Triton launch attribute; omit on + # other backends (e.g. ROCm) to avoid KeyError in JITKernel. + **maybe_launch_pdl(), ) grid = (tma_aligned_T, n_groups * heads_per_group) diff --git a/vllm/v1/attention/ops/flashmla.py b/vllm/v1/attention/ops/flashmla.py index df04f5bf2289..5e4442642e5a 100644 --- a/vllm/v1/attention/ops/flashmla.py +++ b/vllm/v1/attention/ops/flashmla.py @@ -93,6 +93,31 @@ def _raise_flashmla_unavailable(*_args, **_kwargs): flash_mla_with_kvcache, get_mla_metadata, ) +elif current_platform.is_rocm(): + # ROCm path: substitute the V4-relevant entry points with pure-torch + + # Triton fallbacks. The dense / varlen variants have no V4 caller and + # remain hard errors so misuse is loud. + from vllm.v1.attention.ops.rocm_flash_mla_sparse import ( + _FlashMLASchedMetaStub as FlashMLASchedMeta, # noqa: F401 + ) + from vllm.v1.attention.ops.rocm_flash_mla_sparse import ( + flash_mla_sparse_fwd_rocm as flash_mla_sparse_fwd, + ) + from vllm.v1.attention.ops.rocm_flash_mla_sparse import ( + flash_mla_with_kvcache_rocm as flash_mla_with_kvcache, + ) + from vllm.v1.attention.ops.rocm_flash_mla_sparse import ( + get_mla_metadata_rocm as get_mla_metadata, + ) + + flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment] + flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] + flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] + logger.info_once( + "FlashMLA C extension unavailable on ROCm; using pure-torch + Triton " + "sparse-attention fallback for DeepSeek-V4 (flash_mla_sparse_fwd, " + "flash_mla_with_kvcache, get_mla_metadata)." + ) else: class FlashMLASchedMeta: # type: ignore[no-redef] diff --git a/vllm/v1/attention/ops/rocm_flash_mla_sparse.py b/vllm/v1/attention/ops/rocm_flash_mla_sparse.py new file mode 100644 index 000000000000..e356e64d2274 --- /dev/null +++ b/vllm/v1/attention/ops/rocm_flash_mla_sparse.py @@ -0,0 +1,648 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ROCm fallback for DeepSeek-V4's FlashMLA sparse attention kernels. + +The official FlashMLA kernels (``flash_mla_sparse_fwd`` for prefill and the +V4-extended ``flash_mla_with_kvcache`` for decode) are NVIDIA-only — they live +in the ``vllm._flashmla_C`` extension which is not built on ROCm. The wrapper in +``vllm/v1/attention/ops/flashmla.py`` raises ``RuntimeError`` for both calls on +non-CUDA platforms, which crashes DeepSeek-V4 inference at the first generation +step. + +This module provides ROCm-friendly equivalents: + +* ``flash_mla_sparse_fwd_rocm`` — sparse attention over a *bf16* KV pool. The + V4 prefill path pre-dequantizes the FP8 cache via + :func:`vllm.v1.attention.ops.deepseek_v4_ops.dequantize_and_gather_k_cache` + (Triton, works on ROCm), then feeds bf16 ``kv`` into FlashMLA. We can run the + same sparse softmax+gemm in chunked online-softmax form on top of the + dequantized KV without needing the FP8-aware kernel. + +* ``flash_mla_with_kvcache_rocm`` — decode path. Here FlashMLA reads the + FP8 ``swa_cache`` (and optionally a global compressed ``extra_k_cache``) + directly via ``is_fp8_kvcache=True``. We dequantize the requested slots on + the fly with a small Triton kernel (mirroring + ``_dequantize_and_gather_k_kernel`` but indexed by arbitrary global slot ids + instead of a block table), then run the same chunked sparse attention. + +* ``get_mla_metadata_rocm`` — returns an empty ``FlashMLASchedMeta`` stub so + the V4 SWA metadata builder can populate ``tile_sched_*`` fields without + crashing. The metadata is unused by our fallback path. + +Both attention paths use *online softmax* with a bounded ``chunk_topk`` over +the candidate axis so peak intermediate memory stays manageable even with +many query tokens × thousands of selected positions. + +Numerics notes +-------------- +* The softmax includes the per-head ``attn_sink`` logit as an extra column + whose value is dropped before the ``attn @ V`` reduction (matches FlashMLA + semantics: sink mass affects the partition function only). +* Invalid ``indices == -1`` entries are masked with ``-inf`` so they never + contribute, regardless of what we (safely) dequantize at slot 0. +* Rows where every candidate is invalid AND ``attn_sink == -inf`` produce a + zero output (we trap the all-``-inf`` case to avoid NaNs from ``exp(0)/0``). +""" +from __future__ import annotations + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON, tl, triton + +logger = init_logger(__name__) + +# --------------------------------------------------------------------------- +# Cache layout constants — must mirror +# vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py. +# --------------------------------------------------------------------------- +_FP8_DIM = 448 +_BF16_DIM = 64 +_SCALE_DIM = 8 +_QUANT_BLOCK_SIZE = 64 +_TOKEN_DATA_SIZE = _FP8_DIM + _BF16_DIM * 2 # 576 +_HEAD_DIM = _FP8_DIM + _BF16_DIM # 512 +_N_QUANT_BLOCKS = 7 # 7 real (448 // 64), 1 pad slot at index 7 + +# Chunk size for online-softmax over the candidate axis. 128 keeps memory +# small (~64 MiB for T_q=512, head_dim=512, bf16) while letting the matmul +# inside torch see enough work to be efficient. +_DEFAULT_CHUNK_TOPK = 128 + + +# --------------------------------------------------------------------------- +# FP8 slot dequantization (decode path). +# --------------------------------------------------------------------------- +if HAS_TRITON and current_platform.is_cuda_alike(): + + @triton.jit + def _gather_dequant_slots_kernel( + out_ptr, # (N, head_dim) bf16 + out_stride_n, + indices_ptr, # (N,) int32, -1 = invalid (still safely dequant slot 0) + k_cache_ptr, # uint8 byte buffer + block_stride, # bytes per block + cache_block_size: tl.constexpr, + fp8_dim: tl.constexpr, + bf16_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + token_data_size: tl.constexpr, + head_dim: tl.constexpr, + n_quant_blocks: tl.constexpr, + N, + ): + pid = tl.program_id(0) + if pid >= N: + return + + raw_slot = tl.load(indices_ptr + pid) + # Always dequant slot >= 0 to keep the kernel branch-free; the + # caller masks invalid indices in the attention softmax. + slot = tl.maximum(raw_slot, 0) + + out_row_ptr = out_ptr + pid * out_stride_n + + block_idx = (slot // cache_block_size).to(tl.int64) + pos_in_block = slot % cache_block_size + + cache_block_ptr = k_cache_ptr + block_idx * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + token_fp8_ptr = token_data_ptr + token_bf16_ptr = token_data_ptr + fp8_dim + + # Dequantize the 448 FP8 dims in 7 blocks of 64. + for qblock_idx in tl.static_range(n_quant_blocks): + qblock_start = qblock_idx * quant_block + if qblock_start < fp8_dim: + offsets = qblock_start + tl.arange(0, quant_block) + mask = offsets < fp8_dim + x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + encoded_scale = tl.load(token_scale_ptr + qblock_idx) + exponent = encoded_scale.to(tl.float32) - 127.0 + scale = tl.exp2(exponent) + x_dequant = x_float * scale + tl.store( + out_row_ptr + offsets, + x_dequant.to(tl.bfloat16), + mask=mask, + ) + + # Copy the trailing 64 bf16 dims unchanged. + bf16_out_ptr = out_row_ptr + fp8_dim + bf16_cache_bf16_ptr = token_bf16_ptr.to(tl.pointer_type(tl.bfloat16)) + for j in tl.static_range(bf16_dim // 16): + chunk_offsets = j * 16 + tl.arange(0, 16) + bf16_vals = tl.load(bf16_cache_bf16_ptr + chunk_offsets) + tl.store(bf16_out_ptr + chunk_offsets, bf16_vals) +else: + _gather_dequant_slots_kernel = None # type: ignore[assignment] + + +def _gather_dequant_slots_triton( + indices: torch.Tensor, # (N,) int32 — global slot ids, -1 allowed + k_cache: torch.Tensor, # uint8 (num_blocks, ...) — byte buffer + out: torch.Tensor, # (N, head_dim) bf16 output buffer +) -> None: + """Triton gather + UE8M0 FP8 dequant for arbitrary global slot ids.""" + assert _gather_dequant_slots_kernel is not None + assert k_cache.dtype == torch.uint8, ( + f"k_cache must be uint8 byte buffer, got {k_cache.dtype}" + ) + assert out.dtype == torch.bfloat16 + assert out.shape == (indices.shape[0], _HEAD_DIM) + assert indices.is_contiguous() + assert out.is_contiguous() + + block_stride = k_cache.stride(0) + n = indices.shape[0] + if n == 0: + return + + # Block size in *tokens*. The cache is shaped (num_blocks, block_size, 584) + # in the metadata, so dim 1 is the token count per block. + if k_cache.dim() >= 2: + cache_block_size = k_cache.shape[1] + else: + # 1D byte buffer; assume 64 (the default DeepSeek block size). + cache_block_size = 64 + + _gather_dequant_slots_kernel[(n,)]( + out, + out.stride(0), + indices, + k_cache, + block_stride, + cache_block_size=cache_block_size, + fp8_dim=_FP8_DIM, + bf16_dim=_BF16_DIM, + scale_dim=_SCALE_DIM, + quant_block=_QUANT_BLOCK_SIZE, + token_data_size=_TOKEN_DATA_SIZE, + head_dim=_HEAD_DIM, + n_quant_blocks=_N_QUANT_BLOCKS, + N=n, + ) + + +def _gather_dequant_slots_torch( + indices: torch.Tensor, + k_cache: torch.Tensor, + out: torch.Tensor, +) -> None: + """Pure-torch reference for ``_gather_dequant_slots_triton``. + + Slow but correct — useful for environments without a Triton runtime and + for unit-style sanity checks. Implements the same UE8M0 FP8 dequant + bf16 + copy as the Triton kernel. + """ + assert k_cache.dtype == torch.uint8 + assert out.dtype == torch.bfloat16 + n = indices.shape[0] + if n == 0: + return + + block_stride = k_cache.stride(0) + cache_block_size = k_cache.shape[1] if k_cache.dim() >= 2 else 64 + flat_cache = k_cache.view(torch.uint8).contiguous().view(-1) + + safe = indices.clamp(min=0).to(torch.int64) + block_idx = safe // cache_block_size + pos_in_block = safe % cache_block_size + + # Per-token base byte offsets for the data and scale regions. + base = block_idx * block_stride + data_base = base + pos_in_block * _TOKEN_DATA_SIZE # (N,) + scale_base = ( + base + cache_block_size * _TOKEN_DATA_SIZE + pos_in_block * _SCALE_DIM + ) # (N,) + + # ---- FP8 NoPE (448 dims) ---- + fp8_offsets = data_base.unsqueeze(-1) + torch.arange( + _FP8_DIM, device=indices.device, dtype=torch.int64 + ) + fp8_bytes = flat_cache[fp8_offsets.flatten()].view(n, _FP8_DIM) + fp8_vals = fp8_bytes.view(torch.float8_e4m3fn).to(torch.float32) + + # 7 UE8M0 scales, 1 byte each. + scale_offsets = scale_base.unsqueeze(-1) + torch.arange( + _N_QUANT_BLOCKS, device=indices.device, dtype=torch.int64 + ) + scale_bytes = flat_cache[scale_offsets.flatten()].view(n, _N_QUANT_BLOCKS) + exponents = scale_bytes.to(torch.float32) - 127.0 + scales = torch.exp2(exponents) # (N, 7) + # Repeat each scale across its 64-element block. + scales_per_dim = scales.repeat_interleave(_QUANT_BLOCK_SIZE, dim=-1) + nope = (fp8_vals * scales_per_dim).to(torch.bfloat16) + + # ---- BF16 RoPE (64 dims) ---- + bf16_byte_offsets = ( + data_base + _FP8_DIM + ).unsqueeze(-1) + torch.arange( + _BF16_DIM * 2, device=indices.device, dtype=torch.int64 + ) + bf16_bytes = flat_cache[bf16_byte_offsets.flatten()].view(n, _BF16_DIM * 2) + rope = bf16_bytes.view(torch.bfloat16).view(n, _BF16_DIM) + + out.copy_(torch.cat([nope, rope], dim=-1)) + + +def _gather_dequant_slots( + indices: torch.Tensor, + k_cache: torch.Tensor, + out: torch.Tensor, +) -> None: + """Dispatch to Triton when available, otherwise pure torch.""" + if _gather_dequant_slots_kernel is not None and indices.is_cuda: + _gather_dequant_slots_triton(indices, k_cache, out) + else: + _gather_dequant_slots_torch(indices, k_cache, out) + + +# --------------------------------------------------------------------------- +# Sparse attention with online softmax (chunked over the candidate axis). +# --------------------------------------------------------------------------- +def _online_softmax_init( + t_q: int, + num_heads: int, + head_dim_v: int, + attn_sink: torch.Tensor | None, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Seed the (m, l, O) running state with the per-head ``attn_sink`` logit. + + The sink contributes mass exp(sink) to the partition function but no V + contribution, so we initialize: + m = sink (or -inf if no sink) + l = exp(sink - m) = 1 (or 0 if sink == -inf) + O = 0 + """ + if attn_sink is not None: + sink = attn_sink.to(torch.float32).view(1, num_heads).expand(t_q, num_heads) + m = sink.contiguous() + else: + m = torch.full((t_q, num_heads), float("-inf"), dtype=torch.float32, device=device) + + finite_sink = torch.isfinite(m) + l = torch.where(finite_sink, torch.ones_like(m), torch.zeros_like(m)) + O = torch.zeros((t_q, num_heads, head_dim_v), dtype=torch.float32, device=device) + return m, l, O + + +def _online_softmax_update( + m: torch.Tensor, # (T_q, H) running max + l: torch.Tensor, # (T_q, H) running denominator + O: torch.Tensor, # (T_q, H, head_dim_v) running output (fp32) + scores: torch.Tensor, # (T_q, H, c) new logits (fp32, -inf for invalid) + V_chunk: torch.Tensor, # (T_q, c, head_dim_v) bf16/fp32 V values +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """One online-softmax step. + + Numerical care: when a row's running max stays ``-inf`` (no candidate yet + finite) we keep ``O = 0`` and ``l = 0`` and just track the new max so the + next chunk can rebase from it. + """ + chunk_max = scores.amax(dim=-1) # (T_q, H) + new_m = torch.maximum(m, chunk_max) # (T_q, H) + + # Avoid -inf - -inf = nan when both old and new max are still -inf. + finite_old = torch.isfinite(m) & torch.isfinite(new_m) # (T_q, H) + scale_old = torch.where( + finite_old, + torch.exp(m - torch.where(finite_old, new_m, m)), + torch.zeros_like(m), + ) # (T_q, H) + + # Per-element diff: -inf - finite = -inf; finite - -inf would blow up so + # only subtract when new_m is finite. Keep the 2D mask for building + # ``safe_new_m`` (same shape as ``new_m``); unsqueeze separately for the + # 3D mask used against ``scores``. + finite_new_2d = torch.isfinite(new_m) # (T_q, H) + safe_new_m = torch.where( + finite_new_2d, new_m, torch.zeros_like(new_m) + ).unsqueeze(-1) # (T_q, H, 1) + finite_new_3d = finite_new_2d.unsqueeze(-1) # (T_q, H, 1) + e_scores = torch.where( + finite_new_3d & torch.isfinite(scores), + torch.exp(scores - safe_new_m), + torch.zeros_like(scores), + ) # (T_q, H, c) + + l_new = l * scale_old + e_scores.sum(dim=-1) # (T_q, H) + # O_new = scale_old * O + e_scores @ V_chunk + O_new = O * scale_old.unsqueeze(-1) + torch.einsum( + "thc,tcd->thd", e_scores, V_chunk.to(torch.float32) + ) # (T_q, H, head_dim_v) + return new_m, l_new, O_new + + +def _sparse_attn_chunked( + q: torch.Tensor, # (T_q, H, head_dim) bf16/fp32 + indices: torch.Tensor, # (T_q, max_topk) int32, -1 for invalid + K_provider, # callable: (idx_chunk: (T_q, c) int32) -> (T_q, c, head_dim) bf16 + sm_scale: float, + attn_sink: torch.Tensor | None, + head_dim_v: int, + chunk_topk: int = _DEFAULT_CHUNK_TOPK, +) -> torch.Tensor: + """Generic sparse attention with online softmax. + + ``K_provider`` is a callable that returns the dequantized K (bf16) for a + chunk of candidate indices. This lets the same attention loop drive both + the prefill path (already-dequantized bf16 KV pool, simple ``K_full[idx]`` + gather) and the decode path (per-slot Triton FP8 dequant). + """ + t_q, num_heads, _ = q.shape + max_topk = indices.shape[-1] + device = q.device + + m, l, O = _online_softmax_init(t_q, num_heads, head_dim_v, attn_sink, device) + q_f = q.to(torch.float32) + + for cs in range(0, max_topk, chunk_topk): + ce = min(cs + chunk_topk, max_topk) + idx_chunk = indices[:, cs:ce].contiguous() # (T_q, c) + valid = idx_chunk >= 0 # (T_q, c) + if not valid.any(): + continue + + K_chunk = K_provider(idx_chunk) # (T_q, c, head_dim) bf16 + + scores = torch.einsum( + "thd,tcd->thc", q_f, K_chunk.to(torch.float32) + ) * sm_scale # (T_q, H, c) + scores = scores.masked_fill(~valid.unsqueeze(1), float("-inf")) + + V_chunk = K_chunk[..., :head_dim_v] + m, l, O = _online_softmax_update(m, l, O, scores, V_chunk) + + # Finalize: divide by total partition function. + finite_l = l > 0 + out_f = torch.where( + finite_l.unsqueeze(-1), O / l.clamp_min(1e-30).unsqueeze(-1), torch.zeros_like(O) + ) + return out_f + + +# --------------------------------------------------------------------------- +# Prefill: K is already dequantized to bf16 by the caller. +# --------------------------------------------------------------------------- +def flash_mla_sparse_fwd_rocm( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + attn_sink: torch.Tensor | None = None, + topk_length: torch.Tensor | None = None, + out: torch.Tensor | None = None, + head_dim_v: int | None = None, + chunk_topk: int = _DEFAULT_CHUNK_TOPK, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """ROCm fallback for ``flash_mla_sparse_fwd``. + + Args: + q: ``(s_q, h_q, d)`` bf16 query. + kv: ``(s_kv, 1, d)`` bf16 KV pool (already dequantized + gathered). + indices: ``(s_q, 1, topk)`` int32 with -1 sentinel for invalid slots. + sm_scale: softmax scale factor. + attn_sink: optional ``(h_q,)`` per-head sink logit (fp32). + topk_length: kept for API parity; we use the -1 sentinel for masking. + out: optional ``(s_q, h_q, d_v_or_d)`` bf16 output buffer. + head_dim_v: V head dim (default = ``out.shape[-1]`` or ``d``). + + Returns ``(out, max_logits, lse)`` matching the upstream signature; the + optional aux outputs are ``None`` since the caller only reads ``out``. + """ + del topk_length # unused: -1 sentinel masking is sufficient + assert kv.dim() == 3 and kv.shape[1] == 1, ( + f"kv must be (s_kv, 1, d), got {kv.shape}" + ) + assert indices.dim() == 3 and indices.shape[1] == 1, ( + f"indices must be (s_q, 1, topk), got {indices.shape}" + ) + + t_q, num_heads, head_dim = q.shape + if head_dim_v is None: + head_dim_v = out.shape[-1] if out is not None else head_dim + head_dim_v = min(head_dim_v, head_dim) + + K = kv.squeeze(1) # (N_kv, d) + idx_2d = indices.squeeze(1) # (T_q, max_topk) + + def K_provider(idx_chunk: torch.Tensor) -> torch.Tensor: + safe = idx_chunk.clamp(min=0).to(torch.int64) + return K[safe] + + out_f = _sparse_attn_chunked( + q=q, + indices=idx_2d, + K_provider=K_provider, + sm_scale=sm_scale, + attn_sink=attn_sink, + head_dim_v=head_dim_v, + chunk_topk=chunk_topk, + ) + + if out is None: + out = torch.empty(t_q, num_heads, head_dim_v, dtype=q.dtype, device=q.device) + out[..., :head_dim_v].copy_(out_f.to(out.dtype)) + if out.shape[-1] > head_dim_v: + out[..., head_dim_v:].zero_() + return out, None, None + + +# --------------------------------------------------------------------------- +# Decode: K cache is FP8-packed; dequantize requested slots on the fly. +# --------------------------------------------------------------------------- +def _gather_chunk_to_bf16( + idx_chunk: torch.Tensor, # (T_q, c) int32 + k_cache: torch.Tensor, # uint8 byte buffer +) -> torch.Tensor: + """Dequantize `(T_q, c)` cache slots into a `(T_q, c, head_dim)` bf16 + tensor.""" + t_q, c = idx_chunk.shape + flat_idx = idx_chunk.reshape(-1).to(torch.int32).contiguous() + flat_out = torch.empty( + (flat_idx.shape[0], _HEAD_DIM), + dtype=torch.bfloat16, + device=idx_chunk.device, + ) + _gather_dequant_slots(flat_idx, k_cache, flat_out) + return flat_out.view(t_q, c, _HEAD_DIM) + + +def flash_mla_with_kvcache_rocm( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor | None = None, + head_dim_v: int = _HEAD_DIM, + tile_scheduler_metadata: object | None = None, + cache_seqlens: torch.Tensor | None = None, + is_fp8_kvcache: bool = True, + indices: torch.Tensor | None = None, + topk_length: torch.Tensor | None = None, + softmax_scale: float | None = None, + attn_sink: torch.Tensor | None = None, + extra_k_cache: torch.Tensor | None = None, + extra_indices_in_kvcache: torch.Tensor | None = None, + extra_topk_length: torch.Tensor | None = None, + out: torch.Tensor | None = None, + causal: bool = False, + chunk_topk: int = _DEFAULT_CHUNK_TOPK, + **_unused_kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ROCm fallback for V4-extended ``flash_mla_with_kvcache``. + + Decodes one query token per batch position by sparse attention over up to + two FP8-packed caches: + + * ``k_cache`` + ``indices`` / ``topk_length`` (SWA) + * ``extra_k_cache`` + ``extra_indices_in_kvcache`` / ``extra_topk_length`` + (global compressed cache, optional — only present on layers with + ``compress_ratio > 1``) + + The two index sets are concatenated into a single virtual KV pool with a + chunked online softmax that includes the per-head ``attn_sink``. + + Args mirror the V4 call site in ``deepseek_v4_attention._forward_decode``. + Unused-on-ROCm kwargs (``tile_scheduler_metadata``, ``cache_seqlens``, + ``num_splits``, ``causal``) are accepted for API compatibility. + """ + del tile_scheduler_metadata, cache_seqlens, block_table, causal + del topk_length, extra_topk_length # -1 sentinel masking is sufficient + + assert is_fp8_kvcache, ( + "rocm flash_mla_with_kvcache fallback requires is_fp8_kvcache=True " + "(DeepSeek-V4 always quantizes KV cache to UE8M0 FP8)" + ) + assert indices is not None, "SWA indices must be provided for V4 decode" + assert q.dim() == 4 and q.shape[1] == 1, ( + f"q must be (batch, 1, num_heads, head_dim), got {q.shape}" + ) + assert indices.dim() == 3 and indices.shape[1] == 1, ( + f"indices must be (batch, 1, max_swa_topk), got {indices.shape}" + ) + + batch_size, _, num_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** -0.5 + head_dim_v = min(head_dim_v, head_dim) + + q_2d = q.squeeze(1) # (batch, H, head_dim) + swa_idx = indices.squeeze(1) # (batch, max_swa_topk) + + if extra_k_cache is not None: + assert extra_indices_in_kvcache is not None + assert extra_indices_in_kvcache.dim() == 3 + extra_idx = extra_indices_in_kvcache.squeeze(1) # (batch, max_extra_topk) + else: + extra_idx = None + + # Concatenate SWA + extra index sets into one virtual pool. Each pool gets + # its own dequantization closure; the index encoding tags which pool. + swa_topk = swa_idx.shape[-1] + extra_topk = extra_idx.shape[-1] if extra_idx is not None else 0 + total_topk = swa_topk + extra_topk + + # Build a single (batch, total_topk) index tensor where the second half is + # offset by a sentinel so the dispatcher can route to the right cache. + # Encoding: pool 0 = SWA (raw index), pool 1 = extra (index + 2^30). + # We rely on a closure capturing the boundary instead of bit-twiddling + # so torch.int32 stays clean. + if extra_idx is None: + combined_idx = swa_idx + else: + combined_idx = torch.cat([swa_idx, extra_idx], dim=-1) + + # Carry which slice maps to which cache by partitioning chunks at the + # SWA/extra boundary inside the loop. To keep _sparse_attn_chunked + # generic, do the chunking here and call the online-softmax primitives. + device = q.device + m, l, O = _online_softmax_init(batch_size, num_heads, head_dim_v, attn_sink, device) + q_f = q_2d.to(torch.float32) + + def step(idx_chunk: torch.Tensor, cache: torch.Tensor) -> None: + nonlocal m, l, O + valid = idx_chunk >= 0 + if not valid.any(): + return + K_chunk = _gather_chunk_to_bf16(idx_chunk, cache) + scores = torch.einsum( + "thd,tcd->thc", q_f, K_chunk.to(torch.float32) + ) * softmax_scale + scores = scores.masked_fill(~valid.unsqueeze(1), float("-inf")) + V_chunk = K_chunk[..., :head_dim_v] + m, l, O = _online_softmax_update(m, l, O, scores, V_chunk) + + # Pool 0: SWA cache. + for cs in range(0, swa_topk, chunk_topk): + ce = min(cs + chunk_topk, swa_topk) + step(swa_idx[:, cs:ce].contiguous(), k_cache) + + # Pool 1: extra (global compressed) cache. + if extra_idx is not None: + for cs in range(0, extra_topk, chunk_topk): + ce = min(cs + chunk_topk, extra_topk) + step(extra_idx[:, cs:ce].contiguous(), extra_k_cache) + + finite_l = l > 0 + out_f = torch.where( + finite_l.unsqueeze(-1), + O / l.clamp_min(1e-30).unsqueeze(-1), + torch.zeros_like(O), + ) + + if out is None: + out = torch.empty( + (batch_size, 1, num_heads, head_dim_v), + dtype=q.dtype, + device=q.device, + ) + out_view = out.squeeze(1) + out_view[..., :head_dim_v].copy_(out_f.to(out.dtype)) + if out_view.shape[-1] > head_dim_v: + out_view[..., head_dim_v:].zero_() + + # Upstream returns (out, softmax_lse). LSE isn't consumed by the V4 caller. + return out, None + + +# --------------------------------------------------------------------------- +# Stubs for FlashMLA's planner-side helpers. +# --------------------------------------------------------------------------- +class _FlashMLASchedMetaStub: + """Placeholder ``FlashMLASchedMeta`` for ROCm. + + The real CUDA struct holds tile-scheduler tensors that are populated by + the in-kernel planner on first use. Our fallback ignores it but the V4 + metadata builder still allocates one per layer type. + """ + + have_initialized: bool = False + tile_scheduler_metadata: torch.Tensor | None = None + num_splits: torch.Tensor | None = None + + +def get_mla_metadata_rocm(*_args, **_kwargs) -> tuple[_FlashMLASchedMetaStub, None]: + """ROCm stub for FlashMLA's ``get_mla_metadata``. + + Returns a fresh empty scheduler-metadata struct so the V4 + ``DeepseekSparseSWAMetadataBuilder.build_tile_scheduler`` can populate + its per-layer-type cache without crashing on platforms without FlashMLA. + """ + return _FlashMLASchedMetaStub(), None + + +__all__ = [ + "flash_mla_sparse_fwd_rocm", + "flash_mla_with_kvcache_rocm", + "get_mla_metadata_rocm", +] diff --git a/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py b/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py new file mode 100644 index 000000000000..5935323d1f80 --- /dev/null +++ b/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py @@ -0,0 +1,549 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ROCm path for DeepSeek-V4's sparse attention indexer when the compressor +has already inserted the (compressed) K into the indexer's KV cache, i.e. +``skip_k_cache_insert=True`` and the call site passes ``k=None``. + +The CUDA implementation in ``vllm/model_executor/layers/sparse_attn_indexer.py`` +relies on DeepGEMM's ``fp8_fp4_mqa_logits`` / ``fp8_fp4_paged_mqa_logits`` +which are NVIDIA-only. The existing ROCm AITER op +(``rocm_aiter_sparse_attn_indexer`` in ``rocm_aiter_mla_sparse.py``) always +performs its own ``indexer_k_quant_and_cache`` call and dereferences ``k``, +so it can't be reused for the V4 layout where the compressor pre-inserts K +and returns ``None``. + +This module fills that gap with: + * A streaming Triton MQA-logits kernel that runs on gfx9xx, computing + logits without materializing the (H, M, N) intermediate that the torch + reference does (which would OOM at long context). + * A torch fallback (``_mqa_logits_torch_inplace``) used for smoke tests and + on platforms without a usable Triton runtime. + * The orchestration (``rocm_sparse_attn_indexer_no_insert``) that mirrors + the CUDA ``sparse_attn_indexer`` body but skips the K-insert and uses + only ROCm-available helper ops (``cp_gather_indexer_k_quant_cache``, + ``top_k_per_row_prefill`` / ``top_k_per_row_decode``). +""" +from __future__ import annotations + +import torch + +import vllm.envs as envs +from vllm.forward_context import get_forward_context +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON, tl, triton +from vllm.utils.torch_utils import ( + LayerNameType, + _resolve_layer_name, + direct_register_custom_op, +) +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata +from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.v1.worker.workspace import current_workspace_manager + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops + + +# Reuse the gather-workspace helper from the CUDA module so the workspace +# layout (and therefore the size estimate during profile_run) is shared. +def _gather_workspace_shapes_fp8( + total_seq_lens: int, + head_dim: int, + fp8_dtype: torch.dtype, +) -> tuple[ + tuple[tuple[int, int], torch.dtype], tuple[tuple[int, int], torch.dtype] +]: + """FP8 path layout used by ``cp_gather_indexer_k_quant_cache``: a flat + ``(T, head_dim)`` FP8 values buffer and a ``(T, 4)`` uint8 buffer that + aliases ``(T, 1)`` float32 dequant scales (one scale per token block). + Mirrors the FP8 branch of ``_gather_workspace_shapes`` in + ``sparse_attn_indexer.py``. + """ + return ( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + + +# --------------------------------------------------------------------------- +# Triton MQA-logits kernel (prefill / chunked path). +# +# Computes `logits[m, n] = scale[n] * sum_h weights[m, h] * relu(q[m,h,:] . k[n,:])` +# without materializing the (H, M, N) intermediate. Streams over heads so the +# only per-program memory is (BLOCK_N,) accumulator + (D,) Q + (BLOCK_N, D) K. +# --------------------------------------------------------------------------- +if HAS_TRITON: + + @triton.jit + def _mqa_logits_prefill_kernel( + q_ptr, # (M, H, D) fp8 + weights_ptr, # (M, H) fp32 + k_ptr, # (N, D) fp8 + k_scale_ptr, # (N,) fp32 + cu_seqlen_ks_ptr, # (M,) int32 + cu_seqlen_ke_ptr, # (M,) int32 + logits_ptr, # (M, N) fp32 (output) + stride_qm, + stride_qh, + stride_qd, + stride_wm, + stride_wh, + stride_kn, + stride_kd, + stride_lm, + stride_ln, + M, + N, + H: tl.constexpr, + D: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + in_bounds = n_offsets < N + + ks = tl.load(cu_seqlen_ks_ptr + pid_m) + ke = tl.load(cu_seqlen_ke_ptr + pid_m) + valid = in_bounds & (n_offsets >= ks) & (n_offsets < ke) + + d_offsets = tl.arange(0, D) + + # Load K block once and reuse across heads: (BLOCK_N, D) fp32. + k_block = tl.load( + k_ptr + + n_offsets[:, None] * stride_kn + + d_offsets[None, :] * stride_kd, + mask=valid[:, None], + other=0.0, + ).to(tl.float32) + + accum = tl.zeros([BLOCK_N], dtype=tl.float32) + + for h in range(H): + q = tl.load( + q_ptr + + pid_m * stride_qm + + h * stride_qh + + d_offsets * stride_qd, + ).to(tl.float32) + w = tl.load(weights_ptr + pid_m * stride_wm + h * stride_wh).to( + tl.float32 + ) + + score = tl.sum(k_block * q[None, :], axis=1) + accum += w * tl.maximum(score, 0.0) + + k_scale = tl.load(k_scale_ptr + n_offsets, mask=valid, other=0.0) + logits = accum * k_scale + logits = tl.where(valid, logits, float("-inf")) + + tl.store( + logits_ptr + pid_m * stride_lm + n_offsets * stride_ln, + logits, + mask=in_bounds, + ) + + +def _mqa_logits_triton( + q_fp8: torch.Tensor, # (M, H, D) + k_fp8: torch.Tensor, # (N, D) + k_scale: torch.Tensor, # (N,) fp32 + weights: torch.Tensor, # (M, H) fp32 + cu_seqlen_ks: torch.Tensor, # (M,) int32 + cu_seqlen_ke: torch.Tensor, # (M,) int32 +) -> torch.Tensor: + M, H, D = q_fp8.shape + N = k_fp8.shape[0] + assert k_fp8.shape[1] == D + assert weights.shape == (M, H) + assert k_scale.shape == (N,) + + logits = torch.empty((M, N), dtype=torch.float32, device=q_fp8.device) + BLOCK_N = 64 + + grid = (M, triton.cdiv(N, BLOCK_N)) + _mqa_logits_prefill_kernel[grid]( + q_fp8, + weights, + k_fp8, + k_scale, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + q_fp8.stride(0), + q_fp8.stride(1), + q_fp8.stride(2), + weights.stride(0), + weights.stride(1), + k_fp8.stride(0), + k_fp8.stride(1), + logits.stride(0), + logits.stride(1), + M, + N, + H=H, + D=D, + BLOCK_N=BLOCK_N, + ) + return logits + + +def _mqa_logits_torch( + q_fp8: torch.Tensor, # (M, H, D) + k_fp8: torch.Tensor, # (N, D) + k_scale: torch.Tensor, # (N,) fp32 + weights: torch.Tensor, # (M, H) fp32 + cu_seqlen_ks: torch.Tensor, # (M,) int32 + cu_seqlen_ke: torch.Tensor, # (M,) int32 +) -> torch.Tensor: + """Reference impl mirroring ``fp8_mqa_logits_torch`` (DeepGEMM test). Only + used for unit tests; production should always go through the Triton path + because this materializes a (H, M, N) fp32 intermediate. + """ + N = k_fp8.shape[0] + q = q_fp8.to(torch.bfloat16) + k = k_fp8.to(torch.bfloat16) + + arange_n = torch.arange(N, device=q.device) + mask = (arange_n[None, :] >= cu_seqlen_ks[:, None]) & ( + arange_n[None, :] < cu_seqlen_ke[:, None] + ) + + # (H, M, N) fp32; relu must be applied per-head BEFORE the weighted sum. + score = torch.einsum("mhd,nd->hmn", q, k).float() * k_scale + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + return logits + + +def _mqa_logits( + q_fp8: torch.Tensor, + k_fp8: torch.Tensor, + k_scale: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Dispatch to the Triton kernel when available; fall back to torch + reference for environments without a working Triton runtime.""" + if HAS_TRITON: + return _mqa_logits_triton( + q_fp8, k_fp8, k_scale, weights, cu_seqlen_ks, cu_seqlen_ke + ) + return _mqa_logits_torch( + q_fp8, k_fp8, k_scale, weights, cu_seqlen_ks, cu_seqlen_ke + ) + + +def _mqa_logits_paged_torch( + q_fp8: torch.Tensor, # (B, next_n, H, D) + kv_cache_4d: torch.Tensor, # (num_blocks, block_size, 1, D + scale_pad) + weights: torch.Tensor, # (B*next_n, H) fp32 + context_lens: torch.Tensor, # (B,) int32 (or (B, next_n)) + block_tables: torch.Tensor, # (B, max_blocks) int32 + max_model_len: int, + head_dim: int, +) -> torch.Tensor: + """Per-batch torch implementation of the paged MQA-logits compute. Walks + each batch element's block_table, dequantizes the FP8 K-cache slot, and + accumulates per-head relu-weighted logits. Slow but correct, and only + materializes one block's worth of intermediate at a time. + + Mirrors ``fp8_paged_mqa_logits_torch`` in ``rocm_aiter_mla_sparse.py`` but + keeps the (H, ...) intermediate scoped to a single block. + """ + from vllm.utils.math_utils import cdiv + + fp8_dtype = current_platform.fp8_dtype() + batch_size, next_n, H, D = q_fp8.shape + + # Cache layout: last dim = D fp8 + 4 byte (1 fp32) scale per token. + kv_values = kv_cache_4d[..., :head_dim] # uint8 + kv_scale = kv_cache_4d[..., head_dim:] # uint8 (4 bytes per slot) + + num_block, block_size, _, _ = kv_values.size() + + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_fp8.device, + dtype=torch.float32, + ) + + # Normalize context_lens to (B,). + if context_lens.dim() == 2: + context_lens_b = context_lens[:, 0] + else: + context_lens_b = context_lens + ctx_lens = context_lens_b.tolist() + + q_bf16 = q_fp8.to(torch.bfloat16) + weights_f32 = weights.to(torch.float32) + + for i in range(batch_size): + ctx_len = ctx_lens[i] + if ctx_len <= 0: + continue + # Per-token weight slice for this batch element. + # weight_slice shape: (H, next_n) + weight_slice = ( + weights_f32[i * next_n : (i + 1) * next_n, :] + .transpose(0, 1) + .contiguous() + ) + + for block_rk in range(cdiv(ctx_len, block_size)): + phys_block = int(block_tables[i, block_rk].item()) + # K block: (block_size, D) bf16 = fp8 dequant * fp32 scale. + k_fp8_block = ( + kv_values[phys_block, :, 0, :] + .view(fp8_dtype) + .to(torch.bfloat16) + ) + k_scale_block = ( + kv_scale[phys_block, :, 0, :].contiguous().view(torch.float32) + ) # (block_size, 1) + k_block_bf16 = k_fp8_block * k_scale_block.to(torch.bfloat16) + + # Compute (H, next_n, block_size) scores in fp32. + qx = q_bf16[i] # (next_n, H, D) + score = ( + torch.einsum("nhd,sd->hns", qx, k_block_bf16).float() + ) + + # Per-head relu before weighting. weight_slice: (H, next_n) + score = score.relu() * weight_slice.unsqueeze(-1) + block_logits = score.sum(dim=0) # (next_n, block_size) + + # Mask k positions beyond ctx_len within this block. + n_start = block_rk * block_size + n_end = min((block_rk + 1) * block_size, ctx_len) + valid = n_end - n_start + if valid <= 0: + continue + logits[ + i * next_n : (i + 1) * next_n, + n_start:n_end, + ] = block_logits[:, :valid] + + return logits + + +# --------------------------------------------------------------------------- +# Custom op: orchestration that mirrors the CUDA sparse_attn_indexer body +# but assumes ``skip_k_cache_insert=True`` (the V4 layout) and uses only +# ROCm-available helpers. +# --------------------------------------------------------------------------- +def rocm_sparse_attn_indexer_no_insert( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + k_cache_prefix = _resolve_layer_name(k_cache_prefix) + + # Profile-run path: no real attn_metadata; just reserve workspace and + # the dummy logits buffer for the memory profiler (matches the shape / + # dtype the runtime path will actually use). + if not isinstance(attn_metadata, dict): + values_spec, scales_spec = _gather_workspace_shapes_fp8( + total_seq_lens, head_dim, fp8_dtype + ) + current_workspace_manager().get_simultaneous(values_spec, scales_spec) + max_logits_elems = ( + envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + ) + _ = torch.empty( + max_logits_elems, + dtype=torch.uint8, + device=hidden_states.device, + ) + if topk_indices_buffer is None: + return torch.empty( + (hidden_states.shape[0], topk_tokens), + dtype=torch.int32, + device=hidden_states.device, + ) + return topk_indices_buffer + + layer_attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata) + assert topk_indices_buffer is not None + + has_decode = layer_attn_metadata.num_decodes > 0 + has_prefill = layer_attn_metadata.num_prefills > 0 + num_decode_tokens = layer_attn_metadata.num_decode_tokens + + # NOTE: K-cache insert is INTENTIONALLY skipped here. DeepSeek-V4's + # compressor (DeepseekCompressor.forward) writes the compressed K to the + # indexer's KV cache via its fused triton kernel before this op is called, + # and the call site passes k=None. + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + + if has_prefill: + prefill_metadata = layer_attn_metadata.prefill + assert prefill_metadata is not None + for chunk in prefill_metadata.chunks: + # Reuse the workspace to gather the FP8 K + scale for this chunk. + workspace_manager = current_workspace_manager() + values_spec, scales_spec = _gather_workspace_shapes_fp8( + total_seq_lens, head_dim, fp8_dtype + ) + k_quant_full, k_scale_full = ( + workspace_manager.get_simultaneous(values_spec, scales_spec) + ) + k_quant = k_quant_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] + + ops.cp_gather_indexer_k_quant_cache( + kv_cache, + k_quant, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + + q_slice = q_fp8[chunk.token_start : chunk.token_end] + w_slice = weights[chunk.token_start : chunk.token_end] + k_scale_f32 = k_scale.view(torch.float32).squeeze(-1) + + logits = _mqa_logits( + q_slice, + k_quant, + k_scale_f32, + w_slice, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = layer_attn_metadata.decode + assert decode_metadata is not None + + # The kv_cache stored shape is (num_blocks, block_size, head_dim+pad); + # paged-mqa-logits expects an extra "n_head" singleton dim. + kv_cache_4d = kv_cache.unsqueeze(-2) + + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + # Slow-but-correct paged compute. Future Triton kernel TODO: walk the + # block_table on-device to avoid the per-batch python loop and the + # per-block (H, next_n, block_size) intermediate. + logits = _mqa_logits_paged_torch( + padded_q_fp8_decode_tokens, + kv_cache_4d, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + max_model_len, + head_dim, + ) + + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[ + : topk_indices.shape[0], : topk_indices.shape[-1] + ] = topk_indices + + return topk_indices_buffer + + +def rocm_sparse_attn_indexer_no_insert_fake( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # Mirror rocm_aiter_sparse_attn_indexer_fake's profile-run estimate so + # vllm's memory profiler accounts for the gather workspace. + fp8_dtype = current_platform.fp8_dtype() + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], + device=q_fp8.device, + dtype=torch.uint8, + ) + _ = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _ = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + if topk_indices_buffer is None: + return torch.empty( + (hidden_states.shape[0], topk_tokens), + dtype=torch.int32, + device=q_fp8.device, + ) + return topk_indices_buffer + + +# Register as a vllm custom op so vllm's compile / dispatch infrastructure +# treats it the same as the existing sparse_attn_indexer ops. +direct_register_custom_op( + op_name="rocm_sparse_attn_indexer_no_insert", + op_func=rocm_sparse_attn_indexer_no_insert, + mutates_args=["topk_indices_buffer"], + fake_impl=rocm_sparse_attn_indexer_no_insert_fake, + dispatch_key=current_platform.dispatch_key, +) From 91967c76d4d99920949e7d9aac491da813e6117f Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Tue, 28 Apr 2026 17:59:03 +0000 Subject: [PATCH 2/8] mori-io connector changes for PD Disaggregation --- .../v1/moriio/moriio_connector.py | 119 ++++++++++++++++-- 1 file changed, 106 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 15aca3e571cc..baedea775cdf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -408,9 +408,18 @@ def update_state_after_alloc( ) else: - # WRITE mode: prefill scheduler notifies the decode side that - # blocks are ready. Parse the decode's host/notify_port from - # the request_id + # WRITE mode: this branch only fires on the decode-side + # scheduler (the toy proxy sets do_remote_prefill=True only on + # decode-bound requests). The decode tells the prefill which + # blocks to RDMA-write into, so we need the *prefill's* + # host/notify_port from the request_id. + # get_peer_zmq_from_request_id() takes the *caller's* role and + # returns the peer's address; passing self.is_producer=False + # on the decode side resolves to the prefill address. + # Hardcoding True here used to make the decode send the + # block-notify message to its own notify port, where the + # consumer-role assertion in + # MoRIIOWrapper._handle_structured_message would fail. assert request.kv_transfer_params is not None, ( "kv_transfer_params should not be None" ) @@ -418,7 +427,7 @@ def update_state_after_alloc( remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) peer_zmq = get_peer_zmq_from_request_id( - request.request_id, is_producer=True + request.request_id, is_producer=self.is_producer ) remote_host, _, remote_notify_port = parse_moriio_zmq_address(peer_zmq) @@ -770,17 +779,83 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.use_mla = self.model_config.use_mla self.built_session = False self.built_write_session: defaultdict[str, list] = defaultdict(list) - backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - use_mla=self.use_mla, - ) + # DeepSeek V4 sparse attention switches the cache layout to + # "fp8_ds_mla" inside the attention layer. The platform's attention + # selector only exposes the sparse-MLA backend (which understands that + # cache dtype) when use_sparse=True is requested. Detect that case from + # the configured cache dtype so the connector's backend probe matches + # what the model actually instantiated. + self.use_sparse = self.cache_config.cache_dtype == "fp8_ds_mla" self.transfer_id_to_request_id: dict[TransferId, ReqId] = {} + # The platform selector cannot describe every backend a model may pick + # for itself. DeepSeek V4 in particular returns its own + # DeepseekV4FlashMLASparseBackend from the attention layer's + # get_attn_backend() and never goes through get_attn_backend() at the + # platform level. On ROCm there is no platform candidate registered for + # (use_mla=True, use_sparse=True, kv_cache_dtype=fp8_ds_mla), so the + # selector raises here even though the model is running fine via the + # ROCm FlashMLA fallbacks. self.backend_name is only used as + # informational metadata in the MoRIIO handshake (no dispatch keys off + # it), so probe optimistically and fall back to the actual backend the + # model instantiated. # TODO: consider the integration of flashinfer or other backends. - self.backend_name = backend.get_name() - logger.debug("Detected attention backend %s", self.backend_name) + try: + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + use_mla=self.use_mla, + use_sparse=self.use_sparse, + ) + self.backend_name = backend.get_name() + except ValueError as exc: + self.backend_name = self._infer_backend_name_from_model( + vllm_config + ) + logger.warning( + "Platform attention selector has no entry for " + "(use_mla=%s, use_sparse=%s, kv_cache_dtype=%s); using '%s' " + "for MoRIIO handshake metadata. Underlying selector error: %s", + self.use_mla, + self.use_sparse, + self.cache_config.cache_dtype, + self.backend_name, + exc, + ) + else: + logger.debug("Detected attention backend %s", self.backend_name) + + @staticmethod + def _infer_backend_name_from_model(vllm_config: VllmConfig) -> str: + """Recover the attention backend name from a model that bypasses the + platform selector (e.g., DeepSeek V4 returning its own backend class + from the attention layer's get_attn_backend()). + + Walks vllm_config.compilation_config.static_forward_context, which + holds the instantiated attention layers, and returns the name of the + first backend class advertised by any of them. Falls back to a + sentinel string if introspection fails. + """ + sentinel = "UNREGISTERED" + try: + forward_context = ( + vllm_config.compilation_config.static_forward_context + ) + except AttributeError: + return sentinel + for layer in forward_context.values(): + getter = getattr(layer, "get_attn_backend", None) + if getter is None: + continue + try: + backend_cls = getter() + name = backend_cls.get_name() + except Exception: + continue + if isinstance(name, str) and name: + return name + return sentinel def schedule_write_blocks( self, @@ -1174,7 +1249,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if layer_name not in self.layer_name_to_local_kv_cache_metadata: self.layer_name_to_local_kv_cache_metadata[layer_name] = [] - moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache) + # MoRIIO's register_torch_tensor requires a contiguous tensor for + # RDMA pinning. Some KV cache layouts (e.g. DeepSeek V4's MLA spec + # with alignment=576) wrap the underlying contiguous int8 storage + # in a torch.as_strided() view that is non-contiguous by design + # (see vllm/v1/worker/gpu_model_runner.py::_reshape_kv_cache_tensors + # and MLAAttentionSpec.page_size_padded). For those, register the + # underlying storage as a flat 1D uint8 alias instead -- same + # memory, same data_ptr(), so all downstream addressing + # (kv_caches_base_addr, schedule_write_blocks, etc.) is unchanged. + if kv_cache.is_contiguous(): + register_target = kv_cache + else: + register_target = torch.empty( + 0, dtype=torch.uint8, device=kv_cache.device + ).set_(kv_cache.untyped_storage()) + + moriio_mem_metadata = self.moriio_wrapper.register_local_tensor( + register_target + ) self.layer_name_to_local_kv_cache_metadata[layer_name].append( moriio_mem_metadata ) From 1eb6385b3df0e25cb5a7efbf2f196dd6388501f5 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Thu, 7 May 2026 10:31:56 +0000 Subject: [PATCH 3/8] fix accuracy issues with fused_deepseek_v4_qnorm_rope_kv_insert_kernel --- vllm/config/compilation.py | 1 + vllm/envs.py | 27 + .../layers/deepseek_v4_attention.py | 58 +- .../layers/sparse_attn_indexer.py | 27 + .../attention/ops/rocm_sparse_attn_indexer.py | 549 ++++++++++++++++++ 5 files changed, 653 insertions(+), 9 deletions(-) create mode 100644 vllm/v1/attention/ops/rocm_sparse_attn_indexer.py diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 0f02a92681c1..d5fa087a329a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -749,6 +749,7 @@ class CompilationConfig: "vllm::kda_attention", "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", + "vllm::rocm_sparse_attn_indexer_no_insert", "vllm::deepseek_v4_attention", ] diff --git a/vllm/envs.py b/vllm/envs.py index ded474dc085a..73e4b147f88b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -126,6 +126,7 @@ VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False + VLLM_ROCM_USE_V4_TRITON_FALLBACK: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -1077,6 +1078,32 @@ def _get_or_set_default() -> str: "VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: ( os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1") ), + # Master switch for the pre-rebase ROCm-native code paths used by + # DeepSeek-V4 (DSv4-Flash-FP8). When True (default on ROCm) the model + # selects the validated pre-rebase implementations at four call sites: + # + # 1. SWA K-cache writer: torch reference + # (``_deepseek_v4_qnorm_rope_kv_insert_reference``) instead of + # upstream's HIPified ``fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert`` + # C++ kernel, whose FP8 dtype is selected at compile time + # (``HIP_FP8_TYPE_OCP``) and silently corrupts every K byte on + # MI300X (FNUZ-only). This is the regression fix; the other three + # below are kept for defense in depth and bisection. + # 2. MLA decode: ``flash_mla_with_kvcache_rocm`` Triton kernel + # (95% GSM8K validated) instead of upstream's + # ``rocm_forward_decode_fallback``. + # 3. MLA sparse prefill: ``flash_mla_sparse_fwd_rocm`` Triton kernel + # instead of upstream's ``rocm_sparse_attn_prefill``. + # 4. Sparse indexer: recovered ``rocm_sparse_attn_indexer_no_insert`` + # orchestration instead of upstream's + # ``rocm_aiter_sparse_attn_indexer_native``. + # + # Set to "0" to opt back into the upstream paths for bisection / perf + # comparison (note: requires the SWA writer fix below to also be in place + # — flipping this alone reproduces the deterministic-garbage regression). + "VLLM_ROCM_USE_V4_TRITON_FALLBACK": lambda: ( + os.getenv("VLLM_ROCM_USE_V4_TRITON_FALLBACK", "True").lower() in ("true", "1") + ), # Custom quick allreduce kernel for MI3* cards # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index f32c25649dab..92dc0b6dc12b 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -636,7 +636,21 @@ def _fused_qnorm_rope_kv_insert( block_sz = swa_metadata.block_size slot_map = swa_metadata.slot_mapping - if fused_op is not None: + # Commit 628c43630 wired the + # kernel into the ROCm build, but its FP8 dtype is selected at + # *compile time* via ``HIP_FP8_TYPE_OCP`` whereas MI300X (gfx942) + # is FNUZ-only at runtime — a mismatch silently corrupts every K + # byte written to the SWA cache. Force the Python reference on + # ROCm under ``VLLM_ROCM_USE_V4_TRITON_FALLBACK`` so we match the + # pre-rebase numerics; flip the env var to "0" to opt back into + # the upstream C++ kernel for bisection. + # TODO: fix in the next commit. + use_torch_ref = ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_V4_TRITON_FALLBACK + ) + + if fused_op is not None and not use_torch_ref: fused_op(q, kv, swa_kv_cache_2d, slot_map, pos_i64, cos_sin, self.eps, block_sz) else: _deepseek_v4_qnorm_rope_kv_insert_reference( @@ -1040,7 +1054,16 @@ def _forward_decode( swa_indices = swa_metadata.decode_swa_indices swa_lens = swa_metadata.decode_swa_lens - if current_platform.is_rocm(): + # When VLLM_ROCM_USE_V4_TRITON_FALLBACK is enabled (default on ROCm), + # we deliberately skip the upstream `rocm_forward_decode_fallback` and + # let the standard `flash_mla_with_kvcache` call below run. That call + # is mapped by `vllm.v1.attention.ops.flashmla` to our pre-rebase + # `flash_mla_with_kvcache_rocm` Triton/online-softmax fallback, which + # is the path that produced 95% GSM8K accuracy. The upstream torch + # reference (`rocm_ref_sparse_attn_decode`) has its own bugs that + # collapse generation to the base-model prior, so we keep it gated as + # an opt-in fallback for bisection only. + if current_platform.is_rocm() and not envs.VLLM_ROCM_USE_V4_TRITON_FALLBACK: rocm_forward_decode_fallback( q=q, kv_cache=kv_cache, @@ -1088,12 +1111,19 @@ def _forward_decode( f"Unsupported compress_ratio={self.compress_ratio}; " "expected 1, 4, or 128." ) - assert tile_metadata is not None, ( - "swa_metadata missing tile_sched entry for " - f"compress_ratio={self.compress_ratio}; " - "DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not " - "allocate one for this layer type." - ) + # FlashMLA's tile-scheduler metadata is an NVIDIA-only planner state + # consumed by the C++/CUDA kernel. Our ROCm fallback + # (`flash_mla_with_kvcache_rocm`) discards `tile_scheduler_metadata` + # entirely, and `DeepseekSparseSWAMetadataBuilder.build_tile_scheduler` + # (correctly) skips allocating it on ROCm — so a `None` here is + # expected on AMD and only an error on CUDA. + if not current_platform.is_rocm(): + assert tile_metadata is not None, ( + "swa_metadata missing tile_sched entry for " + f"compress_ratio={self.compress_ratio}; " + "DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did " + "not allocate one for this layer type." + ) out, _ = flash_mla_with_kvcache( q=q, @@ -1223,7 +1253,17 @@ def _forward_prefill( N, ) - if current_platform.is_rocm(): + # See the matching comment in `_forward_decode`: by default + # (VLLM_ROCM_USE_V4_TRITON_FALLBACK=True) we send the prefill + # forward through `flash_mla_sparse_fwd`, which on ROCm is bound + # to our pre-rebase `flash_mla_sparse_fwd_rocm` chunked-online- + # softmax kernel via `vllm.v1.attention.ops.flashmla`. Set the env + # var to "0" to opt back into upstream's `rocm_sparse_attn_prefill` + # torch reference (kept for bisection / regression testing). + if ( + current_platform.is_rocm() + and not envs.VLLM_ROCM_USE_V4_TRITON_FALLBACK + ): rocm_sparse_attn_prefill( q=q[query_start:query_end], kv=kv.view(-1, 1, q.shape[-1]), diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 4bf52a49c43f..d73590638090 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -503,6 +503,33 @@ def forward_hip( assert isinstance(q_quant, torch.Tensor), ( "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" ) + + # We only take this path when the + # compressor has already inserted K (skip_k_cache_insert=True), AITER + # is off, and the env-var gate is on (default). Falls through to the + # upstream native path otherwise. + if ( + self.skip_k_cache_insert + and not rocm_aiter_ops.is_enabled() + and envs.VLLM_ROCM_USE_V4_TRITON_FALLBACK + ): + # Import lazily so non-ROCm builds don't pay the import cost. + import vllm.v1.attention.ops.rocm_sparse_attn_indexer # noqa: F401 + + return torch.ops.vllm.rocm_sparse_attn_indexer_no_insert( + hidden_states, + _encode_layer_name(self.k_cache.prefix), + self.k_cache.kv_cache, + q_quant, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) if self.skip_k_cache_insert or not rocm_aiter_ops.is_enabled(): from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( rocm_aiter_sparse_attn_indexer_native, diff --git a/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py b/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py new file mode 100644 index 000000000000..5935323d1f80 --- /dev/null +++ b/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py @@ -0,0 +1,549 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ROCm path for DeepSeek-V4's sparse attention indexer when the compressor +has already inserted the (compressed) K into the indexer's KV cache, i.e. +``skip_k_cache_insert=True`` and the call site passes ``k=None``. + +The CUDA implementation in ``vllm/model_executor/layers/sparse_attn_indexer.py`` +relies on DeepGEMM's ``fp8_fp4_mqa_logits`` / ``fp8_fp4_paged_mqa_logits`` +which are NVIDIA-only. The existing ROCm AITER op +(``rocm_aiter_sparse_attn_indexer`` in ``rocm_aiter_mla_sparse.py``) always +performs its own ``indexer_k_quant_and_cache`` call and dereferences ``k``, +so it can't be reused for the V4 layout where the compressor pre-inserts K +and returns ``None``. + +This module fills that gap with: + * A streaming Triton MQA-logits kernel that runs on gfx9xx, computing + logits without materializing the (H, M, N) intermediate that the torch + reference does (which would OOM at long context). + * A torch fallback (``_mqa_logits_torch_inplace``) used for smoke tests and + on platforms without a usable Triton runtime. + * The orchestration (``rocm_sparse_attn_indexer_no_insert``) that mirrors + the CUDA ``sparse_attn_indexer`` body but skips the K-insert and uses + only ROCm-available helper ops (``cp_gather_indexer_k_quant_cache``, + ``top_k_per_row_prefill`` / ``top_k_per_row_decode``). +""" +from __future__ import annotations + +import torch + +import vllm.envs as envs +from vllm.forward_context import get_forward_context +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON, tl, triton +from vllm.utils.torch_utils import ( + LayerNameType, + _resolve_layer_name, + direct_register_custom_op, +) +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata +from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.v1.worker.workspace import current_workspace_manager + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops + + +# Reuse the gather-workspace helper from the CUDA module so the workspace +# layout (and therefore the size estimate during profile_run) is shared. +def _gather_workspace_shapes_fp8( + total_seq_lens: int, + head_dim: int, + fp8_dtype: torch.dtype, +) -> tuple[ + tuple[tuple[int, int], torch.dtype], tuple[tuple[int, int], torch.dtype] +]: + """FP8 path layout used by ``cp_gather_indexer_k_quant_cache``: a flat + ``(T, head_dim)`` FP8 values buffer and a ``(T, 4)`` uint8 buffer that + aliases ``(T, 1)`` float32 dequant scales (one scale per token block). + Mirrors the FP8 branch of ``_gather_workspace_shapes`` in + ``sparse_attn_indexer.py``. + """ + return ( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + + +# --------------------------------------------------------------------------- +# Triton MQA-logits kernel (prefill / chunked path). +# +# Computes `logits[m, n] = scale[n] * sum_h weights[m, h] * relu(q[m,h,:] . k[n,:])` +# without materializing the (H, M, N) intermediate. Streams over heads so the +# only per-program memory is (BLOCK_N,) accumulator + (D,) Q + (BLOCK_N, D) K. +# --------------------------------------------------------------------------- +if HAS_TRITON: + + @triton.jit + def _mqa_logits_prefill_kernel( + q_ptr, # (M, H, D) fp8 + weights_ptr, # (M, H) fp32 + k_ptr, # (N, D) fp8 + k_scale_ptr, # (N,) fp32 + cu_seqlen_ks_ptr, # (M,) int32 + cu_seqlen_ke_ptr, # (M,) int32 + logits_ptr, # (M, N) fp32 (output) + stride_qm, + stride_qh, + stride_qd, + stride_wm, + stride_wh, + stride_kn, + stride_kd, + stride_lm, + stride_ln, + M, + N, + H: tl.constexpr, + D: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + in_bounds = n_offsets < N + + ks = tl.load(cu_seqlen_ks_ptr + pid_m) + ke = tl.load(cu_seqlen_ke_ptr + pid_m) + valid = in_bounds & (n_offsets >= ks) & (n_offsets < ke) + + d_offsets = tl.arange(0, D) + + # Load K block once and reuse across heads: (BLOCK_N, D) fp32. + k_block = tl.load( + k_ptr + + n_offsets[:, None] * stride_kn + + d_offsets[None, :] * stride_kd, + mask=valid[:, None], + other=0.0, + ).to(tl.float32) + + accum = tl.zeros([BLOCK_N], dtype=tl.float32) + + for h in range(H): + q = tl.load( + q_ptr + + pid_m * stride_qm + + h * stride_qh + + d_offsets * stride_qd, + ).to(tl.float32) + w = tl.load(weights_ptr + pid_m * stride_wm + h * stride_wh).to( + tl.float32 + ) + + score = tl.sum(k_block * q[None, :], axis=1) + accum += w * tl.maximum(score, 0.0) + + k_scale = tl.load(k_scale_ptr + n_offsets, mask=valid, other=0.0) + logits = accum * k_scale + logits = tl.where(valid, logits, float("-inf")) + + tl.store( + logits_ptr + pid_m * stride_lm + n_offsets * stride_ln, + logits, + mask=in_bounds, + ) + + +def _mqa_logits_triton( + q_fp8: torch.Tensor, # (M, H, D) + k_fp8: torch.Tensor, # (N, D) + k_scale: torch.Tensor, # (N,) fp32 + weights: torch.Tensor, # (M, H) fp32 + cu_seqlen_ks: torch.Tensor, # (M,) int32 + cu_seqlen_ke: torch.Tensor, # (M,) int32 +) -> torch.Tensor: + M, H, D = q_fp8.shape + N = k_fp8.shape[0] + assert k_fp8.shape[1] == D + assert weights.shape == (M, H) + assert k_scale.shape == (N,) + + logits = torch.empty((M, N), dtype=torch.float32, device=q_fp8.device) + BLOCK_N = 64 + + grid = (M, triton.cdiv(N, BLOCK_N)) + _mqa_logits_prefill_kernel[grid]( + q_fp8, + weights, + k_fp8, + k_scale, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + q_fp8.stride(0), + q_fp8.stride(1), + q_fp8.stride(2), + weights.stride(0), + weights.stride(1), + k_fp8.stride(0), + k_fp8.stride(1), + logits.stride(0), + logits.stride(1), + M, + N, + H=H, + D=D, + BLOCK_N=BLOCK_N, + ) + return logits + + +def _mqa_logits_torch( + q_fp8: torch.Tensor, # (M, H, D) + k_fp8: torch.Tensor, # (N, D) + k_scale: torch.Tensor, # (N,) fp32 + weights: torch.Tensor, # (M, H) fp32 + cu_seqlen_ks: torch.Tensor, # (M,) int32 + cu_seqlen_ke: torch.Tensor, # (M,) int32 +) -> torch.Tensor: + """Reference impl mirroring ``fp8_mqa_logits_torch`` (DeepGEMM test). Only + used for unit tests; production should always go through the Triton path + because this materializes a (H, M, N) fp32 intermediate. + """ + N = k_fp8.shape[0] + q = q_fp8.to(torch.bfloat16) + k = k_fp8.to(torch.bfloat16) + + arange_n = torch.arange(N, device=q.device) + mask = (arange_n[None, :] >= cu_seqlen_ks[:, None]) & ( + arange_n[None, :] < cu_seqlen_ke[:, None] + ) + + # (H, M, N) fp32; relu must be applied per-head BEFORE the weighted sum. + score = torch.einsum("mhd,nd->hmn", q, k).float() * k_scale + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + return logits + + +def _mqa_logits( + q_fp8: torch.Tensor, + k_fp8: torch.Tensor, + k_scale: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Dispatch to the Triton kernel when available; fall back to torch + reference for environments without a working Triton runtime.""" + if HAS_TRITON: + return _mqa_logits_triton( + q_fp8, k_fp8, k_scale, weights, cu_seqlen_ks, cu_seqlen_ke + ) + return _mqa_logits_torch( + q_fp8, k_fp8, k_scale, weights, cu_seqlen_ks, cu_seqlen_ke + ) + + +def _mqa_logits_paged_torch( + q_fp8: torch.Tensor, # (B, next_n, H, D) + kv_cache_4d: torch.Tensor, # (num_blocks, block_size, 1, D + scale_pad) + weights: torch.Tensor, # (B*next_n, H) fp32 + context_lens: torch.Tensor, # (B,) int32 (or (B, next_n)) + block_tables: torch.Tensor, # (B, max_blocks) int32 + max_model_len: int, + head_dim: int, +) -> torch.Tensor: + """Per-batch torch implementation of the paged MQA-logits compute. Walks + each batch element's block_table, dequantizes the FP8 K-cache slot, and + accumulates per-head relu-weighted logits. Slow but correct, and only + materializes one block's worth of intermediate at a time. + + Mirrors ``fp8_paged_mqa_logits_torch`` in ``rocm_aiter_mla_sparse.py`` but + keeps the (H, ...) intermediate scoped to a single block. + """ + from vllm.utils.math_utils import cdiv + + fp8_dtype = current_platform.fp8_dtype() + batch_size, next_n, H, D = q_fp8.shape + + # Cache layout: last dim = D fp8 + 4 byte (1 fp32) scale per token. + kv_values = kv_cache_4d[..., :head_dim] # uint8 + kv_scale = kv_cache_4d[..., head_dim:] # uint8 (4 bytes per slot) + + num_block, block_size, _, _ = kv_values.size() + + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_fp8.device, + dtype=torch.float32, + ) + + # Normalize context_lens to (B,). + if context_lens.dim() == 2: + context_lens_b = context_lens[:, 0] + else: + context_lens_b = context_lens + ctx_lens = context_lens_b.tolist() + + q_bf16 = q_fp8.to(torch.bfloat16) + weights_f32 = weights.to(torch.float32) + + for i in range(batch_size): + ctx_len = ctx_lens[i] + if ctx_len <= 0: + continue + # Per-token weight slice for this batch element. + # weight_slice shape: (H, next_n) + weight_slice = ( + weights_f32[i * next_n : (i + 1) * next_n, :] + .transpose(0, 1) + .contiguous() + ) + + for block_rk in range(cdiv(ctx_len, block_size)): + phys_block = int(block_tables[i, block_rk].item()) + # K block: (block_size, D) bf16 = fp8 dequant * fp32 scale. + k_fp8_block = ( + kv_values[phys_block, :, 0, :] + .view(fp8_dtype) + .to(torch.bfloat16) + ) + k_scale_block = ( + kv_scale[phys_block, :, 0, :].contiguous().view(torch.float32) + ) # (block_size, 1) + k_block_bf16 = k_fp8_block * k_scale_block.to(torch.bfloat16) + + # Compute (H, next_n, block_size) scores in fp32. + qx = q_bf16[i] # (next_n, H, D) + score = ( + torch.einsum("nhd,sd->hns", qx, k_block_bf16).float() + ) + + # Per-head relu before weighting. weight_slice: (H, next_n) + score = score.relu() * weight_slice.unsqueeze(-1) + block_logits = score.sum(dim=0) # (next_n, block_size) + + # Mask k positions beyond ctx_len within this block. + n_start = block_rk * block_size + n_end = min((block_rk + 1) * block_size, ctx_len) + valid = n_end - n_start + if valid <= 0: + continue + logits[ + i * next_n : (i + 1) * next_n, + n_start:n_end, + ] = block_logits[:, :valid] + + return logits + + +# --------------------------------------------------------------------------- +# Custom op: orchestration that mirrors the CUDA sparse_attn_indexer body +# but assumes ``skip_k_cache_insert=True`` (the V4 layout) and uses only +# ROCm-available helpers. +# --------------------------------------------------------------------------- +def rocm_sparse_attn_indexer_no_insert( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + k_cache_prefix = _resolve_layer_name(k_cache_prefix) + + # Profile-run path: no real attn_metadata; just reserve workspace and + # the dummy logits buffer for the memory profiler (matches the shape / + # dtype the runtime path will actually use). + if not isinstance(attn_metadata, dict): + values_spec, scales_spec = _gather_workspace_shapes_fp8( + total_seq_lens, head_dim, fp8_dtype + ) + current_workspace_manager().get_simultaneous(values_spec, scales_spec) + max_logits_elems = ( + envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + ) + _ = torch.empty( + max_logits_elems, + dtype=torch.uint8, + device=hidden_states.device, + ) + if topk_indices_buffer is None: + return torch.empty( + (hidden_states.shape[0], topk_tokens), + dtype=torch.int32, + device=hidden_states.device, + ) + return topk_indices_buffer + + layer_attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata) + assert topk_indices_buffer is not None + + has_decode = layer_attn_metadata.num_decodes > 0 + has_prefill = layer_attn_metadata.num_prefills > 0 + num_decode_tokens = layer_attn_metadata.num_decode_tokens + + # NOTE: K-cache insert is INTENTIONALLY skipped here. DeepSeek-V4's + # compressor (DeepseekCompressor.forward) writes the compressed K to the + # indexer's KV cache via its fused triton kernel before this op is called, + # and the call site passes k=None. + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + + if has_prefill: + prefill_metadata = layer_attn_metadata.prefill + assert prefill_metadata is not None + for chunk in prefill_metadata.chunks: + # Reuse the workspace to gather the FP8 K + scale for this chunk. + workspace_manager = current_workspace_manager() + values_spec, scales_spec = _gather_workspace_shapes_fp8( + total_seq_lens, head_dim, fp8_dtype + ) + k_quant_full, k_scale_full = ( + workspace_manager.get_simultaneous(values_spec, scales_spec) + ) + k_quant = k_quant_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] + + ops.cp_gather_indexer_k_quant_cache( + kv_cache, + k_quant, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + + q_slice = q_fp8[chunk.token_start : chunk.token_end] + w_slice = weights[chunk.token_start : chunk.token_end] + k_scale_f32 = k_scale.view(torch.float32).squeeze(-1) + + logits = _mqa_logits( + q_slice, + k_quant, + k_scale_f32, + w_slice, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = layer_attn_metadata.decode + assert decode_metadata is not None + + # The kv_cache stored shape is (num_blocks, block_size, head_dim+pad); + # paged-mqa-logits expects an extra "n_head" singleton dim. + kv_cache_4d = kv_cache.unsqueeze(-2) + + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + # Slow-but-correct paged compute. Future Triton kernel TODO: walk the + # block_table on-device to avoid the per-batch python loop and the + # per-block (H, next_n, block_size) intermediate. + logits = _mqa_logits_paged_torch( + padded_q_fp8_decode_tokens, + kv_cache_4d, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + max_model_len, + head_dim, + ) + + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[ + : topk_indices.shape[0], : topk_indices.shape[-1] + ] = topk_indices + + return topk_indices_buffer + + +def rocm_sparse_attn_indexer_no_insert_fake( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # Mirror rocm_aiter_sparse_attn_indexer_fake's profile-run estimate so + # vllm's memory profiler accounts for the gather workspace. + fp8_dtype = current_platform.fp8_dtype() + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], + device=q_fp8.device, + dtype=torch.uint8, + ) + _ = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _ = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + if topk_indices_buffer is None: + return torch.empty( + (hidden_states.shape[0], topk_tokens), + dtype=torch.int32, + device=q_fp8.device, + ) + return topk_indices_buffer + + +# Register as a vllm custom op so vllm's compile / dispatch infrastructure +# treats it the same as the existing sparse_attn_indexer ops. +direct_register_custom_op( + op_name="rocm_sparse_attn_indexer_no_insert", + op_func=rocm_sparse_attn_indexer_no_insert, + mutates_args=["topk_indices_buffer"], + fake_impl=rocm_sparse_attn_indexer_no_insert_fake, + dispatch_key=current_platform.dispatch_key, +) From ca40f349aa0d52e1bcc1f4ce459f6bdc130fff41 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Thu, 7 May 2026 10:49:05 +0000 Subject: [PATCH 4/8] clean up un-necessary files --- vllm/triton_utils/__init__.py | 10 +--------- vllm/utils/deep_gemm.py | 9 +-------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 669a6c3c37ad..f4866a702dd9 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -20,12 +20,4 @@ LOG2E = 1.4426950408889634 LOGE2 = 0.6931471805599453 - -__all__ = [ - "HAS_TRITON", - "triton", - "tl", - "tldevice", - "LOG2E", - "LOGE2", -] +__all__ = ["HAS_TRITON", "triton", "tl", "tldevice", "LOG2E", "LOGE2"] diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index c9d8d12c621f..6b89f5c33203 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -469,14 +469,7 @@ def tf32_hc_prenorm_gemm( out = x.float() @ fn.T sqrsum = x.float().square().sum(-1) - See the caller function for shape requirement. - - The DeepGEMM kernel splits the K dimension into ``num_split`` partial - sums for parallelism (``out`` has a leading ``num_split`` axis and the - consumer reduces over it). When DeepGEMM is not available (e.g. on - ROCm), fall back to a single-shot torch matmul written into split 0 - while zeroing the remaining splits, which is mathematically equivalent - after the consumer's reduction. + See the caller function for shape requirement """ _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: From 38f22c9e28e3522d5c7e80d494ba77c1374b3125 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Thu, 7 May 2026 11:27:44 +0000 Subject: [PATCH 5/8] remove un-necessary fallbacks --- .../layers/deepseek_v4_attention.py | 107 +-------------- .../router/fused_topk_bias_router.py | 122 ++---------------- 2 files changed, 12 insertions(+), 217 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 92dc0b6dc12b..bf044ba5e7d6 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -695,99 +695,6 @@ def deepseek_v4_attention_fake( ) -def _fp8_einsum_torch_fallback( - a: torch.Tensor, - a_scale: torch.Tensor, - b: torch.Tensor, - b_scale: torch.Tensor, - out: torch.Tensor, - equation: str, - recipe: tuple[int, int, int], -) -> None: - """Pure-torch reference for DeepseekV4's block-scaled FP8 einsum. - - Used when DeepGEMM's ``fp8_einsum`` is unavailable (e.g. ROCm). Slow - but correct: dequantizes both operands to bf16 using the per-block - scales implied by ``recipe`` and runs the einsum natively. - - Only the einsum/recipe combinations actually emitted by - ``DeepseekV4MLAAttention.forward`` are supported; anything else - raises ``NotImplementedError`` so we fail loudly rather than - silently produce wrong results. - """ - if equation != "bhr,hdr->bhd": - raise NotImplementedError( - f"FP8 einsum torch fallback only supports 'bhr,hdr->bhd' " - f"(DeepseekV4 wo_a projection); got '{equation}'." - ) - - m_block, n_block, k_block = recipe - - # Recover the logical (H, D, R) layout for ``b``. ``wo_a`` is a - # ColumnParallelLinear with ``out_features = n_groups * o_lora_rank`` - # marked ``is_bmm=True``: the weight is stored 2-D as ``(H*D, R)`` with - # H = n_local_groups (the leading group dim) and D = o_lora_rank, and - # the FP8 GEMM treats the leading H slices as batched. Same trick for - # the per-block weight scale ``(H*D/n_block, R/k_block)``. - h_groups = a.shape[-2] - d_out = out.shape[-1] - r_contract = a.shape[-1] - - b_3d = b - if b.dim() == 2: - if b.shape[0] != h_groups * d_out or b.shape[1] != r_contract: - raise RuntimeError( - f"Unexpected wo_a weight shape {tuple(b.shape)}; " - f"expected ({h_groups * d_out}, {r_contract}) for " - f"H={h_groups}, D={d_out}, R={r_contract}." - ) - b_3d = b.view(h_groups, d_out, r_contract) - elif b.dim() != 3: - raise RuntimeError( - f"Expected wo_a weight to be 2-D or 3-D, got {b.dim()}-D" - ) - - n_d_scale = (d_out + n_block - 1) // n_block if n_block > 1 else d_out - n_r_scale = ( - (r_contract + k_block - 1) // k_block if k_block > 1 else r_contract - ) - - b_scale_3d = b_scale - if b_scale.dim() == 2: - if b_scale.shape != (h_groups * n_d_scale, n_r_scale): - raise RuntimeError( - f"Unexpected wo_a scale shape {tuple(b_scale.shape)}; " - f"expected ({h_groups * n_d_scale}, {n_r_scale})." - ) - b_scale_3d = b_scale.view(h_groups, n_d_scale, n_r_scale) - - a_f32 = a.to(torch.float32) - b_f32 = b_3d.to(torch.float32) - a_scale_f32 = a_scale.to(torch.float32).contiguous() - b_scale_f32 = b_scale_3d.to(torch.float32).contiguous() - - # a: (B, H, R) a_scale: (B, H, R // k_block) - a_scale_r = a_scale_f32 - if k_block > 1: - a_scale_r = a_scale_r.repeat_interleave(k_block, dim=-1) - a_scale_r = a_scale_r[..., :r_contract] - if m_block > 1: - a_scale_r = a_scale_r.repeat_interleave(m_block, dim=0)[: a_f32.shape[0]] - a_bf16 = (a_f32 * a_scale_r).to(torch.bfloat16) - - # b: (H, D, R) b_scale: (H, D // n_block, R // k_block) - b_scale_dr = b_scale_f32 - if k_block > 1: - b_scale_dr = b_scale_dr.repeat_interleave(k_block, dim=-1) - if n_block > 1: - b_scale_dr = b_scale_dr.repeat_interleave(n_block, dim=-2) - b_scale_dr = b_scale_dr[..., :d_out, :r_contract] - b_bf16 = (b_f32 * b_scale_dr).to(torch.bfloat16) - - result = torch.einsum(equation, a_bf16, b_bf16) - out.copy_(result.to(out.dtype)) - - def deepseek_v4_fp8_einsum( a: torch.Tensor, a_scale: torch.Tensor, @@ -797,19 +704,7 @@ def deepseek_v4_fp8_einsum( equation: str, recipe: list[int], ) -> None: - # DeepGEMM's fp8_einsum is the canonical fast path on NVIDIA. On - # platforms without it (e.g. ROCm), fall back to a torch dequant + - # einsum reference. The choice is made at call time (not import) so - # this op stays usable in unit tests that mock current_platform. - from vllm.platforms import current_platform - - if current_platform.is_cuda(): - fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) - return - - _fp8_einsum_torch_fallback( - a, a_scale, b, b_scale, out, equation, tuple(recipe) - ) + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) def deepseek_v4_fp8_einsum_fake( diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 0de5983881be..84eaad7f65e6 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -57,87 +57,6 @@ def vllm_topk_sigmoid( return topk_weights, topk_indices -def _topk_softplus_sqrt_torch( - topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool, - routed_scaling_factor: float, - e_score_correction_bias: torch.Tensor | None, - input_tokens: torch.Tensor | None, - hash_indices_table: torch.Tensor | None, -) -> None: - # Reference implementation of csrc/moe/topk_softplus_sqrt_kernels.cu used - # on platforms where the fused kernel is unavailable (e.g. ROCm). Math - # mirrors the kernel exactly: weight_base = sqrt(softplus(x)) per expert, - # bias is added only for ranking (subtracted back from output), then - # optional renormalize + routed_scaling_factor. - num_tokens, num_experts = gating_output.shape - topk = topk_weights.shape[-1] - - # softplus(x) with beta=1 and the same numerical-stability cutoff used by - # the kernel ((val_b > 20) ? val : log1p(exp(val_b)) / beta). - x_f32 = gating_output.to(torch.float32) - softplus_x = torch.nn.functional.softplus(x_f32, beta=1.0, threshold=20.0) - weights_base = torch.sqrt(softplus_x) # (T, E) - - use_hash = ( - input_tokens is not None and hash_indices_table is not None - ) - - if use_hash: - # tid2eid: (V, k); input_tokens: (T,) -> selected_experts: (T, k) - tid2eid = hash_indices_table - selected_experts = tid2eid[input_tokens.to(torch.long)] - selected_weights = torch.gather( - weights_base, -1, selected_experts.to(torch.long) - ) - if renormalize: - denom = selected_weights.sum(dim=-1, keepdim=True) - denom = torch.where( - denom > 0, denom, torch.ones_like(denom) - ) - selected_weights = selected_weights / denom - selected_weights = selected_weights * routed_scaling_factor - - topk_weights.copy_(selected_weights.to(topk_weights.dtype)) - topk_indices.copy_(selected_experts.to(topk_indices.dtype)) - # The CUDA kernel leaves token_expert_indices untouched in the hash - # path, so we mirror that (caller treats it as scratch in this case). - return - - if e_score_correction_bias is not None: - ranking = weights_base + e_score_correction_bias.to(torch.float32) - else: - ranking = weights_base - - _, topk_ids = torch.topk(ranking, topk, dim=-1) - out_weights = torch.gather(weights_base, -1, topk_ids) - if renormalize: - denom = out_weights.sum(dim=-1, keepdim=True) - denom = torch.where(denom > 0, denom, torch.ones_like(denom)) - out_weights = out_weights / denom - out_weights = out_weights * routed_scaling_factor - - topk_weights.copy_(out_weights.to(topk_weights.dtype)) - topk_indices.copy_(topk_ids.to(topk_indices.dtype)) - - # token_expert_indices[t, k_idx] = k_idx * T + t (matches kernel's - # source_rows write at line 388 of topk_softplus_sqrt_kernels.cu). - arange_t = torch.arange( - num_tokens, - device=gating_output.device, - dtype=token_expert_indices.dtype, - ).unsqueeze(-1) - arange_k = torch.arange( - topk, - device=gating_output.device, - dtype=token_expert_indices.dtype, - ).unsqueeze(0) - token_expert_indices.copy_(arange_k * num_tokens + arange_t) - - def vllm_topk_softplus_sqrt( topk_weights: torch.Tensor, topk_indices: torch.Tensor, @@ -149,36 +68,17 @@ def vllm_topk_softplus_sqrt( hash_indices_table: torch.Tensor | None = None, routed_scaling_factor: float = 1.0, ) -> tuple[torch.Tensor, ...]: - # The fused topk_softplus_sqrt CUDA kernel is gated behind #ifndef USE_ROCM - # in csrc/moe/torch_bindings.cpp and the .cu source isn't added to - # VLLM_MOE_EXT_SRC for ROCm builds (CMakeLists.txt). Fall back to a torch - # reference on platforms that don't ship the symbol. - from vllm.platforms import current_platform - - if current_platform.is_cuda(): - ops.topk_hash_softplus_sqrt( - topk_weights, - topk_indices, - token_expert_indices, - gating_output, - renormalize, - routed_scaling_factor, - e_score_correction_bias, - input_tokens, - hash_indices_table, - ) - else: - _topk_softplus_sqrt_torch( - topk_weights, - topk_indices, - token_expert_indices, - gating_output, - renormalize, - routed_scaling_factor, - e_score_correction_bias, - input_tokens, - hash_indices_table, - ) + ops.topk_hash_softplus_sqrt( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + input_tokens, + hash_indices_table, + ) return topk_weights, topk_indices From e30a6c8836202ea327ab1578412f6e714f61a99e Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Thu, 7 May 2026 11:45:51 +0000 Subject: [PATCH 6/8] use expert_dtype appropriately for fp8 base model --- vllm/model_executor/models/deepseek_v4.py | 58 +++++++++++++++++++---- 1 file changed, 49 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 36266dff9c15..2a0c8c8772c1 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -65,6 +65,50 @@ _DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8") +# Map ``quantization_config.quant_method`` values that may be observed in the +# wild to the DeepSeek V4 expert layout they imply. Used as a secondary signal +# when ``hf_config.expert_dtype`` is missing (some FP8 checkpoints, including +# DeepSeek-V4-Flash-Base-FP8, ship without it). +_QUANT_METHOD_TO_EXPERT_DTYPE = { + "fp8": "fp8", + "deepseek_v4_fp8": "fp8", + "mxfp4": "fp4", + "fp4": "fp4", + "nvfp4": "fp4", +} + + +def _resolve_deepseek_v4_expert_dtype(hf_config) -> str: + """Return the DeepSeek V4 expert layout, inferring it when needed. + + Resolution order: + + 1. Honor ``hf_config.expert_dtype`` if present (authoritative). + 2. Otherwise, peek at ``hf_config.quantization_config.quant_method`` + and map it via ``_QUANT_METHOD_TO_EXPERT_DTYPE`` so an FP8 + checkpoint without an explicit ``expert_dtype`` field still + picks the FP8 expert dispatch instead of falling through to the + MXFP4 path. + 3. Fall back to ``"fp4"`` (matches upstream's historical default + for legacy DSv4 checkpoints). + """ + explicit = getattr(hf_config, "expert_dtype", None) + if explicit is not None: + return explicit + + qcfg = getattr(hf_config, "quantization_config", None) + if qcfg is not None: + if isinstance(qcfg, dict): + quant_method = qcfg.get("quant_method") + else: + quant_method = getattr(qcfg, "quant_method", None) + if quant_method is not None: + inferred = _QUANT_METHOD_TO_EXPERT_DTYPE.get(quant_method) + if inferred is not None: + return inferred + + return "fp4" + class DeepseekV4MLP(nn.Module): def __init__( @@ -151,10 +195,8 @@ def expert_dtype(self) -> str: except Exception: # vllm_config not yet set; defer the decision until a # later call lands inside set_current_vllm_config. - #return "fp4" - return "fp8" - #expert_dtype = getattr(hf_config, "expert_dtype", "fp4") - expert_dtype = getattr(hf_config, "expert_dtype", "fp8") + return "fp4" + expert_dtype = _resolve_deepseek_v4_expert_dtype(hf_config) if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES: raise ValueError( f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; " @@ -740,7 +782,7 @@ def __init__( raise NotImplementedError( "DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only." ) - if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4": + if self.use_mega_moe and _resolve_deepseek_v4_expert_dtype(config) != "fp4": raise NotImplementedError( "DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype=" f"{config.expert_dtype!r}. Drop --kernel-config moe_backend=" @@ -1532,16 +1574,14 @@ class DeepseekV4ForCausalLM(nn.Module): # Default mapper assumes the original FP4-expert checkpoint layout. # Overridden per-instance in __init__ when expert_dtype != "fp4". - #hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") - hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8") + hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config - #expert_dtype = getattr(config, "expert_dtype", "fp4") - expert_dtype = "fp8" + expert_dtype = _resolve_deepseek_v4_expert_dtype(config) if expert_dtype != "fp4": self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype) From e46c4cf10d17c57b1738c020ec0608c7716c5bd4 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Tue, 12 May 2026 05:39:07 +0000 Subject: [PATCH 7/8] mhc: add ROCm fallback for the fused mhc_post_pre op MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Upstream-added ``mhc_fused_post_pre`` calls three tilelang kernels (``mhc_fused_tilelang``, ``mhc_post_tilelang``, ``mhc_pre_big_fuse_tilelang``) that all use Program Dependent Launch (PDL — Hopper-only). On ROCm tilelang's ``MarkCudaSyncCalls`` raises ``PDL is not supported`` at JIT-compile time, taking down every TP worker during profile_run: [TileLang:...]: TileLang begins to compile kernel `mhc_post_tilelang` tvm.error.InternalError: Check failed: ... PDL is not supported The non-fused ``mhc_pre`` and ``mhc_post`` already carry torch ROCm fallbacks; this commit composes them to back the fused op on ROCm, matching the contract (4-tuple of residual_cur / post_mix_cur / comb_mix_cur / layer_input_cur with the exact same shapes and dtypes as the tilelang path). The CUDA path is untouched. This unblocks DSv4-Flash-Base-FP8 profile_run on MI300X after the upstream merge that wired the fused op into the layer forward path. Co-authored-by: Cursor --- vllm/model_executor/layers/mhc.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index 22193a12b74f..b3ea3b5e98e6 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -647,6 +647,30 @@ def mhc_fused_post_pre( assert n_splits in (1, 2, 4, 8) assert hidden_size % n_splits == 0 + if current_platform.is_rocm(): + # tilelang ships only CUDA codegen and the fused kernels here + # additionally use PDL (Hopper-only). Compose the existing torch + # fallbacks of ``mhc_post`` + ``mhc_pre`` instead — both already + # have a ROCm branch and produce the exact output shapes/dtypes + # the fused op contracts on. + if post_layer_mix.ndim == residual.ndim - 1: + post_layer_mix_3d = post_layer_mix.unsqueeze(-1) + else: + post_layer_mix_3d = post_layer_mix + residual_cur = mhc_post(x, residual, post_layer_mix_3d, comb_res_mix) + post_mix_cur, comb_mix_cur, layer_input_cur = mhc_pre( + residual_cur, + fn, + hc_scale, + hc_base, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + ) + return residual_cur, post_mix_cur, comb_mix_cur, layer_input_cur + residual_flat = residual.view(-1, hc_mult, hidden_size) num_tokens = residual_flat.shape[0] x_flat = x.view(num_tokens, hidden_size) From 75a8e1bcc9974c9a71e698a43c5904ec2c425928 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Tue, 12 May 2026 06:03:50 +0000 Subject: [PATCH 8/8] add dsv4 flash mla triton validation --- vllm/envs.py | 29 +++++++++++-------- .../layers/deepseek_v4_attention.py | 20 +++++++++++-- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 41b4d1f5187e..ddd6254b3b92 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1077,9 +1077,9 @@ def _get_or_set_default() -> str: "VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: ( os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1") ), - # Master switch for the pre-rebase ROCm-native code paths used by + # Master switch for the ROCm-native code paths used by # DeepSeek-V4 (DSv4-Flash-FP8). When True (default on ROCm) the model - # selects the validated pre-rebase implementations at two call sites: + # selects the triton/torch fallbacks at three call sites: # # 1. SWA K-cache writer: torch reference # (``_deepseek_v4_qnorm_rope_kv_insert_reference``) instead of @@ -1087,19 +1087,24 @@ def _get_or_set_default() -> str: # C++ kernel, whose FP8 dtype is selected at compile time # (``HIP_FP8_TYPE_OCP``) and silently corrupts every K byte on # MI300X (FNUZ-only). This is the regression fix. - # 2. Sparse indexer: recovered ``rocm_sparse_attn_indexer_no_insert`` + # 2. Sparse indexer: ``rocm_sparse_attn_indexer_no_insert`` # orchestration instead of upstream's # ``rocm_aiter_sparse_attn_indexer_native``. + # 3. MLA sparse backend dispatch: route through the unified + # ``DeepseekV4FlashMLASparseBackend`` (whose ROCm kernels are + # supplied by ``flash_mla_with_kvcache_rocm`` / + # ``flash_mla_sparse_fwd_rocm`` via ``flashmla.py``) instead of + # ``DeepseekV4ROCMAiterMLASparseBackend`` / + # ``Impl`` (whose ``_sparse_attn_decode_ragged_kernel`` Triton + # kernel currently hard-codes the SM89 ``tl.float8e4b15`` dtype + # in the ``IS_FNUZ`` branch and crashes JIT-compile on + # gfx942 — see logs/0512/server_log2.txt). # - # NOTE: the MLA decode/sparse-prefill paths are not gated by this env - # var any more — upstream unified the call sites, and our ROCm - # ``flash_mla_with_kvcache_rocm`` / ``flash_mla_sparse_fwd_rocm`` - # Triton kernels are always dispatched on ROCm via - # ``vllm.v1.attention.ops.flashmla``. - # - # Set to "0" to opt back into the upstream paths for bisection (note: - # the SWA-writer C++ kernel still produces deterministic garbage on - # MI300X, so site 1 is only useful for kernel debugging at present). + # Set to "0" to opt back into the upstream AITER + native paths for + # bisection (note: the SWA-writer C++ kernel still produces + # deterministic garbage on MI300X, and the AITER Triton kernel has the + # ``fp8e4b15`` bug above, so env=0 is only useful for kernel debugging + # at present). "VLLM_ROCM_USE_V4_TRITON_FALLBACK": lambda: ( os.getenv("VLLM_ROCM_USE_V4_TRITON_FALLBACK", "True").lower() in ("true", "1") ), diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 955eac088557..6e5c04817b3b 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -831,7 +831,16 @@ def __init__( self.kv_cache = torch.tensor([]) def get_attn_backend(self) -> type[AttentionBackend]: - if current_platform.is_rocm(): + if ( + current_platform.is_rocm() + and not envs.VLLM_ROCM_USE_V4_TRITON_FALLBACK + ): + # Opt-in (env=0) routes ROCm through the new AITER sparse + # MLA backend. Default (env=1) falls through to the unified + # FlashMLASparse backend; the actual kernels are then + # supplied by ``vllm.v1.attention.ops.flashmla`` which on + # ROCm hands off to our Triton fallbacks + # (``flash_mla_with_kvcache_rocm`` / ``flash_mla_sparse_fwd_rocm``). from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( DeepseekV4ROCMAiterMLASparseBackend, ) @@ -869,7 +878,14 @@ def forward( f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" ) - if current_platform.is_rocm(): + if ( + current_platform.is_rocm() + and not envs.VLLM_ROCM_USE_V4_TRITON_FALLBACK + ): + # See the matching gate in ``get_attn_backend``: env=0 opts + # into the AITER sparse MLA impl. Default (env=1) falls + # through to the unified path below, which routes ROCm to our + # Triton kernels via ``flashmla.py``. from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( DeepseekV4ROCMAiterMLASparseImpl, )