Skip to content

Bugfix: Align expert map shapes with redundant experts in EPLB adjustment#5285

Merged
wangxiyuan merged 6 commits intovllm-project:mainfrom
Mercykid-bash:eplb-fix-main
Jan 6, 2026
Merged

Bugfix: Align expert map shapes with redundant experts in EPLB adjustment#5285
wangxiyuan merged 6 commits intovllm-project:mainfrom
Mercykid-bash:eplb-fix-main

Conversation

@Mercykid-bash
Copy link
Copy Markdown
Contributor

@Mercykid-bash Mercykid-bash commented Dec 23, 2025

Overview

This PR fixes a shape mismatch bug between expert_placement_map and log2phy_expert_map when redundant experts are enabled in the vLLM-Ascend platform. The issue occurred during the initialization of expert maps and their updates via EPLB (Expert Load Balancer) adjustment, leading to potential tensor shape errors and incorrect expert routing in distributed MoE deployments.

Key Changes

  1. Unify expert map shape calculation logic

    • Ensure the shape of expert_placement_map and log2phy_expert_map strictly aligns with the total number of experts (including redundant experts) during initialization.
    • Update the shape adjustment logic in EPLB dynamic update process to match the initial expert map dimensions.
  2. Add shape consistency checks

    • Add assertion statements to verify the shape consistency of the two maps after initialization and EPLB adjustment, preventing silent shape mismatches in subsequent operations.

Impact

  • Resolves tensor shape errors when using redundant experts with EPLB on Ascend platform.

  • Ensures correct expert routing and load balancing for MoE models with redundant expert configurations.

  • No breaking changes to existing functionality; compatible with non-redundant expert deployments.

  • vLLM version: release/v0.13.0

  • vLLM main: vllm-project/vllm@ad32e3e

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a shape mismatch bug concerning redundant experts in MoE layers. The modifications correctly adjust the shapes of expert_placement_map and log2phy_expert_map, and ensure the correct number of experts is passed to the underlying operators. The test updates are consistent with these fixes. My primary feedback focuses on enhancing the implementation within token_dispatcher.py. I suggest refactoring the code to avoid passing state via an instance attribute, which will improve the code's robustness and maintainability.

quant_mode = 0
moe_expert_num = len(expert_map)
quant_mode = 2 if self.with_quant else 0
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Storing moe_expert_num as an instance attribute self.moe_expert_num creates an implicit dependency between token_dispatch (via get_dispatch_mc2_kwargs) and token_combine (via get_combine_mc_kwargs). This makes the code fragile and harder to reason about, as token_combine now relies on token_dispatch having been called first to set this attribute. A more robust approach is to pass state explicitly using the context_metadata dictionary.

I recommend the following refactoring:

  1. In get_dispatch_mc2_kwargs: Calculate moe_expert_num as a local variable and do not assign it to self.

    # In get_dispatch_mc2_kwargs
    moe_expert_num = len(expert_map) + global_redundant_expert_num
    kwargs_mc2 = {
        # ...
        "moe_expert_num": moe_expert_num,
        # ...
    }
  2. In token_dispatch: Pass global_redundant_expert_num through context_metadata.

    # In token_dispatch
    context_metadata["global_redundant_expert_num"] = global_redundant_expert_num
  3. In get_combine_mc_kwargs: Recalculate moe_expert_num using the value from context_metadata instead of reading from self.

    # In get_combine_mc_kwargs
    global_redundant_expert_num = context_metadata["global_redundant_expert_num"]
    moe_expert_num = len(expert_map) + global_redundant_expert_num
    kwargs_mc2 = {
        # ...
        "moe_expert_num": moe_expert_num,
        # ...
    }

This change will make the data flow explicit and improve the component's maintainability.

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@Mercykid-bash
Copy link
Copy Markdown
Contributor Author

Validation Result

