Skip to content

Commit 8f09332

Browse files
committed
[BugFix] Fix qwen3 moe bug
Signed-off-by: wangxiyuan <[email protected]>
1 parent 4b3a210 commit 8f09332

File tree

2 files changed

+93
-85
lines changed

2 files changed

+93
-85
lines changed

vllm_ascend/models/qwen3_moe.py

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,113 @@
1717
# This file is a part of the vllm-ascend project.
1818
from typing import Optional
1919

20+
import torch
2021
from torch import nn
2122
from transformers import PretrainedConfig
2223
from vllm.compilation.decorators import support_torch_compile
23-
from vllm.config import CacheConfig
24+
from vllm.config import CacheConfig, ParallelConfig
25+
from vllm.distributed import get_tensor_model_parallel_world_size
26+
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
27+
get_tp_group)
28+
from vllm.forward_context import get_forward_context
2429
from vllm.model_executor.layers.layernorm import RMSNorm
30+
from vllm.model_executor.layers.linear import ReplicatedLinear
2531
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2632
from vllm.model_executor.layers.quantization import QuantizationConfig
2733
from vllm.model_executor.layers.vocab_parallel_embedding import (
2834
ParallelLMHead, VocabParallelEmbedding)
2935
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
3036
Qwen3MoeDecoderLayer,
3137
Qwen3MoeForCausalLM,
32-
Qwen3MoeMLP, Qwen3MoeModel)
38+
Qwen3MoeMLP, Qwen3MoeModel,
39+
Qwen3MoeSparseMoeBlock)
3340
from vllm.model_executor.models.utils import (
3441
extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
3542
maybe_prefix)
3643

37-
from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock
44+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
3845
from vllm_ascend.platform import VllmConfig
3946

4047

48+
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
49+
50+
def __init__(
51+
self,
52+
config: PretrainedConfig,
53+
quant_config: Optional[QuantizationConfig] = None,
54+
prefix: str = "",
55+
):
56+
nn.Module.__init__(self)
57+
self.tp_size = get_tensor_model_parallel_world_size()
58+
if self.tp_size > config.num_experts:
59+
raise ValueError(
60+
f"Tensor parallel size {self.tp_size} is greater than "
61+
f"the number of experts {config.num_experts}.")
62+
63+
self.gate = ReplicatedLinear(
64+
config.hidden_size,
65+
config.num_experts,
66+
bias=False,
67+
quant_config=None,
68+
prefix=f"{prefix}.gate",
69+
)
70+
71+
self.experts = AscendFusedMoE(
72+
num_experts=config.num_experts,
73+
top_k=config.num_experts_per_tok,
74+
hidden_size=config.hidden_size,
75+
intermediate_size=config.moe_intermediate_size,
76+
reduce_results=False,
77+
renormalize=config.norm_topk_prob,
78+
quant_config=quant_config,
79+
prefix=f"{prefix}.experts",
80+
)
81+
82+
self.top_k = config.num_experts_per_tok
83+
84+
self.dp_size = get_dp_group().world_size
85+
86+
self.tp_group = get_tp_group().device_group
87+
self.tp_rank = get_tp_group().rank_in_group
88+
self.ep_group = get_ep_group()
89+
90+
self.params_dtype = torch.get_default_dtype()
91+
92+
def forward(
93+
self,
94+
hidden_states,
95+
attn_metadata=None,
96+
):
97+
if attn_metadata is None:
98+
attn_metadata = get_forward_context().attn_metadata
99+
# when profile runs, force experts to load balanced tokens
100+
# to avoid high memory consumption on a single rank.
101+
enable_force_load_balance = get_forward_context().in_profile_run
102+
is_prefill = get_forward_context().with_prefill
103+
104+
# router_logits: (num_tokens, n_experts)
105+
router_logits, _ = self.gate(hidden_states)
106+
107+
hidden_states = self.experts(
108+
hidden_states=hidden_states,
109+
router_logits=router_logits,
110+
is_prefill=is_prefill,
111+
top_k=self.top_k,
112+
enable_force_load_balance=enable_force_load_balance,
113+
shared_experts=None,
114+
)
115+
116+
return hidden_states
117+
118+
41119
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
42120

43121
def __init__(
44122
self,
45123
config: PretrainedConfig,
46124
cache_config: Optional[CacheConfig] = None,
47125
quant_config: Optional[QuantizationConfig] = None,
126+
parallel_config: Optional[ParallelConfig] = None,
48127
prefix: str = "",
49128
) -> None:
50129

@@ -76,9 +155,15 @@ def __init__(
76155
if (layer_idx not in mlp_only_layers) and (
77156
config.num_experts > 0 and
78157
(layer_idx + 1) % config.decoder_sparse_step == 0):
79-
self.mlp = AscendSparseMoeBlock(config=config,
80-
quant_config=quant_config,
81-
prefix=f"{prefix}.mlp")
158+
if not parallel_config.enable_expert_parallel:
159+
# custom sparse moe block doesn't work with ep currently.
160+
self.mlp = CustomSparseMoeBlock(config=config,
161+
quant_config=quant_config,
162+
prefix=f"{prefix}.mlp")
163+
else:
164+
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
165+
quant_config=quant_config,
166+
prefix=f"{prefix}.mlp")
82167
else:
83168
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
84169
intermediate_size=config.intermediate_size,
@@ -99,6 +184,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
99184
config = vllm_config.model_config.hf_config
100185
cache_config = vllm_config.cache_config
101186
quant_config = vllm_config.quant_config
187+
parallel_config = vllm_config.parallel_config
102188

103189
self.padding_idx = config.pad_token_id
104190
self.vocab_size = config.vocab_size
@@ -113,6 +199,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
113199
config=config,
114200
cache_config=cache_config,
115201
quant_config=quant_config,
202+
parallel_config=parallel_config,
116203
prefix=prefix),
117204
prefix=f"{prefix}.layers",
118205
)

vllm_ascend/ops/fused_moe.py

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import torch.distributed as dist
2323
import torch_npu
2424
from torch import nn
25-
from transformers import PretrainedConfig
26-
from vllm.attention import AttentionMetadata
2725
from vllm.config import get_current_vllm_config
2826
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
2927
get_tensor_model_parallel_world_size,
@@ -37,7 +35,6 @@
3735
FusedMoEParallelConfig # isort: skip
3836
from vllm.model_executor.layers.fused_moe.layer import (
3937
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
40-
from vllm.model_executor.layers.linear import ReplicatedLinear
4138
from vllm.model_executor.layers.quantization.base_config import \
4239
QuantizationConfig
4340

@@ -1546,79 +1543,3 @@ def _forward_ms_fused_moe_comp(
15461543
)
15471544

15481545
return hidden_states
1549-
1550-
1551-
class AscendSparseMoeBlock(nn.Module):
1552-
1553-
def __init__(
1554-
self,
1555-
config: PretrainedConfig,
1556-
quant_config: Optional[QuantizationConfig] = None,
1557-
prefix: str = "",
1558-
):
1559-
super().__init__()
1560-
self.tp_size = get_tensor_model_parallel_world_size()
1561-
if self.tp_size > config.num_experts:
1562-
raise ValueError(
1563-
f"Tensor parallel size {self.tp_size} is greater than "
1564-
f"the number of experts {config.num_experts}.")
1565-
1566-
ascend_config = get_ascend_config()
1567-
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
1568-
self.enable_multistream_moe = (
1569-
ascend_config.torchair_graph_config.enable_multistream_moe)
1570-
1571-
self.gate = ReplicatedLinear(
1572-
config.hidden_size,
1573-
config.num_experts,
1574-
bias=False,
1575-
quant_config=None,
1576-
prefix=f"{prefix}.gate",
1577-
)
1578-
1579-
self.experts = AscendFusedMoE(
1580-
num_experts=config.num_experts,
1581-
top_k=config.num_experts_per_tok,
1582-
hidden_size=config.hidden_size,
1583-
intermediate_size=config.moe_intermediate_size,
1584-
reduce_results=False,
1585-
renormalize=config.norm_topk_prob,
1586-
quant_config=quant_config,
1587-
prefix=f"{prefix}.experts",
1588-
)
1589-
1590-
self.top_k = config.num_experts_per_tok
1591-
1592-
self.dp_size = get_dp_group().world_size
1593-
1594-
self.tp_group = get_tp_group().device_group
1595-
self.tp_rank = get_tp_group().rank_in_group
1596-
self.ep_group = get_ep_group()
1597-
1598-
self.params_dtype = torch.get_default_dtype()
1599-
1600-
def forward(
1601-
self,
1602-
hidden_states: torch.Tensor,
1603-
attn_metadata: Optional[AttentionMetadata] = None,
1604-
) -> torch.Tensor:
1605-
if attn_metadata is None:
1606-
attn_metadata = get_forward_context().attn_metadata
1607-
# when profile runs, force experts to load balanced tokens
1608-
# to avoid high memory consumption on a single rank.
1609-
enable_force_load_balance = get_forward_context().in_profile_run
1610-
is_prefill = get_forward_context().with_prefill
1611-
1612-
# router_logits: (num_tokens, n_experts)
1613-
router_logits, _ = self.gate(hidden_states)
1614-
1615-
hidden_states = self.experts(
1616-
hidden_states=hidden_states,
1617-
router_logits=router_logits,
1618-
is_prefill=is_prefill,
1619-
top_k=self.top_k,
1620-
enable_force_load_balance=enable_force_load_balance,
1621-
shared_experts=None,
1622-
)
1623-
1624-
return hidden_states

0 commit comments

Comments
 (0)