Skip to content

[Performance][Model Loader] Skip non-local expert weights during EP model loading#37136

Merged
ywang96 merged 5 commits intomainfrom
perf-weight-loading
Mar 16, 2026
Merged

[Performance][Model Loader] Skip non-local expert weights during EP model loading#37136
ywang96 merged 5 commits intomainfrom
perf-weight-loading

Conversation

@esmeetu
Copy link
Member

@esmeetu esmeetu commented Mar 16, 2026

Purpose

In DP+EP deployments, every rank currently reads all expert weights from disk via safe_open().get_tensor(), only for FusedMoE.weight_loader to discard non-local experts afterward. For large MoE models (e.g. Kimi-K2.5 with 384 experts at 591GB), each rank reads the full 591GB but only keeps ~144GB (dense + 1/8 experts).

This PR moves the filtering before the disk read by checking the tensor name against the local expert set in the weight iterator, so f.get_tensor() is never called for non-local experts.

  • Skip reading non-local expert weights from disk during model loading when expert parallelism (EP) is enabled
  • Each rank only reads its own expert shard + shared dense weights, avoiding ~87% of storage I/O for typical MoE models
  • No change to non-EP or non-MoE loading paths

Test Plan

  • Verify on MoE model with EP enabled: only local expert weights loaded, model output unchanged
  • Verify on 3D MoE model: gpt-oss-120b
  • Verify non-MoE model loading is unaffected (local_expert_ids remains None)
  • Verify EP=1 (no filtering) produces identical results to baseline
  • Test with different EP sizes (8, 16, 24) to confirm correct expert distribution

Benchmark Results

Kimi-K2.5-NVFP4 (591GB, 384 experts):

DP/EP=4 (1 node × 4 GPUs)

Metric main PR Speedup
Loading Weights 96.58s 67.59s 1.4x
Model Loading Total 103.53s 72.89s 1.4x

DP/EP=8 (2 nodes × 4 GPUs)

Metric main PR Speedup
Loading Weights 58.45s 25.35s 2.3x
Model Loading Total 63.42s 32.04s 2.0x

DP/EP=16 (4 nodes × 4 GPUs)

Metric main PR Speedup
Loading Weights 54.84s 21.71s 2.5x
Model Loading Total 83.46s 39.71s 2.2x

Note on warm vs cold cache: The numbers above were measured with OS page cache already populated (repeated runs). On cold start (first load after boot, deployment, or scale-up), the speedup is expected to be
larger because the dominant cost shifts from CPU-side mmap fault handling to network filesystem I/O, where reducing read volume from 591GB to ~144GB per rank (EP=8) has a proportionally greater effect. Cold
cache testing was not performed but can be reproduced by running sync && echo 3 > /proc/sys/vm/drop_caches (requires root) before the first load.

Speedup scales with EP size: higher EP → more experts skipped → greater I/O reduction.

EP size Experts/rank Per-rank I/O Expert I/O skipped
EP=4 96 ~208GB 75%
EP=8 48 ~144GB 87.5%
EP=16 24 ~112GB 93.75%
EP=24 16 ~101GB 95.8%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

esmeetu added 2 commits March 16, 2026 10:16
Signed-off-by: esmeetu <jasonailu87@gmail.com>
Signed-off-by: esmeetu <jasonailu87@gmail.com>
@esmeetu esmeetu requested a review from 22quinn as a code owner March 16, 2026 03:08
@esmeetu esmeetu changed the title [Performance][ModelL] Skip non-local expert weights during EP model loading [Performance][Model Loader] Skip non-local expert weights during EP model loading Mar 16, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for loading Mixture-of-Experts (MoE) models with expert parallelism (EP) enabled. By filtering out non-local expert weights before they are read from disk, it effectively reduces I/O, which is particularly beneficial for large models. The implementation is well-structured, with the core filtering logic encapsulated in a new ep_weight_filter.py module and accompanied by a comprehensive test suite. The changes are correctly integrated into the model loading pipeline. I've found one critical issue regarding the calculation of expert parallelism rank and size, which omits the prefill context parallelism dimension. Addressing this will ensure the correctness of expert filtering in all configurations.

@mergify
Copy link

mergify bot commented Mar 16, 2026

Hi @esmeetu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 90784738f0

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Signed-off-by: esmeetu <jasonailu87@gmail.com>
@mergify
Copy link

mergify bot commented Mar 16, 2026

Hi @esmeetu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

esmeetu added 2 commits March 16, 2026 12:17
Signed-off-by: esmeetu <jasonailu87@gmail.com>
Signed-off-by: esmeetu <jasonailu87@gmail.com>
@esmeetu esmeetu added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 16, 2026
@ywang96 ywang96 merged commit 821eb80 into main Mar 16, 2026
55 checks passed
@ywang96 ywang96 deleted the perf-weight-loading branch March 16, 2026 08:33
elvircrn added a commit to elvircrn/vllm that referenced this pull request Mar 17, 2026
…37136)

The EP weight filter (PR vllm-project#37136) partitions logical experts across ranks
and skips non-local expert weights at the safetensors level. This breaks
EPLB because redundant physical expert slots map to logical experts that
belong to other ranks in the default partition. Those weights get filtered
out, leaving redundant slots uninitialized (zeros), which causes
catastrophic accuracy loss (~0.08 gsm8k vs ~0.95 baseline).

Fix: skip the EP weight filter entirely when EPLB is enabled, since the
weight loader needs to see ALL logical expert weights to populate
redundant physical slots.

Signed-off-by: Elvir Crncevic <elvircrn@gmail.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
…odel loading (vllm-project#37136)

Signed-off-by: esmeetu <jasonailu87@gmail.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…odel loading (vllm-project#37136)

Signed-off-by: esmeetu <jasonailu87@gmail.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…odel loading (vllm-project#37136)

Signed-off-by: esmeetu <jasonailu87@gmail.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Mar 20, 2026
…ant experts (#7470)

### What this PR does / why we need it?
pr: vllm-project/vllm#37136 break eplb because
it filters out redundant experts.
pr: vllm-project/vllm#37322 fix it due to use
parallel_config.enable_eplb to determine whether to skip the weight
loading filter.
But in vllm-ascend, parallel_config.enable_eplb is always false. When we
use eplb, we temporarily set it to true.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?

![Snipaste_2026-03-19_16-13-01](https://github.com/user-attachments/assets/b3a4911e-36b3-4c31-951c-7c091f416d00)
| dataset | version | metric | mode | vllm-api-stream-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 25, 2026
…ant experts (vllm-project#7470)

### What this PR does / why we need it?
pr: vllm-project/vllm#37136 break eplb because
it filters out redundant experts.
pr: vllm-project/vllm#37322 fix it due to use
parallel_config.enable_eplb to determine whether to skip the weight
loading filter.
But in vllm-ascend, parallel_config.enable_eplb is always false. When we
use eplb, we temporarily set it to true.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?

![Snipaste_2026-03-19_16-13-01](https://github.com/user-attachments/assets/b3a4911e-36b3-4c31-951c-7c091f416d00)
| dataset | version | metric | mode | vllm-api-stream-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants