Skip to content

[EPLB] Add alternative communication for EPLB weight exchange#33176

Merged
tlrmchlsmth merged 48 commits intovllm-project:mainfrom
neuralmagic:imarkov/refactor-eplb-comminication
Mar 31, 2026
Merged

[EPLB] Add alternative communication for EPLB weight exchange#33176
tlrmchlsmth merged 48 commits intovllm-project:mainfrom
neuralmagic:imarkov/refactor-eplb-comminication

Conversation

@ilmarkov
Copy link
Copy Markdown
Contributor

@ilmarkov ilmarkov commented Jan 27, 2026

Purpose

PR adds an option in eplb_config - communicator [torch_nccl|torch_gloo|pynccl], isolates weights exchange communication from the routing logic.

torch_gloo and nixl avoid async EPLB hangs when NCCL is used in all2all backend, so in this PR we force using these EPLB communicators for async EPLB (instead of doing sync EPLB as of now on main).

Validation

server. Sync EPLB:

VLLM_USE_FLASHINFER_MOE_FP8=0 vllm serve Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --disable-log-requests --no-enable-prefix-caching -tp 1 -dp 4 --max-num-seqs 256 --enable-expert-parallel --port 8003 --gpu-memory-utilization 0.8 --max-model-len 4096 --all2all-backend allgather_reducescatter --async-scheduling --enable-eplb --eplb-config.window_size 10 --eplb-config.step_interval 10 --eplb-config.num_redundant_experts 16 --eplb-config.log_balancedness_interval 128 --eplb-config.log_balancedness false --eplb-config.communicator [pynccl | symm_mem]

gsm8k (same as on main) pynccl

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8855|±  |0.0088|
|     |       |strict-match    |     5|exact_match|↑  |0.8749|±  |0.0091|  

Async EPLB

vllm serve Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --no-enable-prefix-caching -tp 1 -dp 4 --max-num-seqs 256 --enable-expert-parallel --port 8003 --gpu-memory-utilization 0.8 --max-model-len 4096 --async-scheduling --disable-nccl-for-dp-synchronization --all2all-backend allgather_reducescatter --enable-eplb --eplb-config.window_size 10 --eplb-config.step_interval 20 --eplb-config.num_redundant_experts 64 --eplb-config.log_balancedness_interval 128 --eplb-config.log_balancedness false --eplb-config.use_async true --eplb-config.communicator torch_gloo
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8855|±  |0.0088|
|     |       |strict-match    |     5|exact_match|↑  |0.8787|±  |0.0090|

Added tests for all communicators in test_eplb_execute.py. Updated timings in corresponding .buildkite.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@mergify mergify Bot added the ci/build label Jan 27, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@ilmarkov ilmarkov marked this pull request as ready for review January 27, 2026 16:04
@ilmarkov ilmarkov marked this pull request as draft January 27, 2026 16:59
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@ilmarkov ilmarkov marked this pull request as ready for review February 4, 2026 15:19
Copy link
Copy Markdown
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a good change.

I have two minor nits, but otherwise LGTM.

Comment thread vllm/distributed/eplb/rebalance_execute.py Outdated
Comment thread vllm/config/parallel.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 5, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Feb 5, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@mergify mergify Bot removed the needs-rebase label Feb 6, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Feb 11, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
@mergify mergify Bot removed the needs-rebase label Feb 11, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Comment thread vllm/distributed/device_communicators/pynccl_wrapper.py
Comment thread vllm/distributed/eplb/async_worker.py Outdated
Comment thread vllm/distributed/eplb/eplb_communicator.py
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 27, 2026 09:16
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 27, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 27, 2026
Signed-off-by: ilmarkov <markovilya197@gmail.com>
auto-merge was automatically disabled March 27, 2026 10:25

Head branch was pushed to by a user without write access

@mergify mergify Bot removed the needs-rebase label Mar 27, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 30, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 30, 2026
Signed-off-by: Markov Ilya <markovilya19@gmail.com>
@mergify mergify Bot removed the needs-rebase label Mar 31, 2026
@tlrmchlsmth tlrmchlsmth merged commit abdbb68 into vllm-project:main Mar 31, 2026
65 checks passed
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
…roject#33176)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Markov Ilya <markovilya19@gmail.com>
Co-authored-by: Markov Ilya <markovilya19@gmail.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…roject#33176)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Markov Ilya <markovilya19@gmail.com>
Co-authored-by: Markov Ilya <markovilya19@gmail.com>
TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/vllm that referenced this pull request Apr 20, 2026
…roject#33176)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Markov Ilya <markovilya19@gmail.com>
Co-authored-by: Markov Ilya <markovilya19@gmail.com>
SandishKumarHN added a commit to SandishKumarHN/vllm that referenced this pull request Apr 30, 2026
vllm/distributed/eplb/rebalance_execute.py:586 had a device-wide GPU sync
with a NOTE(bowen) comment admitting the original author didn't know why
it was needed. After investigation, the line is dead code in the SYNC
path (rearrange_expert_weights_inplace).

Why it's safe
-------------
The SYNC path runs entirely on the default CUDA stream end-to-end —
torch.empty_like, move_to_buffer's b.copy_(w, non_blocking=True), and
NCCL Send/Recv (default stream=None -> current_stream()) all share it.
No cross-stream hazard exists. PyTorch's ProcessGroupNCCL correctly
calls record_stream() on input/output tensors, so the caching allocator
is also safe across iterations.

The ASYNC path (transfer_layer + async_worker) uses its own design —
cuda_stream.synchronize() (async_worker.py:134) plus CpuGpuEvent for
thread handoff (eplb_utils.py) — and is unaffected by this change.

Likely historical reason
------------------------
The original EPLB PR (vllm-project#18343) used torch.distributed.batch_isend_irecv
directly. req.wait() on those work objects only guarantees the NCCL
collective has been enqueued, NOT that the underlying tensors are safe
to free/reuse — there is no record_stream() linkage to the caching
allocator. torch.cuda.synchronize() was a hammer to flush all work
before the next iteration's torch.empty_like allocations.

The communicator refactor (vllm-project#33176) replaced batch_isend_irecv with
ProcessGroupNCCL-based send/recv, which calls record_stream() correctly.
The sync became dead code at that point but was never removed.

Verification
------------
- Bytecode: 'synchronize' in rearrange_expert_weights_inplace.__code__.co_names -> False
- Stress: 50 runs x 2000 iter x hidden=[1024,2048] on 2x A100 (torch_nccl)
  -> 50/50 race-clean (~100k effective sync-path iterations).
- Larger-scale (4-rank A100) re-validation in progress.
SandishKumarHN added a commit to SandishKumarHN/vllm that referenced this pull request Apr 30, 2026
vllm/distributed/eplb/rebalance_execute.py:586 had a device-wide GPU sync
with a NOTE(bowen) comment admitting the original author didn't know why
it was needed. After investigation, the line is dead code in the SYNC
path (rearrange_expert_weights_inplace).

Why it's safe
-------------
The SYNC path runs entirely on the default CUDA stream end-to-end —
torch.empty_like, move_to_buffer's b.copy_(w, non_blocking=True), and
NCCL Send/Recv (default stream=None -> current_stream()) all share it.
No cross-stream hazard exists. PyTorch's ProcessGroupNCCL correctly
calls record_stream() on input/output tensors, so the caching allocator
is also safe across iterations.

The ASYNC path (transfer_layer + async_worker) uses its own design —
cuda_stream.synchronize() (async_worker.py:134) plus CpuGpuEvent for
thread handoff (eplb_utils.py) — and is unaffected by this change.

Likely historical reason
------------------------
The original EPLB PR (vllm-project#18343) used torch.distributed.batch_isend_irecv
directly. req.wait() on those work objects only guarantees the NCCL
collective has been enqueued, NOT that the underlying tensors are safe
to free/reuse — there is no record_stream() linkage to the caching
allocator. torch.cuda.synchronize() was a hammer to flush all work
before the next iteration's torch.empty_like allocations.

The communicator refactor (vllm-project#33176) replaced batch_isend_irecv with
ProcessGroupNCCL-based send/recv, which calls record_stream() correctly.
The sync became dead code at that point but was never removed.

Verification
------------
- Bytecode: 'synchronize' in rearrange_expert_weights_inplace.__code__.co_names -> False
- Stress: 50 runs x 2000 iter x hidden=[1024,2048] on 2x A100 (torch_nccl)
  -> 50/50 race-clean (~100k effective sync-path iterations).
- Larger-scale (4-rank A100) re-validation in progress.

Signed-off-by: SandishKumarHN <sandish@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants