1717# This file is a part of the vllm-ascend project.
1818from typing import Optional
1919
20+ import torch
2021from torch import nn
2122from transformers import PretrainedConfig
2223from vllm .compilation .decorators import support_torch_compile
23- from vllm .config import CacheConfig
24+ from vllm .config import CacheConfig , ParallelConfig
25+ from vllm .distributed import get_tensor_model_parallel_world_size
26+ from vllm .distributed .parallel_state import (get_dp_group , get_ep_group ,
27+ get_tp_group )
28+ from vllm .forward_context import get_forward_context
2429from vllm .model_executor .layers .layernorm import RMSNorm
30+ from vllm .model_executor .layers .linear import ReplicatedLinear
2531from vllm .model_executor .layers .logits_processor import LogitsProcessor
2632from vllm .model_executor .layers .quantization import QuantizationConfig
2733from vllm .model_executor .layers .vocab_parallel_embedding import (
2834 ParallelLMHead , VocabParallelEmbedding )
2935from vllm .model_executor .models .qwen3_moe import (Qwen3MoeAttention ,
3036 Qwen3MoeDecoderLayer ,
3137 Qwen3MoeForCausalLM ,
32- Qwen3MoeMLP , Qwen3MoeModel )
38+ Qwen3MoeMLP , Qwen3MoeModel ,
39+ Qwen3MoeSparseMoeBlock )
3340from vllm .model_executor .models .utils import (
3441 extract_layer_index , make_empty_intermediate_tensors_factory , make_layers ,
3542 maybe_prefix )
3643
37- from vllm_ascend .ops .fused_moe import AscendSparseMoeBlock
44+ from vllm_ascend .ops .fused_moe import AscendFusedMoE
3845from vllm_ascend .platform import VllmConfig
3946
4047
48+ class CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
49+
50+ def __init__ (
51+ self ,
52+ config : PretrainedConfig ,
53+ quant_config : Optional [QuantizationConfig ] = None ,
54+ prefix : str = "" ,
55+ ):
56+ nn .Module .__init__ (self )
57+ self .tp_size = get_tensor_model_parallel_world_size ()
58+ if self .tp_size > config .num_experts :
59+ raise ValueError (
60+ f"Tensor parallel size { self .tp_size } is greater than "
61+ f"the number of experts { config .num_experts } ." )
62+
63+ self .gate = ReplicatedLinear (
64+ config .hidden_size ,
65+ config .num_experts ,
66+ bias = False ,
67+ quant_config = None ,
68+ prefix = f"{ prefix } .gate" ,
69+ )
70+
71+ self .experts = AscendFusedMoE (
72+ num_experts = config .num_experts ,
73+ top_k = config .num_experts_per_tok ,
74+ hidden_size = config .hidden_size ,
75+ intermediate_size = config .moe_intermediate_size ,
76+ reduce_results = False ,
77+ renormalize = config .norm_topk_prob ,
78+ quant_config = quant_config ,
79+ prefix = f"{ prefix } .experts" ,
80+ )
81+
82+ self .top_k = config .num_experts_per_tok
83+
84+ self .dp_size = get_dp_group ().world_size
85+
86+ self .tp_group = get_tp_group ().device_group
87+ self .tp_rank = get_tp_group ().rank_in_group
88+ self .ep_group = get_ep_group ()
89+
90+ self .params_dtype = torch .get_default_dtype ()
91+
92+ def forward (
93+ self ,
94+ hidden_states ,
95+ attn_metadata = None ,
96+ ):
97+ if attn_metadata is None :
98+ attn_metadata = get_forward_context ().attn_metadata
99+ # when profile runs, force experts to load balanced tokens
100+ # to avoid high memory consumption on a single rank.
101+ enable_force_load_balance = get_forward_context ().in_profile_run
102+ is_prefill = get_forward_context ().with_prefill
103+
104+ # router_logits: (num_tokens, n_experts)
105+ router_logits , _ = self .gate (hidden_states )
106+
107+ hidden_states = self .experts (
108+ hidden_states = hidden_states ,
109+ router_logits = router_logits ,
110+ is_prefill = is_prefill ,
111+ top_k = self .top_k ,
112+ enable_force_load_balance = enable_force_load_balance ,
113+ shared_experts = None ,
114+ )
115+
116+ return hidden_states
117+
118+
41119class CustomQwen3MoeDecoderLayer (Qwen3MoeDecoderLayer ):
42120
43121 def __init__ (
44122 self ,
45123 config : PretrainedConfig ,
46124 cache_config : Optional [CacheConfig ] = None ,
47125 quant_config : Optional [QuantizationConfig ] = None ,
126+ parallel_config : Optional [ParallelConfig ] = None ,
48127 prefix : str = "" ,
49128 ) -> None :
50129
@@ -76,9 +155,15 @@ def __init__(
76155 if (layer_idx not in mlp_only_layers ) and (
77156 config .num_experts > 0 and
78157 (layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
79- self .mlp = AscendSparseMoeBlock (config = config ,
80- quant_config = quant_config ,
81- prefix = f"{ prefix } .mlp" )
158+ if not parallel_config .enable_expert_parallel :
159+ # custom sparse moe block doesn't work with ep currently.
160+ self .mlp = CustomSparseMoeBlock (config = config ,
161+ quant_config = quant_config ,
162+ prefix = f"{ prefix } .mlp" )
163+ else :
164+ self .mlp = Qwen3MoeSparseMoeBlock (config = config ,
165+ quant_config = quant_config ,
166+ prefix = f"{ prefix } .mlp" )
82167 else :
83168 self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
84169 intermediate_size = config .intermediate_size ,
@@ -99,6 +184,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
99184 config = vllm_config .model_config .hf_config
100185 cache_config = vllm_config .cache_config
101186 quant_config = vllm_config .quant_config
187+ parallel_config = vllm_config .parallel_config
102188
103189 self .padding_idx = config .pad_token_id
104190 self .vocab_size = config .vocab_size
@@ -113,6 +199,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
113199 config = config ,
114200 cache_config = cache_config ,
115201 quant_config = quant_config ,
202+ parallel_config = parallel_config ,
116203 prefix = prefix ),
117204 prefix = f"{ prefix } .layers" ,
118205 )
0 commit comments