diff --git a/vllm/model_executor/layers/fused_moe/runner/shared_experts.py b/vllm/model_executor/layers/fused_moe/runner/shared_experts.py index 6d2189cb49b3..10261f6a940a 100644 --- a/vllm/model_executor/layers/fused_moe/runner/shared_experts.py +++ b/vllm/model_executor/layers/fused_moe/runner/shared_experts.py @@ -117,7 +117,7 @@ def _determine_shared_experts_order( return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED should_run_shared_in_aux_stream = ( - current_platform.is_cuda() + current_platform.is_cuda_alike() and not self._use_dp_chunking and self._stream is not None and hidden_states.shape[0] diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 126efb6f88e6..88d13c430406 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -248,13 +248,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] - if self.is_fp4_ckpt: - # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) - indexer_fused_mapping = [ - ("wk_weights_proj", "wk", 0), - ("wk_weights_proj", "weights_proj", 1), - ] - stacked_params_mapping.extend(indexer_fused_mapping) + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj). + # Always included; the fallback check in the loading loop handles + # checkpoints that already have fused wk_weights_proj tensors. + indexer_fused_mapping = [ + ("wk_weights_proj", "wk", 0), + ("wk_weights_proj", "weights_proj", 1), + ] + stacked_params_mapping.extend(indexer_fused_mapping) expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( self, @@ -297,10 +298,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue name_mapped = name.replace(weight_name, param_name) - # QKV fusion is optional, fall back to normal - # weight loading if it's not enabled + # QKV fusion and indexer fusion are optional — fall back to + # direct weight loading when the mapped name doesn't exist. if ( - param_name == "fused_qkv_a_proj" + param_name in ("fused_qkv_a_proj", "wk_weights_proj") ) and name_mapped not in params_dict: continue else: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cfeb36f4af25..3f50181a0eee 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -644,36 +644,20 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.wq_b", ) - if self.is_fp4_ckpt: - # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. - # weights_proj does not get quantized, - # so we run both with quant_config=None - # wk may be upcasted from the default quant; - # experiments show fusion is always faster unless WK proj is in FP4, - # which is not the case for all known quants. - self.wk_weights_proj = MergedColumnParallelLinear( - hidden_size, - [self.head_dim, self.n_head], - bias=False, - quant_config=None, - disable_tp=True, - prefix=f"{prefix}.wk_weights_proj", - ) - else: - self.wk = ReplicatedLinear( - hidden_size, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk", - ) - self.weights_proj = ReplicatedLinear( - hidden_size, - self.n_head, - bias=False, - quant_config=None, - prefix=f"{prefix}.weights_proj", - ) + # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. + # FP4 checkpoints don't quantize weights_proj, so use quant_config=None. + # Other quantized checkpoints (e.g. GLM-5-FP8) may quantize the fused + # tensor, so pass quant_config to create weight_scale_inv parameters. + # Checkpoints with separate wk/weights_proj tensors are handled by the + # stacked_params_mapping in load_weights. + self.wk_weights_proj = MergedColumnParallelLinear( + hidden_size, + [self.head_dim, self.n_head], + bias=False, + quant_config=None if self.is_fp4_ckpt else quant_config, + disable_tp=True, + prefix=f"{prefix}.wk_weights_proj", + ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 @@ -714,14 +698,9 @@ def forward( q_pe, q_nope = torch.split( q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) - if self.is_fp4_ckpt: - # Fused wk + weights_proj: one GEMM, then split - kw, _ = self.wk_weights_proj(hidden_states) - k = kw[:, : self.head_dim] - weights = kw[:, self.head_dim :] - else: - k, _ = self.wk(hidden_states) - weights, _ = self.weights_proj(hidden_states) + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] k = self.k_norm(k) k_pe, k_nope = torch.split( @@ -1469,8 +1448,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] - if self.is_fp4_ckpt: - # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) + if self.is_v32: + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj). + # For checkpoints with separate wk/weights_proj tensors, this mapping + # loads them into the fused MergedColumnParallelLinear shards. + # For checkpoints that already have fused wk_weights_proj (e.g. + # GLM-5-FP8), the substring match is a false positive and the + # fallback check in the loading loop skips it gracefully. indexer_fused_mapping = [ ("wk_weights_proj", "wk", 0), ("wk_weights_proj", "weights_proj", 1), @@ -1528,11 +1512,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue name_mapped = name.replace(weight_name, param_name) - # QKV fusion is optional, fall back to normal - # weight loading if it's not enabled - # if go with fusion option, then update name + # QKV fusion and indexer fusion are optional — fall back to + # direct weight loading when the mapped name doesn't exist + # (e.g. fused checkpoints where "wk" falsely matches + # "wk_weights_proj", or when QKV fusion is disabled). if ( - param_name == "fused_qkv_a_proj" + param_name in ("fused_qkv_a_proj", "wk_weights_proj") ) and name_mapped not in params_dict: continue else: diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 9d1da5b53be5..baf15636a0c0 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -6,11 +6,16 @@ import torch from vllm.forward_context import get_forward_context +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton +logger = init_logger(__name__) + +_AITER_MQA_SMALL_HEADS_WARNED = False + if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops @@ -322,17 +327,31 @@ def rocm_fp8_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ + global _AITER_MQA_SMALL_HEADS_WARNED from vllm._aiter_ops import rocm_aiter_ops + batch_size, next_n, heads, _ = q_fp8.shape + + # AITER's deepgemm_fp8_paged_mqa_logits_stage1 computes TileQCount + # from num_heads; when heads < 16 (e.g. GLM-5 with TP=8 → 8 heads) + # TileQCount becomes 0, causing ZeroDivisionError. + # Tracked: https://github.com/ROCm/aiter/issues/2563 aiter_paged_mqa_logits_module = None - if rocm_aiter_ops.is_enabled(): + if rocm_aiter_ops.is_enabled() and heads >= 16: aiter_paged_mqa_logits_module = paged_mqa_logits_module() + elif rocm_aiter_ops.is_enabled() and not _AITER_MQA_SMALL_HEADS_WARNED: + logger.warning( + "AITER paged MQA logits kernel does not support %d heads " + "(requires >= 16). Falling back to PyTorch reference. " + "See https://github.com/ROCm/aiter/issues/2563", + heads, + ) + _AITER_MQA_SMALL_HEADS_WARNED = True if aiter_paged_mqa_logits_module is not None: deepgemm_fp8_paged_mqa_logits_stage1 = ( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 ) - batch_size, next_n, heads, _ = q_fp8.shape out_qk = torch.full( (heads, batch_size * next_n, max_model_len), float("-inf"), @@ -449,8 +468,9 @@ def rocm_fp8_mqa_logits( # path after aiter merge this kernel into main from vllm._aiter_ops import rocm_aiter_ops + heads = q.shape[1] aiter_mqa_logits_module = None - if rocm_aiter_ops.is_enabled(): + if rocm_aiter_ops.is_enabled() and heads >= 16: aiter_mqa_logits_module = mqa_logits_module() if aiter_mqa_logits_module is not None: