Skip to content

[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285)#5311

Merged
weijinqian0 merged 2 commits intovllm-project:mainfrom
shenchuxiaofugui:config
Dec 29, 2025
Merged

[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285)#5311
weijinqian0 merged 2 commits intovllm-project:mainfrom
shenchuxiaofugui:config

Conversation

@shenchuxiaofugui
Copy link
Copy Markdown
Collaborator

@shenchuxiaofugui shenchuxiaofugui commented Dec 24, 2025

What this PR does / why we need it?

Unify the loading logic for expert_map and log2phy.

  1. The map generated when enabling the redundancy expert is incorrect.
    The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function.
  2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.]
  3. Unify the initialization logic across different scenarios and simplify the code for fused_moe.

Before refactoring

  • map path is not None:

    expert map: get_rank_placement_map from 'expert_load_balancer.py', maintains the map for all ranks and all layers.

    log2phy: get_rank_log2phy_map from 'expert_load_balancer.py', maintains the map for all ranks and all layers.

  • map path is None:

    expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend.
    log2phy: determine_default_log2phy_map from 'eplb_utils.py'. The function does not support the redundant experts of vllm-ascend.

Refactoring
eplb_utils.py
    init_eplb_config
         generate placement
         generate expert map
         generate log2phy

Does this PR introduce any user-facing change?

No

How was this patch tested?

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]

dsr1 baselie:

dataset version metric mode vllm-api-general-chat
gsm8k-lite 7cd45e accuracy gen 100.00

dsr1 eplb:

dataset version metric mode vllm-api-general-chat
gsm8k-lite 7cd45e accuracy gen 100.00

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the initialization logic for expert_map and log2phy to fix bugs related to redundant experts. The refactoring is a good step towards simplifying the logic. However, I've identified several critical issues in the new implementation that could lead to runtime errors. These issues are primarily related to incorrect tensor creation from lists of varying lengths and improper attribute handling, which need to be addressed.

Comment thread vllm_ascend/eplb/core/eplb_utils.py
Comment thread vllm_ascend/eplb/core/eplb_utils.py
Comment thread vllm_ascend/eplb/core/eplb_utils.py Outdated
Comment on lines +113 to +117
log2phy_map = torch.scatter(torch.zeros(len(log2phy_map.keys()), dtype=torch.int32),
0,
torch.tensor(list(log2phy_map.keys()), dtype=torch.int64),
torch.tensor(list(log2phy_map.values()), dtype=torch.int32)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The torch.zeros tensor is initialized with a size of len(log2phy_map.keys()). If the expert indices in log2phy_map.keys() are not contiguous from 0, the maximum index could be greater than len(log2phy_map.keys()) - 1. This will cause an out-of-bounds error when torch.scatter is called. The tensor should be initialized with the total number of experts, which can be obtained from len(expert_list[0]).

Suggested change
log2phy_map = torch.scatter(torch.zeros(len(log2phy_map.keys()), dtype=torch.int32),
0,
torch.tensor(list(log2phy_map.keys()), dtype=torch.int64),
torch.tensor(list(log2phy_map.values()), dtype=torch.int32)
)
log2phy_map = torch.scatter(torch.zeros(len(expert_list[0]), dtype=torch.int32),
0,
torch.tensor(list(log2phy_map.keys()), dtype=torch.int64),
torch.tensor(list(log2phy_map.values()), dtype=torch.int32)
)

Comment thread vllm_ascend/ops/fused_moe/fused_moe.py Outdated
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
self.supports_eplb = self.quant_method.supports_eplb
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The supports_eplb attribute is being set on self (the AscendFusedMoE instance), but it is later accessed from moe_config inside init_eplb_config. This will lead to an AttributeError because moe_config (an instance of FusedMoEConfig) does not have this attribute.

Suggested change
self.supports_eplb = self.quant_method.supports_eplb
self.moe_config.supports_eplb = self.quant_method.supports_eplb

Comment on lines +39 to +41
def __init__(self):
pass
self.transpose_weight = True
self.supports_eplb = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The supports_eplb attribute is being added to AscendW8A8DynamicLinearMethod, which is for general linear layers. However, EPLB (Expert Placement and Load Balancing) is a feature specific to Mixture-of-Experts (MoE) layers. This attribute should be set on AscendW8A8DynamicFusedMoEMethod instead. Please remove self.supports_eplb = True from this __init__ method and add it to the __init__ method of AscendW8A8DynamicFusedMoEMethod.

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@shenchuxiaofugui shenchuxiaofugui changed the title [EPLB] Modification of the initialization logic for expert_map and log2phy [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy Dec 24, 2025
@shenchuxiaofugui shenchuxiaofugui force-pushed the config branch 3 times, most recently from 7596c6b to 40ade36 Compare December 24, 2025 08:21
@wangxiyuan wangxiyuan added ready read for review ready-for-test start test by label for PR labels Dec 24, 2025
@shenchuxiaofugui shenchuxiaofugui force-pushed the config branch 8 times, most recently from e418da5 to a131485 Compare December 25, 2025 11:58
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
@shenchuxiaofugui shenchuxiaofugui changed the title [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr8285) Dec 26, 2025
@shenchuxiaofugui shenchuxiaofugui changed the title [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr8285) [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) Dec 26, 2025
@weijinqian0 weijinqian0 merged commit f81cf69 into vllm-project:main Dec 29, 2025
8 of 10 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Dec 29, 2025
…to eplb_refactor

* 'main' of https://github.com/vllm-project/vllm-ascend: (46 commits)
  [Feature] Support to use fullgraph with eagle (vllm-project#5118)
  [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (vllm-project#5311)
  [Refactor]6/N Extract common code of class AscendMLAImpl (vllm-project#5314)
  [Refactor] cache cos/sin in mla & remove parameter model in builder. (vllm-project#5277)
  update vllm pin to 12.27 (vllm-project#5412)
  [ReleaseNote] Add release note for v0.13.0rc1 (vllm-project#5334)
  [Bugfix] Correctly handle the output shape in multimodal attention (vllm-project#5443)
  Fix nightly (vllm-project#5413)
  [bugfix] fix typo of _skip_all_reduce_across_dp_group (vllm-project#5435)
  [Doc]modify pcp tutorial doc (vllm-project#5440)
  [Misc] fast fail for exiting if tools/install_flash_infer_attention_score_ops_a2.sh (vllm-project#5422)
  [Doc] Update DeepSeek V3.1/R1 2P1D doc (vllm-project#5387)
  [DOC]Fix model weight download links (vllm-project#5436)
  [Doc] Modify DeepSeek-R1/V3.1 documentation (vllm-project#5426)
  Revert "[feat] enable hierarchical mc2 ops on A2 by default (vllm-project#5300)" (vllm-project#5434)
  [Bugfix] fix greedy temperature detection (vllm-project#5417)
  [doc] Update Qwen3-235B doc for reproducing latest performance (vllm-project#5323)
  [feat] enable hierarchical mc2 ops on A2 by default (vllm-project#5300)
  [Doc] delete environment variable HCCL_OP_EXPANSION_MODE in DeepSeekV3.1/R1 (vllm-project#5419)
  [Doc] add long_sequence feature user guide (vllm-project#5343)
  ...
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Dec 29, 2025
…map and log2phy(depend on pr5285) (vllm-project#5311)

### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.

**Before refactoring**

-   map path is not None:

expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

-   map path is None:

expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.

**Refactoring**
eplb_utils.py
&nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  0  1  2  3  4  5  6  7  8
  9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

dsr1 baselie:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |

dsr1 eplb:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
@jianzs
Copy link
Copy Markdown
Collaborator

jianzs commented Dec 29, 2025

image

Please! How can I enable EPLB with redundant experts and MTP?

@shenchuxiaofugui
Copy link
Copy Markdown
Collaborator Author

shenchuxiaofugui commented Dec 30, 2025

image Please! How can I enable EPLB with redundant experts and MTP?

depend on #5285 (unmerged)
The MTP layer must maintain the same number of redundant experts as the other layers. Therefore, either the incoming expert_map includes the MTP layer, or you set init_redundancy_expert to be consistent with the other layers in the table.

zhangxinyuehfad added a commit to zhangxinyuehfad/vllm-ascend that referenced this pull request Dec 30, 2025
… expert_map and log2phy(depend on pr5285) (vllm-project#5311)"

This reverts commit f81cf69.
zhangxinyuehfad added a commit to zhangxinyuehfad/vllm-ascend that referenced this pull request Dec 30, 2025
… expert_map and log2phy(depend on pr5285) (vllm-project#5311)"

This reverts commit f81cf69.

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Dec 31, 2025
…to FIA_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (88 commits)
  [1/N] Refactor nightly test structure (vllm-project#5479)
  Docs: Remove deprecated --task parameter for embedding models (vllm-project#5257)
  Revert "moe_gating_top_k" (vllm-project#5512)
  [Doc] Fix issue link for 0.12.0 (vllm-project#5500)
  [CI]update triton ascend version (vllm-project#5392)
  moe_gating_top_k (vllm-project#5271)
  [refactor] refactor model runner capture model (vllm-project#5230)
  Update corresponding vllm commit ID to 12 29 (vllm-project#5475)
  [Kernel]update csrc cmakelist for open-source cann (vllm-project#5458)
  [OP] add custom op aclnnMoeInitRoutingCustom (vllm-project#5251)
  [Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (vllm-project#5176)
  [Refactor][Triton] Move reject sample triton kernels into ops/triton (vllm-project#5324)
  [Feature] support eager mode in model runner v2 (vllm-project#5210)
  [feature] fia support sliding windows (vllm-project#5239)
  Optimize some rejectsampler functions to make npu op launch non-blocking (vllm-project#4587)
  [Feature] Support to use fullgraph with eagle (vllm-project#5118)
  [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (vllm-project#5311)
  [Refactor]6/N Extract common code of class AscendMLAImpl (vllm-project#5314)
  [Refactor] cache cos/sin in mla & remove parameter model in builder. (vllm-project#5277)
  update vllm pin to 12.27 (vllm-project#5412)
  ...
@shenchuxiaofugui shenchuxiaofugui deleted the config branch January 12, 2026 10:51
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…map and log2phy(depend on pr5285) (vllm-project#5311)

### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.

**Before refactoring**

-   map path is not None:

expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

-   map path is None:

expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.

**Refactoring**
eplb_utils.py
&nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  0  1  2  3  4  5  6  7  8
  9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

dsr1 baselie:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |

dsr1 eplb:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…map and log2phy(depend on pr5285) (vllm-project#5311)

### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.

**Before refactoring**

-   map path is not None:

expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

-   map path is None:

expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.

**Refactoring**
eplb_utils.py
&nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  0  1  2  3  4  5  6  7  8
  9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

dsr1 baselie:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |

dsr1 eplb:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |


- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…map and log2phy(depend on pr5285) (vllm-project#5311)

### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.

**Before refactoring**

-   map path is not None:

expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.

-   map path is None:

expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.

**Refactoring**
eplb_utils.py
&nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  0  1  2  3  4  5  6  7  8
  9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]

dsr1 baselie:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |

dsr1 eplb:

| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants