Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
"pangu_pro_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
Expand Down Expand Up @@ -199,7 +200,9 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
)
if hf_config.model_type in ("pangu_ultra_moe"):
hf_config.model_type = "pangu_ultra_moe_mtp"
if hf_config.model_type == "pangu_ultra_moe_mtp":
if hf_config.model_type in ("PanguProMoE"):
hf_config.model_type = "pangu_pro_moe_mtp"
if hf_config.model_type in ["pangu_ultra_moe_mtp", "pangu_pro_moe_mtp"]:
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def create_static_sink_attention_backend(
underlying_builder = underlying_attn_backend.get_builder_cls()

class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
supports_update_block_table: bool = False

def __init__(
self,
kv_cache_spec: AttentionSpec,
Expand Down Expand Up @@ -122,6 +124,7 @@ def __init__(
cache_config: CacheConfig | None = None,
**kwargs,
):
CustomOp.__init__(self)
dtype = torch.get_default_dtype()

if cache_config is not None:
Expand Down Expand Up @@ -150,7 +153,6 @@ def __init__(
attn_backend=attn_backend,
**kwargs,
)
CustomOp.__init__(self)

self.sink_len = sink_len
self.block_size = block_size
Expand Down
16 changes: 16 additions & 0 deletions vllm/model_executor/models/openpangu_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]

expert_params_mapping = FusedMoE.make_expert_params_mapping(
Expand Down Expand Up @@ -218,6 +221,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if name.endswith(".bias") and name not in params_dict:
continue

if name.endswith("e_score_correction_bias"):
name = name.replace(
"e_score_correction_bias", "gate.e_score_correction_bias"
)
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
Expand All @@ -230,8 +237,17 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
)
weight_loader(param, loaded_weight)
loaded_params.add(name)

self.post_weight_load()
return loaded_params

def post_weight_load(self) -> None:
for name, module in self.named_modules():
if module is self:
continue
if hasattr(module, "post_weight_load"):
module.post_weight_load()

def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Expand Down
1 change: 1 addition & 0 deletions vllm/transformers_utils/model_arch_config_convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,5 +456,6 @@ def get_num_hidden_layers(self) -> int:
"glm_ocr_mtp": GLM4MoeMTPModelArchConfigConvertor,
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"pangu_pro_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
}
4 changes: 2 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,9 +1919,9 @@ def _build_attn_group_metadata(
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
spec_decode_common_attn_metadata = cm
spec_decode_common_attn_metadata = copy(cm)
else:
spec_decode_common_attn_metadata = cm
spec_decode_common_attn_metadata = copy(cm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this copy into StaticSinkAttentionBuilder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the builder is not responsible for the building of spec_decode_common_attn_metadata, it is handled by gpu_model_runner outside.


for attn_gid in range(len(self.attn_groups[kv_cache_gid])):
if ubatch_slices is not None:
Expand Down