Skip to content

[2/N] Elastic EP Milestone 2: Integrating NIXL-EP#35627

Merged
tlrmchlsmth merged 6 commits intovllm-project:mainfrom
itayalroy:nixl_ep_integration
Mar 13, 2026
Merged

[2/N] Elastic EP Milestone 2: Integrating NIXL-EP#35627
tlrmchlsmth merged 6 commits intovllm-project:mainfrom
itayalroy:nixl_ep_integration

Conversation

@itayalroy
Copy link
Contributor

@itayalroy itayalroy commented Feb 28, 2026

This PR is a rebase of #29630 originally authored by @libertyeagle that integrates NIXL-EP kernels into vLLM.
NIXL-EP is an implementation of expert-parallel communication kernels over NIXL's device API. It provides elastic scaling capabilities, enabling dynamic addition and removal of processes (ranks) during runtime, without the need to destroy and recreate communicators during scaling up/down.

This PR also includes a few small fixes to vLLM Elastic EP (#34861) that we found while thoroughly testing vLLM with NIXL-EP.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify
Copy link

mergify bot commented Feb 28, 2026

Hi @itayalroy, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates NIXL-EP kernels for elastic expert parallelism, which is a significant enhancement. The changes are mostly about adding the new nixl_ep backend and its related logic. The implementation also includes some important fixes, such as properly destroying NCCL communicators to prevent resource leaks.

My review focuses on ensuring thread safety in the new NixlEPAll2AllManager. I've identified a potential race condition when accessing the shared buffer and suggested a fix using a lock to ensure robustness in a multi-threaded environment.

Comment on lines +475 to +495
def get_handle(self, kwargs):
if (
NixlEPAll2AllManager._buffer is not None
and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size()
):
return NixlEPAll2AllManager._buffer[0]

num_experts_per_rank = kwargs["num_global_experts"] // kwargs["num_ep_ranks"]
nixl_kwargs = dict(
max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"],
token_hidden_size=kwargs["token_hidden_size"],
num_experts_per_rank=num_experts_per_rank,
)
if NixlEPAll2AllManager._buffer is None:
self._init_buffer(**nixl_kwargs)
else:
self._update_buffer()

assert NixlEPAll2AllManager._buffer is not None
handle = NixlEPAll2AllManager._buffer[0]
return handle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _buffer class attribute is a shared mutable state. The get_handle method reads and writes to this shared state without any synchronization, which can lead to a race condition if called from multiple threads concurrently. This could happen, for example, during dynamic LoRA loading, leading to incorrect behavior or crashes.

To prevent this, a lock should be used to protect access to _buffer.

First, please add a lock to the NixlEPAll2AllManager class:

class NixlEPAll2AllManager(All2AllManagerBase):
    ...
    _lock = threading.Lock()
    ...

Then, wrap the get_handle method's logic with this lock as suggested below.

    def get_handle(self, kwargs):
        with NixlEPAll2AllManager._lock:
            if (
                NixlEPAll2AllManager._buffer is not None
                and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size()
            ):
                return NixlEPAll2AllManager._buffer[0]

            num_experts_per_rank = kwargs["num_global_experts"] // kwargs["num_ep_ranks"]
            nixl_kwargs = dict(
                max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"],
                token_hidden_size=kwargs["token_hidden_size"],
                num_experts_per_rank=num_experts_per_rank,
            )
            if NixlEPAll2AllManager._buffer is None:
                self._init_buffer(**nixl_kwargs)
            else:
                self._update_buffer()

            assert NixlEPAll2AllManager._buffer is not None
            handle = NixlEPAll2AllManager._buffer[0]
            return handle

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itayalroy could you address this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this can actually happen, since get_handle() only appears to be called from a single thread during initial setup or elastic EP reconfiguration. In any case, this isn't on the data path, the cost of adding a lock here is negligible, so I added it to be safe.

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 3, 2026
@itayalroy itayalroy force-pushed the nixl_ep_integration branch 2 times, most recently from 9242642 to 020600d Compare March 3, 2026 22:05
@mergify
Copy link

mergify bot commented Mar 3, 2026

Hi @itayalroy, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@itayalroy itayalroy force-pushed the nixl_ep_integration branch 6 times, most recently from 6abe5cf to 29070b4 Compare March 4, 2026 17:03
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One quesiont: Does NIXL-EP use NVLINK at all for intranode traffic? Is it suitable for MNNVL systems? And are there any de-duplication optimizations?

Comment on lines +475 to +495
def get_handle(self, kwargs):
if (
NixlEPAll2AllManager._buffer is not None
and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size()
):
return NixlEPAll2AllManager._buffer[0]

num_experts_per_rank = kwargs["num_global_experts"] // kwargs["num_ep_ranks"]
nixl_kwargs = dict(
max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"],
token_hidden_size=kwargs["token_hidden_size"],
num_experts_per_rank=num_experts_per_rank,
)
if NixlEPAll2AllManager._buffer is None:
self._init_buffer(**nixl_kwargs)
else:
self._update_buffer()

assert NixlEPAll2AllManager._buffer is not None
handle = NixlEPAll2AllManager._buffer[0]
return handle
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itayalroy could you address this comment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is extremely similar to the DeepEP LL prepare_finalize implementation. Should we consolidate these?

@bnellnm WDYT?

Copy link
Collaborator

@bnellnm bnellnm Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to factor out some of the common utilities, e.g. dequant_fp8, maybe_roundup_layer_hidden_size, (maybe _do_quant?) but I think it might be good to keep the main implementations separate in case one or the other of the backends changes their API.

Copy link
Contributor Author

@itayalroy itayalroy Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect nixl_ep_prepare_finalize.py and deepep_ll_prepare_finalize.py to diverge pretty quickly as NIXL-EP progresses, and possibly on the DeepEP side too, so preferred to keep them separate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, fair enough!

@mergify mergify bot added the needs-rebase label Mar 7, 2026
@itayalroy itayalroy force-pushed the nixl_ep_integration branch from 29070b4 to 4429898 Compare March 8, 2026 11:55
@mergify mergify bot removed the needs-rebase label Mar 8, 2026
@itayalroy
Copy link
Contributor Author

itayalroy commented Mar 8, 2026

One quesiont: Does NIXL-EP use NVLINK at all for intranode traffic? Is it suitable for MNNVL systems? And are there any de-duplication optimizations?

NVLink is used for intranode traffic, support for MNNVL is in review, and currently no de-duplication optimizations (although this might change, we are still evaluating the tradeoffs of that approach)

@itayalroy itayalroy force-pushed the nixl_ep_integration branch 2 times, most recently from 52728b1 to 80253cc Compare March 8, 2026 23:09
@itayalroy itayalroy requested a review from njhill as a code owner March 9, 2026 10:46
@mergify mergify bot added the v1 label Mar 9, 2026
@itayalroy itayalroy force-pushed the nixl_ep_integration branch 3 times, most recently from 260d1a1 to 08facd0 Compare March 9, 2026 17:17
libertyeagle and others added 6 commits March 11, 2026 13:42
Signed-off-by: Yongji Wu <wuyongji317@gmail.com>

rebase fix

Signed-off-by: Yongji Wu <wuyongji317@gmail.com>

rebase fix

Signed-off-by: Yongji Wu <wuyongji317@gmail.com>

rebase fix

Signed-off-by: Yongji Wu <wuyongji317@gmail.com>

rebase fix

Signed-off-by: Yongji Wu <wuyongji317@gmail.com>

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Scale-up after scale-down would hang indefinitely
in the ZMQ poll loop waiting for engine identity messages.
Without ROUTER_HANDOVER enabled on the ZMQ ROUTER
socket, engines reconnecting with previously-used identities
had their messages silently dropped, because the ROUTER
still held stale routing entries from the dead connections.

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
@itayalroy itayalroy force-pushed the nixl_ep_integration branch from 08facd0 to 8df1eb6 Compare March 11, 2026 11:42
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last thing missing is better integration testing, which is tricky given the heavy dependency. From @itayalroy we should be able to pip install nixl in either the next NIXL release or the one after, which will make this easier to manage in the test image

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 12, 2026
@tlrmchlsmth tlrmchlsmth merged commit d5af196 into vllm-project:main Mar 13, 2026
72 of 73 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 13, 2026
whycoming pushed a commit to whycoming/vllm that referenced this pull request Mar 13, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
Signed-off-by: whycoming <120623296@qq.com>
juliendenize pushed a commit to juliendenize/vllm that referenced this pull request Mar 13, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
juliendenize pushed a commit to juliendenize/vllm that referenced this pull request Mar 13, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
athrael-soju pushed a commit to athrael-soju/vllm that referenced this pull request Mar 15, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
athrael-soju pushed a commit to athrael-soju/vllm that referenced this pull request Mar 16, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants