Skip to content

[EPLB] Optimize EPLB with numpy#29499

Merged
tlrmchlsmth merged 27 commits intovllm-project:mainfrom
neuralmagic:imarkov/eplb_optimizations
Jan 7, 2026
Merged

[EPLB] Optimize EPLB with numpy#29499
tlrmchlsmth merged 27 commits intovllm-project:mainfrom
neuralmagic:imarkov/eplb_optimizations

Conversation

@ilmarkov
Copy link
Copy Markdown
Contributor

@ilmarkov ilmarkov commented Nov 26, 2025

This PR adds following optimizations in EPLB algorithm which are applicable to sync and async mode:

  • Removes multiple gpu-cpu transfers in move_from_buffer.
  • Implements get_ep_ranks_with_expert, move_to_buffer, move_from_buffer in numpy which gives 20-25% on average for each primitive
  • Adds preserve_intragpu_slots primitive that does post processing rebalance algo results. It ensures that the experts that are assigned to be moved within the same gpu, stay in their slots. It helps avoid unnecessary gpu memcopies.

Adds log_balancedness_interval config parameter to tune the frequency of balancedness logs and possible reduce number of collectives required for this logging.

Purpose

Improves performance of sync and async eplb. Fixes bug in EPLB config post processing.

Test Plan

Add tests for grouped layer weight exchange, tests for preserve_intragpu_slots.

Validation Result

Client
lm_eval --model local-completions --tasks gsm8k --model_args model=$model,base_url=http://0.0.0.0:$port/v1/completions,num_concurrent=50,max_retries=3,tokenized_requests=False

Main
Qwen-30B-FP8 DP=8

vllm serve Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --disable-log-requests --no-enable-prefix-caching -tp 1 -dp 8 -enable-eplb --eplb-config.window_size 128 --eplb-config.step_interval 512 --eplb-config.num_redundant_experts 32 --eplb-config.use_async true
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8939|±  |0.0085|
|     |       |strict-match    |     5|exact_match|↑  |0.8870|±  |0.0087|

Qwen-30B-FP8 DP=8 AsyncEPLB=True

vllm serve Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --disable-log-requests --no-enable-prefix-caching -tp 1 -dp 8 --max-num-seqs 256 --enable-expert-parallel --port $port --gpu-memory-utilization 0.8 --enable-eplb --eplb-config.window_size 128 --eplb-config.step_interval 512 --eplb-config.num_redundant_experts 32 --eplb-config.use_async true
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8848|±  |0.0088|
|     |       |strict-match    |     5|exact_match|↑  |0.8764|±  |0.0091|

Qwen-30B-FP8 DP=8 AsyncEPLB=False

vllm serve Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --disable-log-requests --no-enable-prefix-caching -tp 1 -dp 8 --max-num-seqs 256 --enable-expert-parallel  --gpu-memory-utilization 0.8 --enable-eplb --eplb-config.window_size 128 --eplb-config.step_interval 512 --eplb-config.num_redundant_experts 32 --eplb-config.use_async false
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8855|±  |0.0088|
|     |       |strict-match    |     5|exact_match|↑  |0.8719|±  |0.0092|

Main
DeepSeek-V2-Lite DP=4
vllm serve deepseek-ai/DeepSeek-V2-Lite --disable-log-requests --no-enable-prefix-caching -tp 1 -dp 4 --enable-eplb --eplb-config.window_size 128 --eplb-config.step_interval 512 --eplb-config.num_redundant_experts 16 --eplb-config.use_async true

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3813|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.3776|±  |0.0134|

DeepSeek-V2-Lite DP=4 AsyncEPLB=True

vllm serve deepseek-ai/DeepSeek-V2-Lite --disable-log-requests --no-enable-prefix-caching -tp 1 -dp 4 --max-num-seqs 256 --enable-expert-parallel --gpu-memory-utilization 0.9 --enable-eplb --eplb-config.window_size 128 --eplb-config.step_interval 512 --eplb-config.num_redundant_experts 16 --eplb-config.use_async true
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3813|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.3768|±  |0.0133|

DeepSeek-V2-Lite DP=4 AsyncEPLB=False

vllm serve deepseek-ai/DeepSeek-V2-Lite --disable-log-requests --no-enable-prefix-caching -tp 1 -dp 4 --max-num-seqs 256 --enable-expert-parallel --gpu-memory-utilization 0.9 --enable-eplb --eplb-config.window_size 128 --eplb-config.step_interval 512 --eplb-config.num_redundant_experts 16 --eplb-config.use_async false
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3783|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.3745|±  |0.0133|

Benchmark results

Server:

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --disable-log-requests --no-enable-prefix-caching -dp 8 -tp 1 --max-num-seqs 256 --enable-expert-parallel --gpu-memory-utilization 0.9 --enable-eplb --eplb-config.window_size 32 --eplb-config.step_interval 128 --eplb-config.num_redundant_experts 128 --eplb-config.use_async false --max-model-len 4096

We compare sync versions of EPLB. Profile the third EPLB call, to get more stable version.
In profile logs
Main: eplb_state.py: rearrange takes 1s370ms
PR: eplb_state.py: rearrange takes 1s77ms. ~27% speedup

Profile logs of move_to_buffer and move_from_buffer

Main:

Screenshot 2025-11-26 at 15 31 44

PR:
Screenshot 2025-11-26 at 15 32 12


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations to the EPLB algorithm by leveraging NumPy for performance gains, reducing GPU-CPU data transfers, and adding support for grouped layer weight exchanges. The introduction of preserve_intragpu_slots is a clever optimization to minimize unnecessary memory copies within the same GPU. The config processing bug fix is also a welcome improvement. Overall, the changes are well-implemented and the added tests provide good coverage for the new functionality. I have one suggestion to further optimize the local weight copy logic in move_to_buffer.

Comment thread vllm/distributed/eplb/rebalance_execute.py Outdated
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@david6666666
Copy link
Copy Markdown
Contributor

LGTM.
EPLB config post processing were re-written by default values also occur in #29385, thx for fix and optimization.

Comment thread vllm/engine/arg_utils.py Outdated
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Dec 9, 2025

Hi @ilmarkov, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Copy link
Copy Markdown
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @ilmarkov. These speedups are quite exciting. I'm still going through the code, but here are some initial, largely cosmetic, comments.

Comment thread vllm/distributed/eplb/rebalance_execute.py
Comment thread vllm/distributed/eplb/rebalance_execute.py Outdated
Comment thread vllm/distributed/eplb/rebalance_execute.py Outdated
Comment thread vllm/distributed/eplb/eplb_state.py Outdated
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Copy link
Copy Markdown
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still going through the code, but here's another round of comments.

Comment thread vllm/distributed/eplb/policy/default.py Outdated
Comment thread vllm/distributed/eplb/policy/default.py Outdated
Comment thread vllm/distributed/eplb/policy/default.py Outdated
Comment thread vllm/distributed/eplb/policy/default.py Outdated
Comment thread vllm/distributed/eplb/policy/default.py Outdated
Comment thread vllm/distributed/eplb/rebalance_execute.py
Comment thread vllm/distributed/eplb/rebalance_execute.py Outdated
Comment thread vllm/distributed/eplb/rebalance_execute.py Outdated
Comment thread vllm/distributed/eplb/rebalance_execute.py
Comment thread vllm/distributed/eplb/rebalance_execute.py Outdated
Comment thread vllm/distributed/eplb/rebalance_execute.py Outdated
Comment thread vllm/distributed/eplb/rebalance_execute.py
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Dec 12, 2025

Hi @ilmarkov, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Copy link
Copy Markdown
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think this is getting pretty close. Given the complex nature of this PR, it would be good to include lm_eval runs for both Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 and deepseek-ai/DeepSeek-V2-Lite with sync and async EPLB.

Comment thread vllm/distributed/eplb/policy/default.py Outdated
Comment thread tests/distributed/test_eplb_algo.py Outdated
Comment thread tests/distributed/test_eplb_algo.py Outdated
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am seeing a failure when running test_eplb_execute - @ilmarkov could you please take a look?

  File "/home/tms/vllm/tests/distributed/test_eplb_execute.py", line 308, in _test_async_transfer_layer_without_mtp_worker
    move_from_buffer(
TypeError: move_from_buffer() got an unexpected keyword argument 'ep_group'
[rank0]:[W106 22:51:45.365432802 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
=============================================== warnings summary ================================================
.venv/lib/python3.12/site-packages/schemathesis/generation/coverage.py:305
  /home/tms/vllm/.venv/lib/python3.12/site-packages/schemathesis/generation/coverage.py:305: DeprecationWarning: jsonschema.exceptions.RefResolutionError is deprecated as of version 4.18.0. If you wish to catch potential reference resolution errors, directly catch referencing.exceptions.Unresolvable.
    ref_error: type[Exception] = jsonschema.RefResolutionError,

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================ short test summary info ============================================
FAILED tests/distributed/test_eplb_execute.py::test_async_transfer_layer_without_mtp[2-2-2-3] - AssertionError

Comment thread vllm/distributed/eplb/eplb_state.py Outdated
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 6, 2026
Comment thread vllm/distributed/eplb/rebalance_execute.py
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thank you!

@tlrmchlsmth tlrmchlsmth merged commit 6170d47 into vllm-project:main Jan 7, 2026
54 checks passed
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
andrewbriand pushed a commit to andrewbriand/vllm that referenced this pull request Feb 10, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
@hmellor hmellor mentioned this pull request Mar 6, 2026
4 tasks
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants