Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/ut/quantization/methods/test_moe_logical_experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from types import SimpleNamespace

from vllm_ascend.quantization.methods.base import get_moe_num_logical_experts


def test_get_moe_num_logical_experts_uses_vllm_config_field():
layer = SimpleNamespace(moe_config=SimpleNamespace(num_logical_experts=128))

assert get_moe_num_logical_experts(layer, num_experts=130, global_redundant_expert_num=2) == 128


def test_get_moe_num_logical_experts_falls_back_for_older_configs():
layer = SimpleNamespace(moe_config=SimpleNamespace())

assert (
get_moe_num_logical_experts(
layer,
num_experts=133,
global_redundant_expert_num=2,
num_shared_experts=3,
)
== 128
)
16 changes: 13 additions & 3 deletions vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.methods.base import get_moe_num_logical_experts
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_NZ,
Expand Down Expand Up @@ -132,6 +133,15 @@ def apply(
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
num_shared_experts = getattr(layer, "n_shared_experts", 0)
if num_shared_experts is None:
num_shared_experts = 0
num_logical_experts = get_moe_num_logical_experts(
layer,
num_experts,
global_redundant_expert_num=global_redundant_expert_num,
num_shared_experts=num_shared_experts,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -144,7 +154,7 @@ def apply(
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_experts=num_experts,
num_experts=num_logical_experts,
)
if layer.vllm_config.model_config is not None and layer.vllm_config.model_config.enable_return_routed_experts:
capturer = RoutedExpertsCapturer.get_instance()
Expand All @@ -158,7 +168,7 @@ def apply(
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=num_experts,
num_experts=num_logical_experts,
zero_expert_type=zero_expert_type,
hidden_states=x,
)
Expand All @@ -168,7 +178,7 @@ def apply(
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device)
random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)

moe_comm_method = _EXTRA_CTX.moe_comm_method
Expand Down
14 changes: 14 additions & 0 deletions vllm_ascend/quantization/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
from vllm_ascend.quantization.quant_type import QuantType


def get_moe_num_logical_experts(
layer: torch.nn.Module,
num_experts: int,
global_redundant_expert_num: int = 0,
num_shared_experts: int = 0,
) -> int:
moe_config = getattr(layer, "moe_config", None)
num_logical_experts = getattr(moe_config, "num_logical_experts", None)
if num_logical_experts is not None:
return int(num_logical_experts)

return int(num_experts - global_redundant_expert_num - num_shared_experts)


class AscendLinearScheme(ABC):
"""Base class for all linear quantization schemes.

Expand Down
15 changes: 12 additions & 3 deletions vllm_ascend/quantization/methods/w4a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input

from .base import AscendMoEScheme, QuantType
from .base import AscendMoEScheme, QuantType, get_moe_num_logical_experts
from .registry import register_scheme


Expand Down Expand Up @@ -199,7 +199,16 @@ def apply(
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
assert router_logits.shape[1] == num_experts, "Number of global experts mismatch (excluding redundancy)"
num_shared_experts = getattr(layer, "n_shared_experts", 0)
if num_shared_experts is None:
num_shared_experts = 0
num_logical_experts = get_moe_num_logical_experts(
layer,
num_experts,
global_redundant_expert_num=global_redundant_expert_num,
num_shared_experts=num_shared_experts,
)
assert router_logits.shape[1] == num_logical_experts, "Number of global experts mismatch (excluding redundancy)"

topk_weights, topk_ids = select_experts(
hidden_states=x,
Expand All @@ -212,7 +221,7 @@ def apply(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
num_experts=num_experts,
num_experts=num_logical_experts,
)

topk_ids = topk_ids.to(torch.int32)
Expand Down
17 changes: 13 additions & 4 deletions vllm_ascend/quantization/methods/w4a4_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input

from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .base import AscendLinearScheme, AscendMoEScheme, QuantType, get_moe_num_logical_experts
from .registry import register_scheme


Expand Down Expand Up @@ -187,7 +187,16 @@ def apply(
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
assert router_logits.shape[1] == num_experts, "Number of global experts mismatch (excluding redundancy)"
num_shared_experts = getattr(layer, "n_shared_experts", 0)
if num_shared_experts is None:
num_shared_experts = 0
num_logical_experts = get_moe_num_logical_experts(
layer,
num_experts,
global_redundant_expert_num=global_redundant_expert_num,
num_shared_experts=num_shared_experts,
)
assert router_logits.shape[1] == num_logical_experts, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -199,14 +208,14 @@ def apply(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
num_experts=num_experts,
num_experts=num_logical_experts,
)

# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device)
random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)

topk_weights = topk_weights.to(x.dtype)
Expand Down
17 changes: 13 additions & 4 deletions vllm_ascend/quantization/methods/w4a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz

from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .base import AscendLinearScheme, AscendMoEScheme, QuantType, get_moe_num_logical_experts
from .registry import register_scheme


Expand Down Expand Up @@ -349,7 +349,16 @@ def apply(
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
assert router_logits.shape[1] == num_experts, "Number of global experts mismatch (excluding redundancy)"
num_shared_experts = getattr(layer, "n_shared_experts", 0)
if num_shared_experts is None:
num_shared_experts = 0
num_logical_experts = get_moe_num_logical_experts(
layer,
num_experts,
global_redundant_expert_num=global_redundant_expert_num,
num_shared_experts=num_shared_experts,
)
assert router_logits.shape[1] == num_logical_experts, "Number of global experts mismatch (excluding redundancy)"

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids = select_experts(
Expand All @@ -363,14 +372,14 @@ def apply(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
num_experts=num_experts,
num_experts=num_logical_experts,
)

# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device)
random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)

topk_weights = topk_weights.to(x.dtype)
Expand Down
17 changes: 11 additions & 6 deletions vllm_ascend/quantization/methods/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz

from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .base import AscendLinearScheme, AscendMoEScheme, QuantType, get_moe_num_logical_experts
from .registry import register_scheme


Expand Down Expand Up @@ -193,9 +193,14 @@ def apply(
mix_placement = getattr(layer, "mix_placement", False)
if n_shared_experts is None:
n_shared_experts = 0
valid_global_expert_num = num_experts - n_shared_experts
num_logical_experts = get_moe_num_logical_experts(
layer,
num_experts,
global_redundant_expert_num=global_redundant_expert_num,
num_shared_experts=n_shared_experts,
)
if zero_expert_num == 0 or zero_expert_type is None:
assert router_logits.shape[1] == valid_global_expert_num, (
assert router_logits.shape[1] == num_logical_experts, (
"Number of global experts mismatch (excluding redundancy)"
)

Expand All @@ -220,23 +225,23 @@ def apply(
mix_placement=mix_placement,
num_logical_experts=router_logits.shape[1],
num_shared_experts=n_shared_experts,
num_experts=num_experts,
num_experts=num_logical_experts,
)
assert topk_ids is not None
assert topk_weights is not None
if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=num_experts,
num_experts=num_logical_experts,
zero_expert_type=zero_expert_type,
hidden_states=x,
)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device)
random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)

assert topk_weights is not None
Expand Down
17 changes: 13 additions & 4 deletions vllm_ascend/quantization/methods/w8a8_mxfp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input

from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .base import AscendLinearScheme, AscendMoEScheme, QuantType, get_moe_num_logical_experts
from .registry import register_scheme


Expand Down Expand Up @@ -246,7 +246,16 @@ def apply(
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
assert router_logits.shape[1] == num_experts, "Number of global experts mismatch (excluding redundancy)"
num_shared_experts = getattr(layer, "n_shared_experts", 0)
if num_shared_experts is None:
num_shared_experts = 0
num_logical_experts = get_moe_num_logical_experts(
layer,
num_experts,
global_redundant_expert_num=global_redundant_expert_num,
num_shared_experts=num_shared_experts,
)
assert router_logits.shape[1] == num_logical_experts, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -258,14 +267,14 @@ def apply(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
num_experts=num_experts,
num_experts=num_logical_experts,
)

# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0), num_experts, device=topk_ids.device)
random_matrix = torch.rand(topk_ids.size(0), num_logical_experts, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)

topk_weights = topk_weights.to(x.dtype)
Expand Down
Loading