From dd8965c83657b6865cc85726a201d297311f1bf4 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Thu, 21 Aug 2025 12:20:30 +0000 Subject: [PATCH 1/7] Fix the bug that qwen3 moe doesn't work with aclgraph Signed-off-by: shen-shanshan <467638484@qq.com> --- tests/multicard/test_qwen3_moe.py | 55 ++++++++++++++++ vllm_ascend/models/qwen3_moe.py | 100 +++++++++++++++++++++++++++--- vllm_ascend/ops/fused_moe.py | 80 ------------------------ 3 files changed, 148 insertions(+), 87 deletions(-) create mode 100644 tests/multicard/test_qwen3_moe.py diff --git a/tests/multicard/test_qwen3_moe.py b/tests/multicard/test_qwen3_moe.py new file mode 100644 index 00000000000..ccc31d4c1d7 --- /dev/null +++ b/tests/multicard/test_qwen3_moe.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/test_offline_inference.py`. +""" + +from tests.e2e.conftest import VllmRunner + + +def test_models_distributed_Qwen3_MOE_TP2(): + example_prompts = [ + "Hello, my name is", + ] + dtype = "half" + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3-30B-A3B", + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_EP(): + example_prompts = [ + "Hello, my name is", + ] + dtype = "half" + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3-30B-A3B", + dtype=dtype, + tensor_parallel_size=2, + enable_expert_parallel=True, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 088b25f5a27..2d0012cd3a8 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -22,9 +22,13 @@ from torch import nn from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group +from vllm.config import CacheConfig, CompilationLevel, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -32,17 +36,89 @@ from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, Qwen3MoeForCausalLM, - Qwen3MoeMLP, Qwen3MoeModel) + Qwen3MoeMLP, Qwen3MoeModel, + Qwen3MoeSparseMoeBlock) from vllm.model_executor.models.utils import ( extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors -from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock +from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, init_metadata_for_sp) +class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states, + attn_metadata=None, + ): + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + is_prefill = get_forward_context().with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + ) + + return hidden_states + + class AscendQwen3MoeDecoderLayer(nn.Module): def __init__( @@ -78,12 +154,22 @@ def __init__( layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) + use_aclgraph = (vllm_config is not None + and vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and not vllm_config.model_config.enforce_eager) if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = AscendSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if not use_aclgraph: + # FIXME: custom sparse moe block doesn't work with aclgraph. + self.mlp = CustomSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 39ee9acbda8..46d37700c3e 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -23,8 +23,6 @@ import torch.distributed as dist import torch_npu from torch import nn -from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -35,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, determine_expert_map) -from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig @@ -1505,80 +1502,3 @@ def _forward_ms_fused_moe_comp( ) return hidden_states - - -class AscendSparseMoeBlock(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = ( - ascend_config.torchair_graph_config.enable_multistream_moe) - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate", - ) - - self.experts = AscendFusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - ) - - self.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() - - self.params_dtype = torch.get_default_dtype() - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None, - _metadata_for_padding: Optional[MetadataForPadding] = None - ) -> torch.Tensor: - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - enable_force_load_balance = get_forward_context().in_profile_run - is_prefill = get_forward_context().with_prefill - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=self.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=None, - _metadata_for_padding=_metadata_for_padding) - - return hidden_states From 46d426cc404ea4e8546f5b2e5d61104f4495b855 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Thu, 21 Aug 2025 12:37:36 +0000 Subject: [PATCH 2/7] update Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/ops/fused_moe.py | 79 ++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 46d37700c3e..9e3f3d73e06 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -23,6 +23,8 @@ import torch.distributed as dist import torch_npu from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -33,6 +35,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, determine_expert_map) +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig @@ -1502,3 +1505,79 @@ def _forward_ms_fused_moe_comp( ) return hidden_states + +class AscendSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = ( + ascend_config.torchair_graph_config.enable_multistream_moe) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + _metadata_for_padding: Optional[MetadataForPadding] = None + ) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + is_prefill = get_forward_context().with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + _metadata_for_padding=_metadata_for_padding) + + return hidden_states From 8d8fd84e0ae54eff9cb170d2e60238ecdbbc507f Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Thu, 21 Aug 2025 12:46:47 +0000 Subject: [PATCH 3/7] update Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/ops/fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 9e3f3d73e06..39ee9acbda8 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1506,6 +1506,7 @@ def _forward_ms_fused_moe_comp( return hidden_states + class AscendSparseMoeBlock(nn.Module): def __init__( From ff50da600b339b7c4329727b3ada718234f63105 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 22 Aug 2025 01:29:05 +0000 Subject: [PATCH 4/7] update Signed-off-by: shen-shanshan <467638484@qq.com> --- tests/multicard/test_qwen3_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/multicard/test_qwen3_moe.py b/tests/multicard/test_qwen3_moe.py index ccc31d4c1d7..612aa7e1b14 100644 --- a/tests/multicard/test_qwen3_moe.py +++ b/tests/multicard/test_qwen3_moe.py @@ -21,7 +21,7 @@ Run `pytest tests/test_offline_inference.py`. """ -from tests.e2e.conftest import VllmRunner +from tests.conftest import VllmRunner def test_models_distributed_Qwen3_MOE_TP2(): From c5fe8171637c1b330e23da66effe387c4a0b340e Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 22 Aug 2025 02:53:20 +0000 Subject: [PATCH 5/7] update Signed-off-by: shen-shanshan <467638484@qq.com> --- tests/multicard/test_qwen3_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/multicard/test_qwen3_moe.py b/tests/multicard/test_qwen3_moe.py index 612aa7e1b14..ef2aaa52a98 100644 --- a/tests/multicard/test_qwen3_moe.py +++ b/tests/multicard/test_qwen3_moe.py @@ -33,7 +33,7 @@ def test_models_distributed_Qwen3_MOE_TP2(): with VllmRunner( "Qwen/Qwen3-30B-A3B", dtype=dtype, - tensor_parallel_size=2, + tensor_parallel_size=4, distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) @@ -48,7 +48,7 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_EP(): with VllmRunner( "Qwen/Qwen3-30B-A3B", dtype=dtype, - tensor_parallel_size=2, + tensor_parallel_size=4, enable_expert_parallel=True, distributed_executor_backend="mp", ) as vllm_model: From a5e1e6b84bb409d70f18f968d09fbf724da35860 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 22 Aug 2025 03:57:11 +0000 Subject: [PATCH 6/7] update Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/models/qwen3_moe.py | 74 +-------------------------------- 1 file changed, 1 insertion(+), 73 deletions(-) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 2d0012cd3a8..6d906803c10 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -48,77 +48,6 @@ init_metadata_for_sp) -class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - nn.Module.__init__(self) - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate", - ) - - self.experts = AscendFusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - ) - - self.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() - - self.params_dtype = torch.get_default_dtype() - - def forward( - self, - hidden_states, - attn_metadata=None, - ): - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - enable_force_load_balance = get_forward_context().in_profile_run - is_prefill = get_forward_context().with_prefill - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=self.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=None, - ) - - return hidden_states - - class AscendQwen3MoeDecoderLayer(nn.Module): def __init__( @@ -162,8 +91,7 @@ def __init__( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): if not use_aclgraph: - # FIXME: custom sparse moe block doesn't work with aclgraph. - self.mlp = CustomSparseMoeBlock(config=config, + self.mlp = AscendSparseMoeBlock(config=config, quant_config=quant_config, prefix=f"{prefix}.mlp") else: From e04fb7a3ff501d2bd730166cd63a4cd8c77714e8 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 22 Aug 2025 04:37:45 +0000 Subject: [PATCH 7/7] update Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/models/qwen3_moe.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 6d906803c10..9be69a62cb7 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -23,12 +23,8 @@ from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, CompilationLevel, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_tp_group) -from vllm.forward_context import get_forward_context +from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -43,7 +39,7 @@ maybe_prefix) from vllm.sequence import IntermediateTensors -from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, init_metadata_for_sp)