We conducted comparative tests on the Qwen model under two configurations to verify the fix for expert map shape mismatch:

  1. Baseline (without redundant experts, EPLB disabled)
    The inference output is as follows:

    {"id":"chatcmpl-8e1df21e02cbf858","object":"chat.completion","created":1766491376,"model":"qwen","choices":[{"index":0,"message":{"role":"assistant","content":"\nOkay, the user asked, \"What is deeplearning?\" I need to explain this in a clear and simple way. Let me start by recalling what I know about deep learning.\n\nFirst, deep learning is a subset of machine learning,","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":14,"total_tokens":64,"completion_tokens":50,"prompt_tokens_details":null}
  2. Test Configuration (with redundant experts enabled, EPLB adjustment triggered)
    After enabling redundant experts and triggering dynamic EPLB (Expert Load Balancer) adjustment, the inference output is:

    {"id":"chatcmpl-9a72d52057ab55e0","object":"chat.completion","created":1766492392,"model":"qwen","choices":[{"index":0,"message":{"role":"assistant","content":"\nOkay, the user asked, \"What is deeplearning?\" I need to explain this in a clear and simple way. Let me start by recalling what I know about deep learning.\n\nFirst, deep learning is a subset of machine learning,","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":14,"total_tokens":64,"completion_tokens":50,"prompt_tokens_details":null}

Conclusion

The test results confirm that:

  • The content of the inference response, token count (prompt/completion/total tokens), and finish reason are completely consistent between the two configurations.
  • The fix for shape alignment of expert_placement_map and log2phy_expert_map ensures that enabling redundant experts and triggering EPLB adjustment does not introduce any accuracy degradation or output inconsistency in vLLM-Ascend.
  • The core inference logic remains stable and accurate under the MoE configuration with redundant experts and dynamic EPLB adjustment.

@shenchuxiaofugui shenchuxiaofugui force-pushed the eplb-fix-main branch 2 times, most recently from 75b4896 to 701f2f7 Compare December 25, 2025 12:29
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@MengqingCao MengqingCao added ready read for review ready-for-test start test by label for PR labels Dec 29, 2025
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Jan 4, 2026

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Jan 5, 2026

This pull request has conflicts, please resolve those before we can evaluate the pull request.

mercykid and others added 6 commits January 6, 2026 10:08
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>

rebase
@wangxiyuan wangxiyuan merged commit 29e2f9a into vllm-project:main Jan 6, 2026
19 checks passed
Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
…ment (vllm-project#5285)

#### Overview
This PR fixes a shape mismatch bug between `expert_placement_map` and
`log2phy_expert_map` when **redundant experts** are enabled in the
vLLM-Ascend platform. The issue occurred during the initialization of
expert maps and their updates via EPLB (Expert Load Balancer)
adjustment, leading to potential tensor shape errors and incorrect
expert routing in distributed MoE deployments.

#### Key Changes
1. **Unify expert map shape calculation logic**
- Ensure the shape of `expert_placement_map` and `log2phy_expert_map`
strictly aligns with the total number of experts (including redundant
experts) during initialization.
- Update the shape adjustment logic in EPLB dynamic update process to
match the initial expert map dimensions.

2. **Add shape consistency checks**
- Add assertion statements to verify the shape consistency of the two
maps after initialization and EPLB adjustment, preventing silent shape
mismatches in subsequent operations.

#### Impact
- Resolves tensor shape errors when using redundant experts with EPLB on
Ascend platform.
- Ensures correct expert routing and load balancing for MoE models with
redundant expert configurations.
- No breaking changes to existing functionality; compatible with
non-redundant expert deployments.

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
@Mercykid-bash Mercykid-bash deleted the eplb-fix-main branch January 13, 2026 11:04
aipaes pushed a commit to aipaes/vllm-ascend that referenced this pull request Jan 15, 2026
…ment (vllm-project#5285)

#### Overview
This PR fixes a shape mismatch bug between `expert_placement_map` and
`log2phy_expert_map` when **redundant experts** are enabled in the
vLLM-Ascend platform. The issue occurred during the initialization of
expert maps and their updates via EPLB (Expert Load Balancer)
adjustment, leading to potential tensor shape errors and incorrect
expert routing in distributed MoE deployments.

#### Key Changes
1. **Unify expert map shape calculation logic**
- Ensure the shape of `expert_placement_map` and `log2phy_expert_map`
strictly aligns with the total number of experts (including redundant
experts) during initialization.
- Update the shape adjustment logic in EPLB dynamic update process to
match the initial expert map dimensions.

2. **Add shape consistency checks**
- Add assertion statements to verify the shape consistency of the two
maps after initialization and EPLB adjustment, preventing silent shape
mismatches in subsequent operations.

#### Impact
- Resolves tensor shape errors when using redundant experts with EPLB on
Ascend platform.
- Ensures correct expert routing and load balancing for MoE models with
redundant expert configurations.
- No breaking changes to existing functionality; compatible with
non-redundant expert deployments.

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…ment (vllm-project#5285)

#### Overview
This PR fixes a shape mismatch bug between `expert_placement_map` and
`log2phy_expert_map` when **redundant experts** are enabled in the
vLLM-Ascend platform. The issue occurred during the initialization of
expert maps and their updates via EPLB (Expert Load Balancer)
adjustment, leading to potential tensor shape errors and incorrect
expert routing in distributed MoE deployments.

#### Key Changes
1. **Unify expert map shape calculation logic**
- Ensure the shape of `expert_placement_map` and `log2phy_expert_map`
strictly aligns with the total number of experts (including redundant
experts) during initialization.
- Update the shape adjustment logic in EPLB dynamic update process to
match the initial expert map dimensions.

2. **Add shape consistency checks**
- Add assertion statements to verify the shape consistency of the two
maps after initialization and EPLB adjustment, preventing silent shape
mismatches in subsequent operations.

#### Impact
- Resolves tensor shape errors when using redundant experts with EPLB on
Ascend platform.
- Ensures correct expert routing and load balancing for MoE models with
redundant expert configurations.
- No breaking changes to existing functionality; compatible with
non-redundant expert deployments.

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…ment (vllm-project#5285)

#### Overview
This PR fixes a shape mismatch bug between `expert_placement_map` and
`log2phy_expert_map` when **redundant experts** are enabled in the
vLLM-Ascend platform. The issue occurred during the initialization of
expert maps and their updates via EPLB (Expert Load Balancer)
adjustment, leading to potential tensor shape errors and incorrect
expert routing in distributed MoE deployments.

#### Key Changes
1. **Unify expert map shape calculation logic**
- Ensure the shape of `expert_placement_map` and `log2phy_expert_map`
strictly aligns with the total number of experts (including redundant
experts) during initialization.
- Update the shape adjustment logic in EPLB dynamic update process to
match the initial expert map dimensions.

2. **Add shape consistency checks**
- Add assertion statements to verify the shape consistency of the two
maps after initialization and EPLB adjustment, preventing silent shape
mismatches in subsequent operations.

#### Impact
- Resolves tensor shape errors when using redundant experts with EPLB on
Ascend platform.
- Ensures correct expert routing and load balancing for MoE models with
redundant expert configurations.
- No breaking changes to existing functionality; compatible with
non-redundant expert deployments.

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…ment (vllm-project#5285)

#### Overview
This PR fixes a shape mismatch bug between `expert_placement_map` and
`log2phy_expert_map` when **redundant experts** are enabled in the
vLLM-Ascend platform. The issue occurred during the initialization of
expert maps and their updates via EPLB (Expert Load Balancer)
adjustment, leading to potential tensor shape errors and incorrect
expert routing in distributed MoE deployments.

#### Key Changes
1. **Unify expert map shape calculation logic**
- Ensure the shape of `expert_placement_map` and `log2phy_expert_map`
strictly aligns with the total number of experts (including redundant
experts) during initialization.
- Update the shape adjustment logic in EPLB dynamic update process to
match the initial expert map dimensions.

2. **Add shape consistency checks**
- Add assertion statements to verify the shape consistency of the two
maps after initialization and EPLB adjustment, preventing silent shape
mismatches in subsequent operations.

#### Impact
- Resolves tensor shape errors when using redundant experts with EPLB on
Ascend platform.
- Ensures correct expert routing and load balancing for MoE models with
redundant expert configurations.
- No breaking changes to existing functionality; compatible with
non-redundant expert deployments.

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…ment (vllm-project#5285)

#### Overview
This PR fixes a shape mismatch bug between `expert_placement_map` and
`log2phy_expert_map` when **redundant experts** are enabled in the
vLLM-Ascend platform. The issue occurred during the initialization of
expert maps and their updates via EPLB (Expert Load Balancer)
adjustment, leading to potential tensor shape errors and incorrect
expert routing in distributed MoE deployments.

#### Key Changes
1. **Unify expert map shape calculation logic**
- Ensure the shape of `expert_placement_map` and `log2phy_expert_map`
strictly aligns with the total number of experts (including redundant
experts) during initialization.
- Update the shape adjustment logic in EPLB dynamic update process to
match the initial expert map dimensions.

2. **Add shape consistency checks**
- Add assertion statements to verify the shape consistency of the two
maps after initialization and EPLB adjustment, preventing silent shape
mismatches in subsequent operations.

#### Impact
- Resolves tensor shape errors when using redundant experts with EPLB on
Ascend platform.
- Ensures correct expert routing and load balancing for MoE models with
redundant expert configurations.
- No breaking changes to existing functionality; compatible with
non-redundant expert deployments.

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: shenchuxiaofugui <1311027364@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops module:tests ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants