diff --git a/tests/ut/quantization/methods/test_moe_logical_experts.py b/tests/ut/quantization/methods/test_moe_logical_experts.py new file mode 100644 index 00000000000..f994a7280d7 --- /dev/null +++ b/tests/ut/quantization/methods/test_moe_logical_experts.py @@ -0,0 +1,23 @@ +from types import SimpleNamespace + +from vllm_ascend.quantization.methods.base import get_moe_num_logical_experts + + +def test_get_moe_num_logical_experts_uses_vllm_config_field(): + layer = SimpleNamespace(moe_config=SimpleNamespace(num_logical_experts=128)) + + assert get_moe_num_logical_experts(layer, num_experts=130, global_redundant_expert_num=2) == 128 + + +def test_get_moe_num_logical_experts_falls_back_for_older_configs(): + layer = SimpleNamespace(moe_config=SimpleNamespace()) + + assert ( + get_moe_num_logical_experts( + layer, + num_experts=133, + global_redundant_expert_num=2, + num_shared_experts=3, + ) + == 128 + ) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index b0d7a946b33..7aa0f31f76d 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -41,6 +41,7 @@ from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input +from vllm_ascend.quantization.methods.base import get_moe_num_logical_experts from vllm_ascend.quantization.quant_type import QuantType from vllm_ascend.utils import ( ACL_FORMAT_FRACTAL_NZ, @@ -132,6 +133,15 @@ def apply( ) -> torch.Tensor: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) + num_shared_experts = getattr(layer, "n_shared_experts", 0) + if num_shared_experts is None: + num_shared_experts = 0 + num_logical_experts = get_moe_num_logical_experts( + layer, + num_experts, + global_redundant_expert_num=global_redundant_expert_num, + num_shared_experts=num_shared_experts, + ) topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, @@ -144,7 +154,7 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - num_experts=num_experts, + num_experts=num_logical_experts, ) if layer.vllm_config.model_config is not None and layer.vllm_config.model_config.enable_return_routed_experts: capturer = RoutedExpertsCapturer.get_instance() @@ -158,7 +168,7 @@ def apply( topk_ids, topk_weights, zero_expert_result = zero_experts_compute( expert_indices=topk_ids, expert_scales=topk_weights, - num_experts=num_experts, + num_experts=num_logical_experts, zero_expert_type=zero_expert_type, hidden_states=x, ) @@ -168,7 +178,7 @@ def apply( # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device) + random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) moe_comm_method = _EXTRA_CTX.moe_comm_method diff --git a/vllm_ascend/quantization/methods/base.py b/vllm_ascend/quantization/methods/base.py index adb9edd37f1..c029cfbe8c3 100644 --- a/vllm_ascend/quantization/methods/base.py +++ b/vllm_ascend/quantization/methods/base.py @@ -25,6 +25,20 @@ from vllm_ascend.quantization.quant_type import QuantType +def get_moe_num_logical_experts( + layer: torch.nn.Module, + num_experts: int, + global_redundant_expert_num: int = 0, + num_shared_experts: int = 0, +) -> int: + moe_config = getattr(layer, "moe_config", None) + num_logical_experts = getattr(moe_config, "num_logical_experts", None) + if num_logical_experts is not None: + return int(num_logical_experts) + + return int(num_experts - global_redundant_expert_num - num_shared_experts) + + class AscendLinearScheme(ABC): """Base class for all linear quantization schemes. diff --git a/vllm_ascend/quantization/methods/w4a16.py b/vllm_ascend/quantization/methods/w4a16.py index 260951ab4ee..84d9fc4137a 100644 --- a/vllm_ascend/quantization/methods/w4a16.py +++ b/vllm_ascend/quantization/methods/w4a16.py @@ -27,7 +27,7 @@ from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input -from .base import AscendMoEScheme, QuantType +from .base import AscendMoEScheme, QuantType, get_moe_num_logical_experts from .registry import register_scheme @@ -199,7 +199,16 @@ def apply( apply_router_weight_on_input: bool = False, mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: - assert router_logits.shape[1] == num_experts, "Number of global experts mismatch (excluding redundancy)" + num_shared_experts = getattr(layer, "n_shared_experts", 0) + if num_shared_experts is None: + num_shared_experts = 0 + num_logical_experts = get_moe_num_logical_experts( + layer, + num_experts, + global_redundant_expert_num=global_redundant_expert_num, + num_shared_experts=num_shared_experts, + ) + assert router_logits.shape[1] == num_logical_experts, "Number of global experts mismatch (excluding redundancy)" topk_weights, topk_ids = select_experts( hidden_states=x, @@ -212,7 +221,7 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - num_experts=num_experts, + num_experts=num_logical_experts, ) topk_ids = topk_ids.to(torch.int32) diff --git a/vllm_ascend/quantization/methods/w4a4_mxfp4.py b/vllm_ascend/quantization/methods/w4a4_mxfp4.py index 870a9d6aa38..95c948bdb78 100644 --- a/vllm_ascend/quantization/methods/w4a4_mxfp4.py +++ b/vllm_ascend/quantization/methods/w4a4_mxfp4.py @@ -33,7 +33,7 @@ from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input -from .base import AscendLinearScheme, AscendMoEScheme, QuantType +from .base import AscendLinearScheme, AscendMoEScheme, QuantType, get_moe_num_logical_experts from .registry import register_scheme @@ -187,7 +187,16 @@ def apply( apply_router_weight_on_input: bool = False, mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: - assert router_logits.shape[1] == num_experts, "Number of global experts mismatch (excluding redundancy)" + num_shared_experts = getattr(layer, "n_shared_experts", 0) + if num_shared_experts is None: + num_shared_experts = 0 + num_logical_experts = get_moe_num_logical_experts( + layer, + num_experts, + global_redundant_expert_num=global_redundant_expert_num, + num_shared_experts=num_shared_experts, + ) + assert router_logits.shape[1] == num_logical_experts, "Number of global experts mismatch (excluding redundancy)" topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, @@ -199,14 +208,14 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - num_experts=num_experts, + num_experts=num_logical_experts, ) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device) + random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) topk_weights = topk_weights.to(x.dtype) diff --git a/vllm_ascend/quantization/methods/w4a8.py b/vllm_ascend/quantization/methods/w4a8.py index 7d7ce720d47..7f11ca70314 100644 --- a/vllm_ascend/quantization/methods/w4a8.py +++ b/vllm_ascend/quantization/methods/w4a8.py @@ -31,7 +31,7 @@ from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz -from .base import AscendLinearScheme, AscendMoEScheme, QuantType +from .base import AscendLinearScheme, AscendMoEScheme, QuantType, get_moe_num_logical_experts from .registry import register_scheme @@ -349,7 +349,16 @@ def apply( apply_router_weight_on_input: bool = False, mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: - assert router_logits.shape[1] == num_experts, "Number of global experts mismatch (excluding redundancy)" + num_shared_experts = getattr(layer, "n_shared_experts", 0) + if num_shared_experts is None: + num_shared_experts = 0 + num_logical_experts = get_moe_num_logical_experts( + layer, + num_experts, + global_redundant_expert_num=global_redundant_expert_num, + num_shared_experts=num_shared_experts, + ) + assert router_logits.shape[1] == num_logical_experts, "Number of global experts mismatch (excluding redundancy)" # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern topk_weights, topk_ids = select_experts( @@ -363,14 +372,14 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - num_experts=num_experts, + num_experts=num_logical_experts, ) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device) + random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) topk_weights = topk_weights.to(x.dtype) diff --git a/vllm_ascend/quantization/methods/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py index 334577ac8e6..5e4aff5d74a 100644 --- a/vllm_ascend/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -32,7 +32,7 @@ from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz -from .base import AscendLinearScheme, AscendMoEScheme, QuantType +from .base import AscendLinearScheme, AscendMoEScheme, QuantType, get_moe_num_logical_experts from .registry import register_scheme @@ -193,9 +193,14 @@ def apply( mix_placement = getattr(layer, "mix_placement", False) if n_shared_experts is None: n_shared_experts = 0 - valid_global_expert_num = num_experts - n_shared_experts + num_logical_experts = get_moe_num_logical_experts( + layer, + num_experts, + global_redundant_expert_num=global_redundant_expert_num, + num_shared_experts=n_shared_experts, + ) if zero_expert_num == 0 or zero_expert_type is None: - assert router_logits.shape[1] == valid_global_expert_num, ( + assert router_logits.shape[1] == num_logical_experts, ( "Number of global experts mismatch (excluding redundancy)" ) @@ -220,7 +225,7 @@ def apply( mix_placement=mix_placement, num_logical_experts=router_logits.shape[1], num_shared_experts=n_shared_experts, - num_experts=num_experts, + num_experts=num_logical_experts, ) assert topk_ids is not None assert topk_weights is not None @@ -228,7 +233,7 @@ def apply( topk_ids, topk_weights, zero_expert_result = zero_experts_compute( expert_indices=topk_ids, expert_scales=topk_weights, - num_experts=num_experts, + num_experts=num_logical_experts, zero_expert_type=zero_expert_type, hidden_states=x, ) @@ -236,7 +241,7 @@ def apply( # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device) + random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) assert topk_weights is not None diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index 64a6b13da10..cb57e063e55 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -33,7 +33,7 @@ from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input -from .base import AscendLinearScheme, AscendMoEScheme, QuantType +from .base import AscendLinearScheme, AscendMoEScheme, QuantType, get_moe_num_logical_experts from .registry import register_scheme @@ -246,7 +246,16 @@ def apply( apply_router_weight_on_input: bool = False, mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: - assert router_logits.shape[1] == num_experts, "Number of global experts mismatch (excluding redundancy)" + num_shared_experts = getattr(layer, "n_shared_experts", 0) + if num_shared_experts is None: + num_shared_experts = 0 + num_logical_experts = get_moe_num_logical_experts( + layer, + num_experts, + global_redundant_expert_num=global_redundant_expert_num, + num_shared_experts=num_shared_experts, + ) + assert router_logits.shape[1] == num_logical_experts, "Number of global experts mismatch (excluding redundancy)" topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, @@ -258,14 +267,14 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - num_experts=num_experts, + num_experts=num_logical_experts, ) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device) + random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) topk_weights = topk_weights.to(x.dtype)