From 0fad57a6124b43dd4da189eb22faf527c32cfd8c Mon Sep 17 00:00:00 2001 From: yjc9696 <616530803@qq.com> Date: Thu, 21 Aug 2025 23:31:42 +0800 Subject: [PATCH 1/2] hardcode norm_topk_prob --- .../models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py | 4 ---- .../models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py | 4 +--- .../models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py | 4 +--- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py index 28d114af0525..fb8cba72bdfc 100644 --- a/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py @@ -99,8 +99,6 @@ class HunYuanMoEV1Config(PretrainedConfig): The dropout ratio for the attention probabilities. num_experts (`int` or `List`, *optional*, defaults to 1): The number of experts for moe. If it is a list, it will be used as the number of experts for each layer. - norm_topk_prob (`bool`, *optional*, defaults to `True`): - Whether to normalize the topk probabilities. moe_topk (int or List, *optional*, defaults to 1): Number of experts selected per token (Top-K routing). List form enables layer-wise customization. head_dim (`int`, *optional*, defaults to 128): @@ -135,7 +133,6 @@ def __init__( attention_bias=False, attention_dropout=0.0, num_experts: Union[int, list] = 1, - norm_topk_prob=True, moe_topk: Union[int, list] = 1, head_dim=None, **kwargs, @@ -147,7 +144,6 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_experts = num_experts - self.norm_topk_prob = norm_topk_prob self.moe_topk = moe_topk self.head_dim = head_dim diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 8bfdab6d159b..0b2df41ba59e 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -252,7 +252,6 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): self.layer_idx = layer_idx self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] - self.norm_topk_prob = config.norm_topk_prob self.gate = HunYuanMoEV1Gate(config, layer_idx=layer_idx) # self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32) self.experts = nn.ModuleList( @@ -270,8 +269,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index a8b7f92f9941..66801569efb8 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -137,7 +137,6 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): self.layer_idx = layer_idx self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] - self.norm_topk_prob = config.norm_topk_prob self.gate = HunYuanMoEV1Gate(config, layer_idx=layer_idx) # self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32) self.experts = nn.ModuleList( @@ -155,8 +154,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) From 70cf3468371cc9faf92412a32cfeea1254f7a740 Mon Sep 17 00:00:00 2001 From: yjc9696 <616530803@qq.com> Date: Thu, 21 Aug 2025 23:57:43 +0800 Subject: [PATCH 2/2] fix testcase --- .../hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py b/tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py index 237ed5ae8ba3..6194d4eec8c8 100644 --- a/tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py +++ b/tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py @@ -113,8 +113,8 @@ def tearDown(self): def test_model_generation(self): # we will compele this when model file change over # pass - EXPECTED_ANSWER = "\nRegular exercise offers numerous physical, mental, and emotional benefits. It improves cardiovascular health, strengthens muscles and bones, boosts metabolism, and helps" - prompt = "Write a short summary of the benefits of regular exercise " + EXPECTED_ANSWER = "\nOkay, I need to write a short summary about the benefits of regular exercise. Let me start by recalling what I know. First," + prompt = "Write a short summary of the benefits of regular exercise" tokenizer = AutoTokenizer.from_pretrained("tencent/Hunyuan-A13B-Instruct") model = AutoModelForCausalLM.from_pretrained("tencent/Hunyuan-A13B-Instruct", device_map="auto") messages = [ @@ -125,9 +125,8 @@ def test_model_generation(self): tokenize=True, add_generation_prompt=True, return_tensors="pt", - enable_thinking=False, # Toggle thinking mode (default: True) ) generated_ids = model.generate(tokenized_chat.to(model.device), max_new_tokens=30, top_k=1) - text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - answer = text.split("")[1] - self.assertEqual(EXPECTED_ANSWER, answer) + text = tokenizer.decode(generated_ids[0]) + output = text.split("")[1] + self.assertEqual(EXPECTED_ANSWER, output)