From c692771b44feffb4ccc58db60f8a573d5f33db98 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 17 Jan 2026 15:55:28 +0800 Subject: [PATCH] Add MTP for opanpangu_pro_moe model, fix an initialization bug in StaticSinkAttention Signed-off-by: yuantao <2422264527@qq.com> --- vllm/config/speculative.py | 5 ++++- .../layers/attention/static_sink_attention.py | 4 +++- vllm/model_executor/models/openpangu_mtp.py | 16 ++++++++++++++++ .../model_arch_config_convertor.py | 1 + vllm/v1/worker/gpu_model_runner.py | 4 ++-- 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 36e6447124f8..6db35f748a05 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -39,6 +39,7 @@ "longcat_flash_mtp", "mtp", "pangu_ultra_moe_mtp", + "pangu_pro_moe_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] SpeculativeMethod = Literal[ @@ -179,7 +180,9 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: ) if hf_config.model_type in ("pangu_ultra_moe"): hf_config.model_type = "pangu_ultra_moe_mtp" - if hf_config.model_type == "pangu_ultra_moe_mtp": + if hf_config.model_type in ("PanguProMoE"): + hf_config.model_type = "pangu_pro_moe_mtp" + if hf_config.model_type in ["pangu_ultra_moe_mtp", "pangu_pro_moe_mtp"]: n_predict = getattr(hf_config, "num_nextn_predict_layers", None) hf_config.update( {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]} diff --git a/vllm/model_executor/layers/attention/static_sink_attention.py b/vllm/model_executor/layers/attention/static_sink_attention.py index a869226ea182..ad509e5469d2 100644 --- a/vllm/model_executor/layers/attention/static_sink_attention.py +++ b/vllm/model_executor/layers/attention/static_sink_attention.py @@ -40,6 +40,8 @@ def create_static_sink_attention_backend( underlying_builder = underlying_attn_backend.get_builder_cls() class StaticSinkAttentionBuilder(underlying_builder): # type: ignore + supports_update_block_table: bool = False + def __init__( self, kv_cache_spec: AttentionSpec, @@ -122,6 +124,7 @@ def __init__( cache_config: CacheConfig | None = None, **kwargs, ): + CustomOp.__init__(self) dtype = torch.get_default_dtype() if cache_config is not None: @@ -150,7 +153,6 @@ def __init__( attn_backend=attn_backend, **kwargs, ) - CustomOp.__init__(self) self.sink_len = sink_len self.block_size = block_size diff --git a/vllm/model_executor/models/openpangu_mtp.py b/vllm/model_executor/models/openpangu_mtp.py index 273351051797..385cc20d40d2 100644 --- a/vllm/model_executor/models/openpangu_mtp.py +++ b/vllm/model_executor/models/openpangu_mtp.py @@ -145,6 +145,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("gate_up_proj", "up_proj", 1), ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), ] expert_params_mapping = FusedMoE.make_expert_params_mapping( @@ -218,6 +221,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if name.endswith(".bias") and name not in params_dict: continue + if name.endswith("e_score_correction_bias"): + name = name.replace( + "e_score_correction_bias", "gate.e_score_correction_bias" + ) if ( spec_layer != self.model.mtp_start_layer_idx and ".layers" not in name @@ -230,8 +237,17 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) weight_loader(param, loaded_weight) loaded_params.add(name) + + self.post_weight_load() return loaded_params + def post_weight_load(self) -> None: + for name, module in self.named_modules(): + if module is self: + continue + if hasattr(module, "post_weight_load"): + module.post_weight_load() + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: """ Rewrite the weight name to match the format of the original model. diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 6df4bb64dceb..a14c2a694f7f 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -398,5 +398,6 @@ def get_num_hidden_layers(self) -> int: "glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor, "ernie_mtp": ErnieMTPModelArchConfigConvertor, "pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, + "pangu_pro_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, "longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor, } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5691a76983b6..cc81eb9bf647 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1762,9 +1762,9 @@ def _build_attn_group_metadata( if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: - spec_decode_common_attn_metadata = cm + spec_decode_common_attn_metadata = copy(cm) else: - spec_decode_common_attn_metadata = cm + spec_decode_common_attn_metadata = copy(cm) for attn_gid in range(len(self.attn_groups[kv_cache_gid])): if ubatch_slices is not None: