From 2b2028a9a6bf8710ed155d1349b4290730c85b26 Mon Sep 17 00:00:00 2001 From: Li Date: Tue, 31 Mar 2026 15:57:25 -0700 Subject: [PATCH 1/4] [ROCm] Enable dual-stream MoE shared experts and GLM-5 MXFP4 Quark support Enable dual-stream shared expert overlap on ROCm by using is_cuda_alike() instead of is_cuda() in the MoE forward path. This allows shared experts and routed experts to execute concurrently on separate HIP streams, matching the optimization already available on CUDA. Also add GLM-5 (glm_moe_dsa) to the Quark dynamic MXFP4 model types so that its attention projections use the same dynamic re-quantization path as DeepSeek-V3 family models. Co-authored-by: Claude Signed-off-by: Chuan Li Made-with: Cursor --- vllm/model_executor/layers/fused_moe/runner/shared_experts.py | 2 +- vllm/model_executor/layers/quantization/quark/quark.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/runner/shared_experts.py b/vllm/model_executor/layers/fused_moe/runner/shared_experts.py index 6d2189cb49b3..10261f6a940a 100644 --- a/vllm/model_executor/layers/fused_moe/runner/shared_experts.py +++ b/vllm/model_executor/layers/fused_moe/runner/shared_experts.py @@ -117,7 +117,7 @@ def _determine_shared_experts_order( return SharedExpertsOrder.MK_INTERNAL_OVERLAPPED should_run_shared_in_aux_stream = ( - current_platform.is_cuda() + current_platform.is_cuda_alike() and not self._use_dp_chunking and self._stream is not None and hidden_states.shape[0] diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index d0362cedcf2b..29d9ce908c2b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -46,8 +46,8 @@ logger = init_logger(__name__) # model_type values that use dynamic MXFP4 re-quantization for -# OCP MX fp4 Quark checkpoints -_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"}) +# OCP MX fp4 Quark checkpoints (DSA-MoE architecture family) +_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3", "glm_moe_dsa"}) class QuarkConfig(QuantizationConfig): From ba2e2f20b65c39fb0b98d549b682e04ea1188203 Mon Sep 17 00:00:00 2001 From: Li Date: Tue, 31 Mar 2026 20:08:00 -0700 Subject: [PATCH 2/4] [ROCm] Work around AITER sparse MLA ZeroDivisionError for < 16 heads AITER's deepgemm_fp8_paged_mqa_logits_stage1 kernel computes TileQCount from num_heads; when heads < 16 (e.g. GLM-5 with TP=8 giving 8 heads per GPU), TileQCount becomes 0, causing ZeroDivisionError. Guard both rocm_fp8_paged_mqa_logits and rocm_fp8_mqa_logits to fall back to the PyTorch reference implementation when num_heads < 16, with a one-time warning log. Tracked upstream: https://github.com/ROCm/aiter/issues/2563 Co-authored-by: Claude Made-with: Cursor --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 9d1da5b53be5..91e90f4585fc 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -6,11 +6,16 @@ import torch from vllm.forward_context import get_forward_context +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton +logger = init_logger(__name__) + +_AITER_MQA_SMALL_HEADS_WARNED = False + if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops @@ -322,17 +327,29 @@ def rocm_fp8_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ + global _AITER_MQA_SMALL_HEADS_WARNED from vllm._aiter_ops import rocm_aiter_ops + batch_size, next_n, heads, _ = q_fp8.shape + + # AITER's deepgemm_fp8_paged_mqa_logits_stage1 computes TileQCount + # from num_heads; when heads < 16 (e.g. GLM-5 with TP=8 → 8 heads) + # TileQCount becomes 0, causing ZeroDivisionError. + # Tracked: https://github.com/ROCm/aiter/issues/2563 aiter_paged_mqa_logits_module = None - if rocm_aiter_ops.is_enabled(): + if rocm_aiter_ops.is_enabled() and heads >= 16: aiter_paged_mqa_logits_module = paged_mqa_logits_module() + elif rocm_aiter_ops.is_enabled() and not _AITER_MQA_SMALL_HEADS_WARNED: + logger.warning( + "AITER paged MQA logits kernel does not support %d heads " + "(requires >= 16). Falling back to PyTorch reference. " + "See https://github.com/ROCm/aiter/issues/2563", heads) + _AITER_MQA_SMALL_HEADS_WARNED = True if aiter_paged_mqa_logits_module is not None: deepgemm_fp8_paged_mqa_logits_stage1 = ( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 ) - batch_size, next_n, heads, _ = q_fp8.shape out_qk = torch.full( (heads, batch_size * next_n, max_model_len), float("-inf"), @@ -449,8 +466,9 @@ def rocm_fp8_mqa_logits( # path after aiter merge this kernel into main from vllm._aiter_ops import rocm_aiter_ops + heads = q.shape[1] aiter_mqa_logits_module = None - if rocm_aiter_ops.is_enabled(): + if rocm_aiter_ops.is_enabled() and heads >= 16: aiter_mqa_logits_module = mqa_logits_module() if aiter_mqa_logits_module is not None: From 2ad2f33e5b19cc8b3af0484306149b5d7d195167 Mon Sep 17 00:00:00 2001 From: Li Date: Wed, 1 Apr 2026 08:18:53 -0700 Subject: [PATCH 3/4] Fix ruff format in rocm_aiter_mla_sparse.py Signed-off-by: Li Co-authored-by: Claude Made-with: Cursor --- vllm/v1/attention/ops/rocm_aiter_mla_sparse.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 91e90f4585fc..baf15636a0c0 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -343,7 +343,9 @@ def rocm_fp8_paged_mqa_logits( logger.warning( "AITER paged MQA logits kernel does not support %d heads " "(requires >= 16). Falling back to PyTorch reference. " - "See https://github.com/ROCm/aiter/issues/2563", heads) + "See https://github.com/ROCm/aiter/issues/2563", + heads, + ) _AITER_MQA_SMALL_HEADS_WARNED = True if aiter_paged_mqa_logits_module is not None: From e40e395e4fa77ce90c10602fb3a3d9e056f32939 Mon Sep 17 00:00:00 2001 From: Li Date: Fri, 3 Apr 2026 14:02:42 -0700 Subject: [PATCH 4/4] [ROCm] Fix GLM-5-FP8 weight loading for fused indexer wk_weights_proj GLM-5-FP8 checkpoints quantize the fused wk_weights_proj tensor with FP8 block quantization (weight + weight_scale_inv). Resolve merge conflict with upstream indexer refactor (#38684/#38870) by always using fused MergedColumnParallelLinear with quant_config: - FP4: quant_config=None (weights_proj should not be quantized) - Non-FP4: quant_config=quant_config (supports FP8 weight_scale_inv) Add fallback in load_weights to handle both fused and separate checkpoint formats gracefully via stacked_params_mapping. Also reverts glm_moe_dsa from _DEEPSEEK_V3_FAMILY_MODEL_TYPES per review feedback (will be submitted as a standalone PR). Co-authored-by: Claude Signed-off-by: Chuan Li Made-with: Cursor --- .../layers/quantization/quark/quark.py | 4 +- vllm/model_executor/models/deepseek_mtp.py | 21 +++--- vllm/model_executor/models/deepseek_v2.py | 73 ++++++++----------- 3 files changed, 42 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 29d9ce908c2b..d0362cedcf2b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -46,8 +46,8 @@ logger = init_logger(__name__) # model_type values that use dynamic MXFP4 re-quantization for -# OCP MX fp4 Quark checkpoints (DSA-MoE architecture family) -_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3", "glm_moe_dsa"}) +# OCP MX fp4 Quark checkpoints +_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"}) class QuarkConfig(QuantizationConfig): diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 126efb6f88e6..88d13c430406 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -248,13 +248,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] - if self.is_fp4_ckpt: - # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) - indexer_fused_mapping = [ - ("wk_weights_proj", "wk", 0), - ("wk_weights_proj", "weights_proj", 1), - ] - stacked_params_mapping.extend(indexer_fused_mapping) + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj). + # Always included; the fallback check in the loading loop handles + # checkpoints that already have fused wk_weights_proj tensors. + indexer_fused_mapping = [ + ("wk_weights_proj", "wk", 0), + ("wk_weights_proj", "weights_proj", 1), + ] + stacked_params_mapping.extend(indexer_fused_mapping) expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( self, @@ -297,10 +298,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue name_mapped = name.replace(weight_name, param_name) - # QKV fusion is optional, fall back to normal - # weight loading if it's not enabled + # QKV fusion and indexer fusion are optional — fall back to + # direct weight loading when the mapped name doesn't exist. if ( - param_name == "fused_qkv_a_proj" + param_name in ("fused_qkv_a_proj", "wk_weights_proj") ) and name_mapped not in params_dict: continue else: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cfeb36f4af25..3f50181a0eee 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -644,36 +644,20 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.wq_b", ) - if self.is_fp4_ckpt: - # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. - # weights_proj does not get quantized, - # so we run both with quant_config=None - # wk may be upcasted from the default quant; - # experiments show fusion is always faster unless WK proj is in FP4, - # which is not the case for all known quants. - self.wk_weights_proj = MergedColumnParallelLinear( - hidden_size, - [self.head_dim, self.n_head], - bias=False, - quant_config=None, - disable_tp=True, - prefix=f"{prefix}.wk_weights_proj", - ) - else: - self.wk = ReplicatedLinear( - hidden_size, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk", - ) - self.weights_proj = ReplicatedLinear( - hidden_size, - self.n_head, - bias=False, - quant_config=None, - prefix=f"{prefix}.weights_proj", - ) + # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. + # FP4 checkpoints don't quantize weights_proj, so use quant_config=None. + # Other quantized checkpoints (e.g. GLM-5-FP8) may quantize the fused + # tensor, so pass quant_config to create weight_scale_inv parameters. + # Checkpoints with separate wk/weights_proj tensors are handled by the + # stacked_params_mapping in load_weights. + self.wk_weights_proj = MergedColumnParallelLinear( + hidden_size, + [self.head_dim, self.n_head], + bias=False, + quant_config=None if self.is_fp4_ckpt else quant_config, + disable_tp=True, + prefix=f"{prefix}.wk_weights_proj", + ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 @@ -714,14 +698,9 @@ def forward( q_pe, q_nope = torch.split( q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) - if self.is_fp4_ckpt: - # Fused wk + weights_proj: one GEMM, then split - kw, _ = self.wk_weights_proj(hidden_states) - k = kw[:, : self.head_dim] - weights = kw[:, self.head_dim :] - else: - k, _ = self.wk(hidden_states) - weights, _ = self.weights_proj(hidden_states) + kw, _ = self.wk_weights_proj(hidden_states) + k = kw[:, : self.head_dim] + weights = kw[:, self.head_dim :] k = self.k_norm(k) k_pe, k_nope = torch.split( @@ -1469,8 +1448,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] - if self.is_fp4_ckpt: - # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) + if self.is_v32: + # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj). + # For checkpoints with separate wk/weights_proj tensors, this mapping + # loads them into the fused MergedColumnParallelLinear shards. + # For checkpoints that already have fused wk_weights_proj (e.g. + # GLM-5-FP8), the substring match is a false positive and the + # fallback check in the loading loop skips it gracefully. indexer_fused_mapping = [ ("wk_weights_proj", "wk", 0), ("wk_weights_proj", "weights_proj", 1), @@ -1528,11 +1512,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue name_mapped = name.replace(weight_name, param_name) - # QKV fusion is optional, fall back to normal - # weight loading if it's not enabled - # if go with fusion option, then update name + # QKV fusion and indexer fusion are optional — fall back to + # direct weight loading when the mapped name doesn't exist + # (e.g. fused checkpoints where "wk" falsely matches + # "wk_weights_proj", or when QKV fusion is disabled). if ( - param_name == "fused_qkv_a_proj" + param_name in ("fused_qkv_a_proj", "wk_weights_proj") ) and name_mapped not in params_dict: continue else: