diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index c2bced7842d3..8d8f6cf1ffe6 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -43,6 +43,7 @@ "longcat_flash_mtp", "mtp", "pangu_ultra_moe_mtp", + "pangu_pro_moe_mtp", "step3p5_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] @@ -199,7 +200,9 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: ) if hf_config.model_type in ("pangu_ultra_moe"): hf_config.model_type = "pangu_ultra_moe_mtp" - if hf_config.model_type == "pangu_ultra_moe_mtp": + if hf_config.model_type in ("PanguProMoE"): + hf_config.model_type = "pangu_pro_moe_mtp" + if hf_config.model_type in ["pangu_ultra_moe_mtp", "pangu_pro_moe_mtp"]: n_predict = getattr(hf_config, "num_nextn_predict_layers", None) hf_config.update( {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]} diff --git a/vllm/model_executor/layers/attention/static_sink_attention.py b/vllm/model_executor/layers/attention/static_sink_attention.py index 49d83823b512..76ab659ee69a 100644 --- a/vllm/model_executor/layers/attention/static_sink_attention.py +++ b/vllm/model_executor/layers/attention/static_sink_attention.py @@ -40,6 +40,8 @@ def create_static_sink_attention_backend( underlying_builder = underlying_attn_backend.get_builder_cls() class StaticSinkAttentionBuilder(underlying_builder): # type: ignore + supports_update_block_table: bool = False + def __init__( self, kv_cache_spec: AttentionSpec, @@ -122,6 +124,7 @@ def __init__( cache_config: CacheConfig | None = None, **kwargs, ): + CustomOp.__init__(self) dtype = torch.get_default_dtype() if cache_config is not None: @@ -150,7 +153,6 @@ def __init__( attn_backend=attn_backend, **kwargs, ) - CustomOp.__init__(self) self.sink_len = sink_len self.block_size = block_size diff --git a/vllm/model_executor/models/openpangu_mtp.py b/vllm/model_executor/models/openpangu_mtp.py index 91b454a4bc38..162216958807 100644 --- a/vllm/model_executor/models/openpangu_mtp.py +++ b/vllm/model_executor/models/openpangu_mtp.py @@ -145,6 +145,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("gate_up_proj", "up_proj", 1), ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), ] expert_params_mapping = FusedMoE.make_expert_params_mapping( @@ -218,6 +221,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if name.endswith(".bias") and name not in params_dict: continue + if name.endswith("e_score_correction_bias"): + name = name.replace( + "e_score_correction_bias", "gate.e_score_correction_bias" + ) if ( spec_layer != self.model.mtp_start_layer_idx and ".layers" not in name @@ -230,8 +237,17 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) weight_loader(param, loaded_weight) loaded_params.add(name) + + self.post_weight_load() return loaded_params + def post_weight_load(self) -> None: + for name, module in self.named_modules(): + if module is self: + continue + if hasattr(module, "post_weight_load"): + module.post_weight_load() + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: """ Rewrite the weight name to match the format of the original model. diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 5fc737e8ee90..3d26936e0bcc 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -456,5 +456,6 @@ def get_num_hidden_layers(self) -> int: "glm_ocr_mtp": GLM4MoeMTPModelArchConfigConvertor, "ernie_mtp": ErnieMTPModelArchConfigConvertor, "pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, + "pangu_pro_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, "longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor, } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d82b83b8c496..3e21b9371b30 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1919,9 +1919,9 @@ def _build_attn_group_metadata( if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: - spec_decode_common_attn_metadata = cm + spec_decode_common_attn_metadata = copy(cm) else: - spec_decode_common_attn_metadata = cm + spec_decode_common_attn_metadata = copy(cm) for attn_gid in range(len(self.attn_groups[kv_cache_gid])): if ubatch_slices is not None: